xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/mtia_extension.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/detail/MTIAHooksInterface.h>
2 #include <c10/core/Device.h>
3 #include <c10/core/Stream.h>
4 #include <c10/core/impl/DeviceGuardImplInterface.h>
5 #include <c10/util/Logging.h>
6 #include <torch/csrc/utils/device_lazy_init.h>
7 #include <thread>
8 namespace torch::mtia {
9 
10 constexpr c10::DeviceType kMTIADeviceType = c10::DeviceType::MTIA;
11 constexpr c10::DeviceIndex kMTIADeviceCount = 2;
12 static thread_local c10::DeviceIndex current_device = 0;
13 static thread_local std::array<c10::Stream, kMTIADeviceCount> current_streams =
14     {c10::Stream::unpack3(0, 0, c10::DeviceType::MTIA),
15      c10::Stream::unpack3(0, 1, c10::DeviceType::MTIA)};
16 static int64_t stream_id_gen = 1;
17 static int64_t event_id_gen = 1;
18 static std::array<c10::Stream, kMTIADeviceCount> default_streams = {
19     c10::Stream::unpack3(0, 0, c10::DeviceType::MTIA),
20     c10::Stream::unpack3(0, 1, c10::DeviceType::MTIA)};
21 struct MTIAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
22   MTIAGuardImpl() = default;
MTIAGuardImpltorch::mtia::MTIAGuardImpl23   explicit MTIAGuardImpl(c10::DeviceType t) {
24     TORCH_INTERNAL_ASSERT(t == kMTIADeviceType);
25   }
typetorch::mtia::MTIAGuardImpl26   c10::DeviceType type() const override {
27     return kMTIADeviceType;
28   }
exchangeDevicetorch::mtia::MTIAGuardImpl29   c10::Device exchangeDevice(c10::Device d) const override {
30     c10::Device old_device = getDevice();
31     if (old_device.index() != d.index()) {
32       setDevice(d);
33     }
34     return old_device;
35   }
getDevicetorch::mtia::MTIAGuardImpl36   c10::Device getDevice() const override {
37     return c10::Device(kMTIADeviceType, current_device);
38   }
39 
setDevicetorch::mtia::MTIAGuardImpl40   void setDevice(c10::Device d) const override {
41     c10::Device current_device = getDevice();
42     if (current_device.index() != d.index()) {
43       current_device = d;
44     }
45   }
uncheckedSetDevicetorch::mtia::MTIAGuardImpl46   void uncheckedSetDevice(c10::Device d) const noexcept override {
47     (void)d;
48   }
getStreamtorch::mtia::MTIAGuardImpl49   c10::Stream getStream(c10::Device d) const noexcept override {
50     return current_streams[d.index()];
51   }
getNewStreamtorch::mtia::MTIAGuardImpl52   c10::Stream getNewStream(c10::Device d, int priority = 0) const override {
53     (void)priority;
54     return c10::Stream::unpack3(stream_id_gen++, d.index(), d.type());
55   }
getDefaultStreamtorch::mtia::MTIAGuardImpl56   c10::Stream getDefaultStream(c10::Device d) const override {
57     return default_streams[d.index()];
58   }
getStreamFromGlobalPooltorch::mtia::MTIAGuardImpl59   c10::Stream getStreamFromGlobalPool(
60       c10::Device d,
61       bool isHighPriority = false) const override {
62     return c10::Stream::unpack3(stream_id_gen++, d.index(), d.type());
63   }
64   // NB: These do NOT set the current device
exchangeStreamtorch::mtia::MTIAGuardImpl65   c10::Stream exchangeStream(c10::Stream s) const noexcept override {
66     c10::Stream old_stream = getStream(s.device());
67     return old_stream;
68   }
deviceCounttorch::mtia::MTIAGuardImpl69   c10::DeviceIndex deviceCount() const noexcept override {
70     return kMTIADeviceCount;
71   }
72 
destroyEventtorch::mtia::MTIAGuardImpl73   void destroyEvent(void* event, const c10::DeviceIndex device_index)
74       const noexcept override {
75     (void)device_index;
76   }
77 
recordtorch::mtia::MTIAGuardImpl78   void record(
79       void** event,
80       const c10::Stream& stream,
81       const c10::DeviceIndex device_index,
82       const c10::EventFlag flag) const override {
83     TORCH_CHECK(
84         device_index == -1 || device_index == stream.device_index(),
85         "Event device index ",
86         device_index,
87         " does not match recording stream's device index ",
88         stream.device_index(),
89         ".");
90 
91     const auto orig_device = getDevice();
92 
93     setDevice(stream.device());
94 
95     if (*event == nullptr) {
96       *event = reinterpret_cast<void*>(event_id_gen++);
97     }
98     setDevice(orig_device);
99   }
100 
blocktorch::mtia::MTIAGuardImpl101   void block(void* event, const c10::Stream& stream) const override {
102     (void)event;
103     (void)stream;
104   }
105 
106   // May be called from any device
queryEventtorch::mtia::MTIAGuardImpl107   bool queryEvent(void* event) const override {
108     (void)event;
109     return true;
110   }
111 
112   // Stream-related functions
queryStreamtorch::mtia::MTIAGuardImpl113   bool queryStream(const c10::Stream& stream) const override {
114     (void)stream;
115     return true;
116   }
117 
synchronizeStreamtorch::mtia::MTIAGuardImpl118   void synchronizeStream(const c10::Stream& stream) const override {
119     (void)stream;
120   }
121 
recordDataPtrOnStreamtorch::mtia::MTIAGuardImpl122   void recordDataPtrOnStream(
123       const c10::DataPtr& data_ptr,
124       const c10::Stream& stream) const override {
125     (void)data_ptr;
126     (void)stream;
127   }
128 
elapsedTimetorch::mtia::MTIAGuardImpl129   double elapsedTime(void* event1, void* event2, const c10::DeviceIndex device_index) const override {
130     (void)device_index;
131     uint64_t elapsed_time = 1e6;
132     return (double)(elapsed_time / 1e6);
133   }
134 
synchronizeEventtorch::mtia::MTIAGuardImpl135   void synchronizeEvent(void* event) const override {
136     (void)event;
137   }
138 };
139 
140 struct MTIAHooks : public at::MTIAHooksInterface {
MTIAHookstorch::mtia::MTIAHooks141   explicit MTIAHooks(at::MTIAHooksArgs) {}
initMTIAtorch::mtia::MTIAHooks142   void initMTIA() const override {}
143 
hasMTIAtorch::mtia::MTIAHooks144   bool hasMTIA() const override {
145     return true;
146   }
147 
deviceCounttorch::mtia::MTIAHooks148   c10::DeviceIndex deviceCount() const override {
149     torch::utils::device_lazy_init(at::kMTIA);
150     return c10::DeviceIndex(2);
151   }
152 
deviceSynchronizetorch::mtia::MTIAHooks153   void deviceSynchronize(c10::DeviceIndex device_index) const override {
154     torch::utils::device_lazy_init(at::kMTIA);
155     (void)device_index;
156   }
157 
showConfigtorch::mtia::MTIAHooks158   std::string showConfig() const override {
159     return "None config";
160   }
161 
exchangeDevicetorch::mtia::MTIAHooks162   c10::DeviceIndex exchangeDevice(c10::DeviceIndex device) const override {
163     torch::utils::device_lazy_init(at::kMTIA);
164     auto orig_device = current_device;
165     if (current_device != device) {
166       current_device = device;
167     }
168     return orig_device;
169   }
170 
maybeExchangeDevicetorch::mtia::MTIAHooks171   c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device) const override {
172     torch::utils::device_lazy_init(at::kMTIA);
173 
174     auto orig_device = current_device;
175     if (current_device != device) {
176       current_device = device;
177     }
178     return orig_device;
179   }
180 
getDefaultStreamtorch::mtia::MTIAHooks181   c10::Stream getDefaultStream(c10::DeviceIndex device) const override {
182     torch::utils::device_lazy_init(at::kMTIA);
183 
184     return default_streams[device];
185   }
186 
getCurrentStreamtorch::mtia::MTIAHooks187   c10::Stream getCurrentStream(c10::DeviceIndex device) const override {
188     torch::utils::device_lazy_init(at::kMTIA);
189 
190     return current_streams[device];
191   }
192 
setCurrentStreamtorch::mtia::MTIAHooks193   void setCurrentStream(const c10::Stream& stream) const override {
194     torch::utils::device_lazy_init(at::kMTIA);
195 
196     current_streams[stream.device_index()] = stream;
197   }
198 
getCurrentDevicetorch::mtia::MTIAHooks199   c10::DeviceIndex getCurrentDevice() const override {
200     torch::utils::device_lazy_init(at::kMTIA);
201 
202     return current_device;
203   }
204 
setCurrentDevicetorch::mtia::MTIAHooks205   void setCurrentDevice(c10::DeviceIndex device) const override {
206     torch::utils::device_lazy_init(at::kMTIA);
207 
208     if (current_device != device) {
209       current_device = device;
210     }
211   }
212 };
213 
214 using at::MTIAHooksRegistry;
215 using at::RegistererMTIAHooksRegistry;
216 
217 REGISTER_MTIA_HOOKS(MTIAHooks);
218 C10_REGISTER_GUARD_IMPL(MTIA, MTIAGuardImpl);
219 
220 } // namespace torch::mtia
221