xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSGeneratorImpl.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1//  Copyright © 2022 Apple Inc.
2
3#include <ATen/Utils.h>
4#include <ATen/mps/MPSGeneratorImpl.h>
5#include <algorithm>
6
7namespace at {
8namespace mps::detail {
9
10const Generator& getDefaultMPSGenerator() {
11  static auto default_gen_mps = createMPSGenerator(c10::detail::getNonDeterministicRandom());
12  return default_gen_mps;
13}
14
15Generator createMPSGenerator(uint64_t seed_val) {
16  auto gen = make_generator<MPSGeneratorImpl>(seed_val);
17  gen.set_current_seed(seed_val);
18  return gen;
19}
20
21} // namespace mps::detail
22
23MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in)
24    : c10::GeneratorImpl{Device(DeviceType::MPS, 0), DispatchKeySet(c10::DispatchKey::MPS)},
25      data_({.seed = seed_in}),
26      engine_(seed_in, 0, 0) {}
27
28void MPSGeneratorImpl::set_current_seed(uint64_t seed) {
29  data_.seed = seed;
30  data_.state.fill(1);
31  // the two last state values are the Philox keys
32  // TODO: make "key" in PhiloxRNGEngine.h public so we don't duplicate code here
33  data_.state[5] = static_cast<uint32_t>(seed);
34  data_.state[6] = static_cast<uint32_t>(seed >> 32);
35  engine_.reset_state(seed);
36}
37
38void MPSGeneratorImpl::set_offset(uint64_t offset) {
39  engine_.set_offset(offset);
40}
41
42uint64_t MPSGeneratorImpl::get_offset() const {
43  return engine_.get_offset();
44}
45
46uint64_t MPSGeneratorImpl::current_seed() const {
47  return data_.seed;
48}
49
50uint64_t MPSGeneratorImpl::seed() {
51  auto random = c10::detail::getNonDeterministicRandom();
52  this->set_current_seed(random);
53  return random;
54}
55
56// See Note [Acquire lock when using random generators]
57void MPSGeneratorImpl::update_philox_counters() {
58  // calling engine_() would call operator() of philox_engine class to
59  // get each of the four newly generated counter values (see PhiloxRNGEngine.h).
60  for (int i = 1; i <= 4; i++) {
61    data_.state[i] = engine_();
62  }
63}
64
65c10::intrusive_ptr<c10::TensorImpl> MPSGeneratorImpl::get_state() const {
66  constexpr size_t states_size = mps::detail::PHILOX_STATE_N * sizeof(uint32_t);
67  constexpr size_t seed_size = sizeof(uint64_t);
68  constexpr size_t offset_size = sizeof(uint64_t);
69  constexpr size_t total_size = states_size + seed_size + offset_size;
70
71  auto state_tensor = at::detail::empty_cpu(
72      {(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
73  auto rng_state = state_tensor.data_ptr<uint8_t>();
74  auto current_seed = this->current_seed();
75  auto current_offset = this->get_offset();
76
77  static_assert(sizeof(decltype(current_seed)) == seed_size, "current_seed size is wrong");
78  static_assert(sizeof(decltype(current_offset)) == offset_size, "current_offset size is wrong");
79
80  memcpy(rng_state, this->data_.state.data(), states_size);
81  memcpy(rng_state + states_size, &current_seed, seed_size);
82  memcpy(rng_state + states_size + seed_size, &current_offset, offset_size);
83
84  return state_tensor.getIntrusivePtr();
85}
86
87void MPSGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
88  constexpr size_t states_size = mps::detail::PHILOX_STATE_N * sizeof(uint32_t);
89  constexpr size_t seed_size = sizeof(uint64_t);
90  constexpr size_t offset_size = sizeof(uint64_t);
91  constexpr size_t total_size = states_size + seed_size + offset_size;
92
93  detail::check_rng_state(new_state);
94
95  auto new_state_size = new_state.numel();
96  TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
97
98  uint64_t input_seed = default_rng_seed_val;
99  uint64_t input_offset = 0;
100  auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
101  memcpy(&input_seed, new_rng_state + states_size, seed_size);
102  this->set_current_seed(input_seed);
103  memcpy(&input_offset, new_rng_state + states_size + seed_size, offset_size);
104  this->set_offset(input_offset);
105  // state.data must be copied after input_seed to not reset the state in set_current_seed()
106  memcpy(this->state_data(), new_rng_state, states_size);
107}
108
109std::shared_ptr<MPSGeneratorImpl> MPSGeneratorImpl::clone() const {
110  return std::shared_ptr<MPSGeneratorImpl>(this->clone_impl());
111}
112
113MPSGeneratorImpl* MPSGeneratorImpl::clone_impl() const {
114  auto gen = new MPSGeneratorImpl();
115  gen->set_current_seed(this->data_.seed);
116  return gen;
117}
118
119} // namespace at
120