1 // Copyright © 2022 Apple Inc. 2 3 #pragma once 4 5 #include <ATen/core/Generator.h> 6 #include <ATen/core/PhiloxRNGEngine.h> 7 #include <c10/core/GeneratorImpl.h> 8 #include <optional> 9 10 namespace at { 11 namespace mps::detail { 12 13 constexpr uint32_t PHILOX_STATE_N = 7; 14 struct rng_data_pod { 15 std::array<uint32_t, PHILOX_STATE_N> state{1}; 16 uint64_t seed = default_rng_seed_val; 17 }; 18 19 TORCH_API const Generator& getDefaultMPSGenerator(); 20 TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val); 21 22 } // namespace mps::detail 23 24 struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl { 25 // Constructors 26 MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val); 27 ~MPSGeneratorImpl() override = default; 28 29 // MPSGeneratorImpl methods 30 std::shared_ptr<MPSGeneratorImpl> clone() const; 31 void set_current_seed(uint64_t seed) override; 32 void set_offset(uint64_t offset) override; 33 uint64_t get_offset() const override; 34 uint64_t current_seed() const override; 35 uint64_t seed() override; 36 void set_state(const c10::TensorImpl& new_state) override; 37 c10::intrusive_ptr<c10::TensorImpl> get_state() const override; 38 void update_philox_counters(); 39 set_engineMPSGeneratorImpl40 void set_engine(at::Philox4_32 engine) { engine_ = engine; }; engineMPSGeneratorImpl41 at::Philox4_32 engine() { return engine_; }; state_dataMPSGeneratorImpl42 uint32_t* state_data() { return data_.state.data(); } device_typeMPSGeneratorImpl43 static DeviceType device_type() { return DeviceType::MPS; }; 44 45 private: 46 mps::detail::rng_data_pod data_; 47 at::Philox4_32 engine_; 48 49 MPSGeneratorImpl* clone_impl() const override; 50 }; 51 52 } // namespace at 53