1 #include <c10/core/impl/DeviceGuardImplInterface.h> 2 #include <c10/macros/Macros.h> 3 4 namespace at { 5 namespace detail { 6 7 struct MetalGuardImpl final : public c10::impl::DeviceGuardImplInterface { 8 MetalGuardImpl() = default; 9 MetalGuardImplat::detail::MetalGuardImpl10 explicit MetalGuardImpl(DeviceType t) { 11 TORCH_INTERNAL_ASSERT(t == DeviceType::Metal); 12 } 13 typeat::detail::MetalGuardImpl14 DeviceType type() const override { 15 return DeviceType::Metal; 16 } exchangeDeviceat::detail::MetalGuardImpl17 Device exchangeDevice(Device) const override { 18 // no-op 19 return Device(DeviceType::Metal, -1); 20 } getDeviceat::detail::MetalGuardImpl21 Device getDevice() const override { 22 return Device(DeviceType::Metal, -1); 23 } setDeviceat::detail::MetalGuardImpl24 void setDevice(Device) const override { 25 // no-op 26 } uncheckedSetDeviceat::detail::MetalGuardImpl27 void uncheckedSetDevice(Device d) const noexcept override { 28 // no-op 29 } getStreamat::detail::MetalGuardImpl30 Stream getStream(Device d) const noexcept override { 31 // no-op 32 return Stream(Stream::DEFAULT, Device(DeviceType::Metal, -1)); 33 } 34 // NB: These do NOT set the current device exchangeStreamat::detail::MetalGuardImpl35 Stream exchangeStream(Stream s) const noexcept override { 36 // no-op 37 return Stream(Stream::DEFAULT, Device(DeviceType::Metal, -1)); 38 } deviceCountat::detail::MetalGuardImpl39 DeviceIndex deviceCount() const noexcept override { 40 return 1; 41 } 42 43 // Event-related functions recordat::detail::MetalGuardImpl44 void record( 45 void** event, 46 const Stream& stream, 47 const DeviceIndex device_index, 48 const EventFlag flag) const override { 49 TORCH_CHECK(false, "Metal backend doesn't support events."); 50 } blockat::detail::MetalGuardImpl51 void block(void* event, const Stream& stream) const override { 52 TORCH_CHECK(false, "Metal backend doesn't support events.") 53 } queryEventat::detail::MetalGuardImpl54 bool queryEvent(void* event) const override { 55 TORCH_CHECK(false, "Metal backend doesn't support events.") 56 } destroyEventat::detail::MetalGuardImpl57 void destroyEvent(void* event, const DeviceIndex device_index) const 58 noexcept override {} 59 }; 60 61 C10_REGISTER_GUARD_IMPL(Metal, MetalGuardImpl); 62 63 } // namespace detail 64 } // namespace at 65