1// Copyright © 2022 Apple Inc. 2 3#include <ATen/Utils.h> 4#include <ATen/mps/MPSGeneratorImpl.h> 5#include <algorithm> 6 7namespace at { 8namespace mps::detail { 9 10const Generator& getDefaultMPSGenerator() { 11 static auto default_gen_mps = createMPSGenerator(c10::detail::getNonDeterministicRandom()); 12 return default_gen_mps; 13} 14 15Generator createMPSGenerator(uint64_t seed_val) { 16 auto gen = make_generator<MPSGeneratorImpl>(seed_val); 17 gen.set_current_seed(seed_val); 18 return gen; 19} 20 21} // namespace mps::detail 22 23MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in) 24 : c10::GeneratorImpl{Device(DeviceType::MPS, 0), DispatchKeySet(c10::DispatchKey::MPS)}, 25 data_({.seed = seed_in}), 26 engine_(seed_in, 0, 0) {} 27 28void MPSGeneratorImpl::set_current_seed(uint64_t seed) { 29 data_.seed = seed; 30 data_.state.fill(1); 31 // the two last state values are the Philox keys 32 // TODO: make "key" in PhiloxRNGEngine.h public so we don't duplicate code here 33 data_.state[5] = static_cast<uint32_t>(seed); 34 data_.state[6] = static_cast<uint32_t>(seed >> 32); 35 engine_.reset_state(seed); 36} 37 38void MPSGeneratorImpl::set_offset(uint64_t offset) { 39 engine_.set_offset(offset); 40} 41 42uint64_t MPSGeneratorImpl::get_offset() const { 43 return engine_.get_offset(); 44} 45 46uint64_t MPSGeneratorImpl::current_seed() const { 47 return data_.seed; 48} 49 50uint64_t MPSGeneratorImpl::seed() { 51 auto random = c10::detail::getNonDeterministicRandom(); 52 this->set_current_seed(random); 53 return random; 54} 55 56// See Note [Acquire lock when using random generators] 57void MPSGeneratorImpl::update_philox_counters() { 58 // calling engine_() would call operator() of philox_engine class to 59 // get each of the four newly generated counter values (see PhiloxRNGEngine.h). 60 for (int i = 1; i <= 4; i++) { 61 data_.state[i] = engine_(); 62 } 63} 64 65c10::intrusive_ptr<c10::TensorImpl> MPSGeneratorImpl::get_state() const { 66 constexpr size_t states_size = mps::detail::PHILOX_STATE_N * sizeof(uint32_t); 67 constexpr size_t seed_size = sizeof(uint64_t); 68 constexpr size_t offset_size = sizeof(uint64_t); 69 constexpr size_t total_size = states_size + seed_size + offset_size; 70 71 auto state_tensor = at::detail::empty_cpu( 72 {(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt); 73 auto rng_state = state_tensor.data_ptr<uint8_t>(); 74 auto current_seed = this->current_seed(); 75 auto current_offset = this->get_offset(); 76 77 static_assert(sizeof(decltype(current_seed)) == seed_size, "current_seed size is wrong"); 78 static_assert(sizeof(decltype(current_offset)) == offset_size, "current_offset size is wrong"); 79 80 memcpy(rng_state, this->data_.state.data(), states_size); 81 memcpy(rng_state + states_size, ¤t_seed, seed_size); 82 memcpy(rng_state + states_size + seed_size, ¤t_offset, offset_size); 83 84 return state_tensor.getIntrusivePtr(); 85} 86 87void MPSGeneratorImpl::set_state(const c10::TensorImpl& new_state) { 88 constexpr size_t states_size = mps::detail::PHILOX_STATE_N * sizeof(uint32_t); 89 constexpr size_t seed_size = sizeof(uint64_t); 90 constexpr size_t offset_size = sizeof(uint64_t); 91 constexpr size_t total_size = states_size + seed_size + offset_size; 92 93 detail::check_rng_state(new_state); 94 95 auto new_state_size = new_state.numel(); 96 TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size"); 97 98 uint64_t input_seed = default_rng_seed_val; 99 uint64_t input_offset = 0; 100 auto new_rng_state = new_state.data_dtype_initialized<uint8_t>(); 101 memcpy(&input_seed, new_rng_state + states_size, seed_size); 102 this->set_current_seed(input_seed); 103 memcpy(&input_offset, new_rng_state + states_size + seed_size, offset_size); 104 this->set_offset(input_offset); 105 // state.data must be copied after input_seed to not reset the state in set_current_seed() 106 memcpy(this->state_data(), new_rng_state, states_size); 107} 108 109std::shared_ptr<MPSGeneratorImpl> MPSGeneratorImpl::clone() const { 110 return std::shared_ptr<MPSGeneratorImpl>(this->clone_impl()); 111} 112 113MPSGeneratorImpl* MPSGeneratorImpl::clone_impl() const { 114 auto gen = new MPSGeneratorImpl(); 115 gen->set_current_seed(this->data_.seed); 116 return gen; 117} 118 119} // namespace at 120