xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSGeneratorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //  Copyright © 2022 Apple Inc.
2 
3 #pragma once
4 
5 #include <ATen/core/Generator.h>
6 #include <ATen/core/PhiloxRNGEngine.h>
7 #include <c10/core/GeneratorImpl.h>
8 #include <optional>
9 
10 namespace at {
11 namespace mps::detail {
12 
13 constexpr uint32_t PHILOX_STATE_N = 7;
14 struct rng_data_pod {
15   std::array<uint32_t, PHILOX_STATE_N> state{1};
16   uint64_t seed = default_rng_seed_val;
17 };
18 
19 TORCH_API const Generator& getDefaultMPSGenerator();
20 TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
21 
22 } // namespace mps::detail
23 
24 struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
25   // Constructors
26   MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
27   ~MPSGeneratorImpl() override = default;
28 
29   // MPSGeneratorImpl methods
30   std::shared_ptr<MPSGeneratorImpl> clone() const;
31   void set_current_seed(uint64_t seed) override;
32   void set_offset(uint64_t offset) override;
33   uint64_t get_offset() const override;
34   uint64_t current_seed() const override;
35   uint64_t seed() override;
36   void set_state(const c10::TensorImpl& new_state) override;
37   c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
38   void update_philox_counters();
39 
set_engineMPSGeneratorImpl40   void set_engine(at::Philox4_32 engine) { engine_ = engine; };
engineMPSGeneratorImpl41   at::Philox4_32 engine() { return engine_; };
state_dataMPSGeneratorImpl42   uint32_t* state_data() { return data_.state.data(); }
device_typeMPSGeneratorImpl43   static DeviceType device_type() { return DeviceType::MPS; };
44 
45 private:
46   mps::detail::rng_data_pod data_;
47   at::Philox4_32 engine_;
48 
49   MPSGeneratorImpl* clone_impl() const override;
50 };
51 
52 } // namespace at
53