1 #pragma once 2 3 #include <ATen/core/Generator.h> 4 5 namespace at { 6 7 struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl { 8 // Constructors 9 XPUGeneratorImpl(DeviceIndex device_index = -1); 10 ~XPUGeneratorImpl() override = default; 11 12 // XPUGeneratorImpl methods 13 std::shared_ptr<XPUGeneratorImpl> clone() const; 14 void set_current_seed(uint64_t seed) override; 15 void set_offset(uint64_t offset) override; 16 uint64_t get_offset() const override; 17 uint64_t current_seed() const override; 18 uint64_t seed() override; 19 void set_state(const c10::TensorImpl& new_state) override; 20 c10::intrusive_ptr<c10::TensorImpl> get_state() const override; 21 void set_philox_offset_per_thread(uint64_t offset); 22 uint64_t philox_offset_per_thread() const; 23 std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment); 24 static c10::DeviceType device_type(); 25 26 private: 27 XPUGeneratorImpl* clone_impl() const override; 28 uint64_t seed_ = default_rng_seed_val; 29 uint64_t philox_offset_per_thread_ = 0; 30 }; 31 32 namespace xpu::detail { 33 34 TORCH_XPU_API const Generator& getDefaultXPUGenerator(DeviceIndex device = -1); 35 36 TORCH_XPU_API Generator createXPUGenerator(DeviceIndex device = -1); 37 38 } // namespace xpu::detail 39 } // namespace at 40