xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalGuardImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/impl/DeviceGuardImplInterface.h>
2 #include <c10/macros/Macros.h>
3 
4 namespace at {
5 namespace detail {
6 
7 struct MetalGuardImpl final : public c10::impl::DeviceGuardImplInterface {
8   MetalGuardImpl() = default;
9 
MetalGuardImplat::detail::MetalGuardImpl10   explicit MetalGuardImpl(DeviceType t) {
11     TORCH_INTERNAL_ASSERT(t == DeviceType::Metal);
12   }
13 
typeat::detail::MetalGuardImpl14   DeviceType type() const override {
15     return DeviceType::Metal;
16   }
exchangeDeviceat::detail::MetalGuardImpl17   Device exchangeDevice(Device) const override {
18     // no-op
19     return Device(DeviceType::Metal, -1);
20   }
getDeviceat::detail::MetalGuardImpl21   Device getDevice() const override {
22     return Device(DeviceType::Metal, -1);
23   }
setDeviceat::detail::MetalGuardImpl24   void setDevice(Device) const override {
25     // no-op
26   }
uncheckedSetDeviceat::detail::MetalGuardImpl27   void uncheckedSetDevice(Device d) const noexcept override {
28     // no-op
29   }
getStreamat::detail::MetalGuardImpl30   Stream getStream(Device d) const noexcept override {
31     // no-op
32     return Stream(Stream::DEFAULT, Device(DeviceType::Metal, -1));
33   }
34   // NB: These do NOT set the current device
exchangeStreamat::detail::MetalGuardImpl35   Stream exchangeStream(Stream s) const noexcept override {
36     // no-op
37     return Stream(Stream::DEFAULT, Device(DeviceType::Metal, -1));
38   }
deviceCountat::detail::MetalGuardImpl39   DeviceIndex deviceCount() const noexcept override {
40     return 1;
41   }
42 
43   // Event-related functions
recordat::detail::MetalGuardImpl44   void record(
45       void** event,
46       const Stream& stream,
47       const DeviceIndex device_index,
48       const EventFlag flag) const override {
49     TORCH_CHECK(false, "Metal backend doesn't support events.");
50   }
blockat::detail::MetalGuardImpl51   void block(void* event, const Stream& stream) const override {
52     TORCH_CHECK(false, "Metal backend doesn't support events.")
53   }
queryEventat::detail::MetalGuardImpl54   bool queryEvent(void* event) const override {
55     TORCH_CHECK(false, "Metal backend doesn't support events.")
56   }
destroyEventat::detail::MetalGuardImpl57   void destroyEvent(void* event, const DeviceIndex device_index) const
58       noexcept override {}
59 };
60 
61 C10_REGISTER_GUARD_IMPL(Metal, MetalGuardImpl);
62 
63 } // namespace detail
64 } // namespace at
65