1 #pragma once 2 3 #include <c10/core/impl/DeviceGuardImplInterface.h> 4 5 namespace c10::impl { 6 7 /** 8 * An implementation of DeviceGuardImplInterface which delegates 9 * to virtual dispatch on the DeviceGuardImpl registry. 10 */ 11 class VirtualGuardImpl final : public DeviceGuardImplInterface { 12 public: VirtualGuardImpl(DeviceType device_type)13 VirtualGuardImpl(DeviceType device_type) 14 : impl_(getDeviceGuardImpl(device_type)) {} 15 // This constructor exists purely for testing VirtualGuardImpl(const DeviceGuardImplInterface * impl)16 VirtualGuardImpl(const DeviceGuardImplInterface* impl) : impl_(impl) {} 17 18 // Copying and moving is OK! 19 VirtualGuardImpl(const VirtualGuardImpl&) = default; 20 VirtualGuardImpl& operator=(const VirtualGuardImpl&) = default; 21 VirtualGuardImpl(VirtualGuardImpl&&) noexcept = default; 22 VirtualGuardImpl& operator=(VirtualGuardImpl&&) noexcept = default; 23 type()24 DeviceType type() const override { 25 return impl_->type(); 26 } exchangeDevice(Device d)27 Device exchangeDevice(Device d) const override { 28 return impl_->exchangeDevice(d); 29 } getDevice()30 Device getDevice() const override { 31 return impl_->getDevice(); 32 } setDevice(Device d)33 void setDevice(Device d) const override { 34 impl_->setDevice(d); 35 } uncheckedSetDevice(Device d)36 void uncheckedSetDevice(Device d) const noexcept override { 37 impl_->uncheckedSetDevice(d); 38 } getStream(Device d)39 Stream getStream(Device d) const noexcept override { 40 return impl_->getStream(d); 41 } 42 Stream getNewStream(Device d, int priority = 0) const override { 43 return impl_->getNewStream(d, priority); 44 } getDefaultStream(Device d)45 Stream getDefaultStream(Device d) const override { 46 return impl_->getDefaultStream(d); 47 } 48 Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) 49 const override { 50 return impl_->getStreamFromGlobalPool(d, isHighPriority); 51 } exchangeStream(Stream s)52 Stream exchangeStream(Stream s) const noexcept override { 53 return impl_->exchangeStream(s); 54 } deviceCount()55 DeviceIndex deviceCount() const noexcept override { 56 return impl_->deviceCount(); 57 } 58 59 // Event functions record(void ** event,const Stream & stream,const DeviceIndex device_index,const EventFlag flag)60 void record( 61 void** event, 62 const Stream& stream, 63 const DeviceIndex device_index, 64 const EventFlag flag) const override { 65 impl_->record(event, stream, device_index, flag); 66 } block(void * event,const Stream & stream)67 void block(void* event, const Stream& stream) const override { 68 impl_->block(event, stream); 69 } queryEvent(void * event)70 bool queryEvent(void* event) const override { 71 return impl_->queryEvent(event); 72 } destroyEvent(void * event,const DeviceIndex device_index)73 void destroyEvent(void* event, const DeviceIndex device_index) 74 const noexcept override { 75 impl_->destroyEvent(event, device_index); 76 } 77 queryStream(const Stream & stream)78 bool queryStream(const Stream& stream) const override { 79 return impl_->queryStream(stream); 80 } synchronizeStream(const Stream & stream)81 void synchronizeStream(const Stream& stream) const override { 82 impl_->synchronizeStream(stream); 83 } 84 recordDataPtrOnStream(const c10::DataPtr & data_ptr,const Stream & stream)85 void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) 86 const override { 87 impl_->recordDataPtrOnStream(data_ptr, stream); 88 } 89 elapsedTime(void * event1,void * event2,const DeviceIndex device_index)90 double elapsedTime(void* event1, void* event2, const DeviceIndex device_index) 91 const override { 92 return impl_->elapsedTime(event1, event2, device_index); 93 } 94 synchronizeEvent(void * event)95 void synchronizeEvent(void* event) const override { 96 return impl_->synchronizeEvent(event); 97 } 98 99 private: 100 const DeviceGuardImplInterface* impl_ = nullptr; 101 }; 102 103 } // namespace c10::impl 104