xref: /aosp_15_r20/external/pytorch/aten/src/ATen/xpu/XPUGeneratorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Generator.h>
4 
5 namespace at {
6 
7 struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl {
8   // Constructors
9   XPUGeneratorImpl(DeviceIndex device_index = -1);
10   ~XPUGeneratorImpl() override = default;
11 
12   // XPUGeneratorImpl methods
13   std::shared_ptr<XPUGeneratorImpl> clone() const;
14   void set_current_seed(uint64_t seed) override;
15   void set_offset(uint64_t offset) override;
16   uint64_t get_offset() const override;
17   uint64_t current_seed() const override;
18   uint64_t seed() override;
19   void set_state(const c10::TensorImpl& new_state) override;
20   c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
21   void set_philox_offset_per_thread(uint64_t offset);
22   uint64_t philox_offset_per_thread() const;
23   std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment);
24   static c10::DeviceType device_type();
25 
26  private:
27   XPUGeneratorImpl* clone_impl() const override;
28   uint64_t seed_ = default_rng_seed_val;
29   uint64_t philox_offset_per_thread_ = 0;
30 };
31 
32 namespace xpu::detail {
33 
34 TORCH_XPU_API const Generator& getDefaultXPUGenerator(DeviceIndex device = -1);
35 
36 TORCH_XPU_API Generator createXPUGenerator(DeviceIndex device = -1);
37 
38 } // namespace xpu::detail
39 } // namespace at
40