xref: /aosp_15_r20/external/pytorch/c10/core/impl/VirtualGuardImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/impl/DeviceGuardImplInterface.h>
4 
5 namespace c10::impl {
6 
7 /**
8  * An implementation of DeviceGuardImplInterface which delegates
9  * to virtual dispatch on the DeviceGuardImpl registry.
10  */
11 class VirtualGuardImpl final : public DeviceGuardImplInterface {
12  public:
VirtualGuardImpl(DeviceType device_type)13   VirtualGuardImpl(DeviceType device_type)
14       : impl_(getDeviceGuardImpl(device_type)) {}
15   // This constructor exists purely for testing
VirtualGuardImpl(const DeviceGuardImplInterface * impl)16   VirtualGuardImpl(const DeviceGuardImplInterface* impl) : impl_(impl) {}
17 
18   // Copying and moving is OK!
19   VirtualGuardImpl(const VirtualGuardImpl&) = default;
20   VirtualGuardImpl& operator=(const VirtualGuardImpl&) = default;
21   VirtualGuardImpl(VirtualGuardImpl&&) noexcept = default;
22   VirtualGuardImpl& operator=(VirtualGuardImpl&&) noexcept = default;
23 
type()24   DeviceType type() const override {
25     return impl_->type();
26   }
exchangeDevice(Device d)27   Device exchangeDevice(Device d) const override {
28     return impl_->exchangeDevice(d);
29   }
getDevice()30   Device getDevice() const override {
31     return impl_->getDevice();
32   }
setDevice(Device d)33   void setDevice(Device d) const override {
34     impl_->setDevice(d);
35   }
uncheckedSetDevice(Device d)36   void uncheckedSetDevice(Device d) const noexcept override {
37     impl_->uncheckedSetDevice(d);
38   }
getStream(Device d)39   Stream getStream(Device d) const noexcept override {
40     return impl_->getStream(d);
41   }
42   Stream getNewStream(Device d, int priority = 0) const override {
43     return impl_->getNewStream(d, priority);
44   }
getDefaultStream(Device d)45   Stream getDefaultStream(Device d) const override {
46     return impl_->getDefaultStream(d);
47   }
48   Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
49       const override {
50     return impl_->getStreamFromGlobalPool(d, isHighPriority);
51   }
exchangeStream(Stream s)52   Stream exchangeStream(Stream s) const noexcept override {
53     return impl_->exchangeStream(s);
54   }
deviceCount()55   DeviceIndex deviceCount() const noexcept override {
56     return impl_->deviceCount();
57   }
58 
59   // Event functions
record(void ** event,const Stream & stream,const DeviceIndex device_index,const EventFlag flag)60   void record(
61       void** event,
62       const Stream& stream,
63       const DeviceIndex device_index,
64       const EventFlag flag) const override {
65     impl_->record(event, stream, device_index, flag);
66   }
block(void * event,const Stream & stream)67   void block(void* event, const Stream& stream) const override {
68     impl_->block(event, stream);
69   }
queryEvent(void * event)70   bool queryEvent(void* event) const override {
71     return impl_->queryEvent(event);
72   }
destroyEvent(void * event,const DeviceIndex device_index)73   void destroyEvent(void* event, const DeviceIndex device_index)
74       const noexcept override {
75     impl_->destroyEvent(event, device_index);
76   }
77 
queryStream(const Stream & stream)78   bool queryStream(const Stream& stream) const override {
79     return impl_->queryStream(stream);
80   }
synchronizeStream(const Stream & stream)81   void synchronizeStream(const Stream& stream) const override {
82     impl_->synchronizeStream(stream);
83   }
84 
recordDataPtrOnStream(const c10::DataPtr & data_ptr,const Stream & stream)85   void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
86       const override {
87     impl_->recordDataPtrOnStream(data_ptr, stream);
88   }
89 
elapsedTime(void * event1,void * event2,const DeviceIndex device_index)90   double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
91       const override {
92     return impl_->elapsedTime(event1, event2, device_index);
93   }
94 
synchronizeEvent(void * event)95   void synchronizeEvent(void* event) const override {
96     return impl_->synchronizeEvent(event);
97   }
98 
99  private:
100   const DeviceGuardImplInterface* impl_ = nullptr;
101 };
102 
103 } // namespace c10::impl
104