1 #include <ATen/detail/MTIAHooksInterface.h> 2 #include <c10/core/Device.h> 3 #include <c10/core/Stream.h> 4 #include <c10/core/impl/DeviceGuardImplInterface.h> 5 #include <c10/util/Logging.h> 6 #include <torch/csrc/utils/device_lazy_init.h> 7 #include <thread> 8 namespace torch::mtia { 9 10 constexpr c10::DeviceType kMTIADeviceType = c10::DeviceType::MTIA; 11 constexpr c10::DeviceIndex kMTIADeviceCount = 2; 12 static thread_local c10::DeviceIndex current_device = 0; 13 static thread_local std::array<c10::Stream, kMTIADeviceCount> current_streams = 14 {c10::Stream::unpack3(0, 0, c10::DeviceType::MTIA), 15 c10::Stream::unpack3(0, 1, c10::DeviceType::MTIA)}; 16 static int64_t stream_id_gen = 1; 17 static int64_t event_id_gen = 1; 18 static std::array<c10::Stream, kMTIADeviceCount> default_streams = { 19 c10::Stream::unpack3(0, 0, c10::DeviceType::MTIA), 20 c10::Stream::unpack3(0, 1, c10::DeviceType::MTIA)}; 21 struct MTIAGuardImpl final : public c10::impl::DeviceGuardImplInterface { 22 MTIAGuardImpl() = default; MTIAGuardImpltorch::mtia::MTIAGuardImpl23 explicit MTIAGuardImpl(c10::DeviceType t) { 24 TORCH_INTERNAL_ASSERT(t == kMTIADeviceType); 25 } typetorch::mtia::MTIAGuardImpl26 c10::DeviceType type() const override { 27 return kMTIADeviceType; 28 } exchangeDevicetorch::mtia::MTIAGuardImpl29 c10::Device exchangeDevice(c10::Device d) const override { 30 c10::Device old_device = getDevice(); 31 if (old_device.index() != d.index()) { 32 setDevice(d); 33 } 34 return old_device; 35 } getDevicetorch::mtia::MTIAGuardImpl36 c10::Device getDevice() const override { 37 return c10::Device(kMTIADeviceType, current_device); 38 } 39 setDevicetorch::mtia::MTIAGuardImpl40 void setDevice(c10::Device d) const override { 41 c10::Device current_device = getDevice(); 42 if (current_device.index() != d.index()) { 43 current_device = d; 44 } 45 } uncheckedSetDevicetorch::mtia::MTIAGuardImpl46 void uncheckedSetDevice(c10::Device d) const noexcept override { 47 (void)d; 48 } getStreamtorch::mtia::MTIAGuardImpl49 c10::Stream getStream(c10::Device d) const noexcept override { 50 return current_streams[d.index()]; 51 } getNewStreamtorch::mtia::MTIAGuardImpl52 c10::Stream getNewStream(c10::Device d, int priority = 0) const override { 53 (void)priority; 54 return c10::Stream::unpack3(stream_id_gen++, d.index(), d.type()); 55 } getDefaultStreamtorch::mtia::MTIAGuardImpl56 c10::Stream getDefaultStream(c10::Device d) const override { 57 return default_streams[d.index()]; 58 } getStreamFromGlobalPooltorch::mtia::MTIAGuardImpl59 c10::Stream getStreamFromGlobalPool( 60 c10::Device d, 61 bool isHighPriority = false) const override { 62 return c10::Stream::unpack3(stream_id_gen++, d.index(), d.type()); 63 } 64 // NB: These do NOT set the current device exchangeStreamtorch::mtia::MTIAGuardImpl65 c10::Stream exchangeStream(c10::Stream s) const noexcept override { 66 c10::Stream old_stream = getStream(s.device()); 67 return old_stream; 68 } deviceCounttorch::mtia::MTIAGuardImpl69 c10::DeviceIndex deviceCount() const noexcept override { 70 return kMTIADeviceCount; 71 } 72 destroyEventtorch::mtia::MTIAGuardImpl73 void destroyEvent(void* event, const c10::DeviceIndex device_index) 74 const noexcept override { 75 (void)device_index; 76 } 77 recordtorch::mtia::MTIAGuardImpl78 void record( 79 void** event, 80 const c10::Stream& stream, 81 const c10::DeviceIndex device_index, 82 const c10::EventFlag flag) const override { 83 TORCH_CHECK( 84 device_index == -1 || device_index == stream.device_index(), 85 "Event device index ", 86 device_index, 87 " does not match recording stream's device index ", 88 stream.device_index(), 89 "."); 90 91 const auto orig_device = getDevice(); 92 93 setDevice(stream.device()); 94 95 if (*event == nullptr) { 96 *event = reinterpret_cast<void*>(event_id_gen++); 97 } 98 setDevice(orig_device); 99 } 100 blocktorch::mtia::MTIAGuardImpl101 void block(void* event, const c10::Stream& stream) const override { 102 (void)event; 103 (void)stream; 104 } 105 106 // May be called from any device queryEventtorch::mtia::MTIAGuardImpl107 bool queryEvent(void* event) const override { 108 (void)event; 109 return true; 110 } 111 112 // Stream-related functions queryStreamtorch::mtia::MTIAGuardImpl113 bool queryStream(const c10::Stream& stream) const override { 114 (void)stream; 115 return true; 116 } 117 synchronizeStreamtorch::mtia::MTIAGuardImpl118 void synchronizeStream(const c10::Stream& stream) const override { 119 (void)stream; 120 } 121 recordDataPtrOnStreamtorch::mtia::MTIAGuardImpl122 void recordDataPtrOnStream( 123 const c10::DataPtr& data_ptr, 124 const c10::Stream& stream) const override { 125 (void)data_ptr; 126 (void)stream; 127 } 128 elapsedTimetorch::mtia::MTIAGuardImpl129 double elapsedTime(void* event1, void* event2, const c10::DeviceIndex device_index) const override { 130 (void)device_index; 131 uint64_t elapsed_time = 1e6; 132 return (double)(elapsed_time / 1e6); 133 } 134 synchronizeEventtorch::mtia::MTIAGuardImpl135 void synchronizeEvent(void* event) const override { 136 (void)event; 137 } 138 }; 139 140 struct MTIAHooks : public at::MTIAHooksInterface { MTIAHookstorch::mtia::MTIAHooks141 explicit MTIAHooks(at::MTIAHooksArgs) {} initMTIAtorch::mtia::MTIAHooks142 void initMTIA() const override {} 143 hasMTIAtorch::mtia::MTIAHooks144 bool hasMTIA() const override { 145 return true; 146 } 147 deviceCounttorch::mtia::MTIAHooks148 c10::DeviceIndex deviceCount() const override { 149 torch::utils::device_lazy_init(at::kMTIA); 150 return c10::DeviceIndex(2); 151 } 152 deviceSynchronizetorch::mtia::MTIAHooks153 void deviceSynchronize(c10::DeviceIndex device_index) const override { 154 torch::utils::device_lazy_init(at::kMTIA); 155 (void)device_index; 156 } 157 showConfigtorch::mtia::MTIAHooks158 std::string showConfig() const override { 159 return "None config"; 160 } 161 exchangeDevicetorch::mtia::MTIAHooks162 c10::DeviceIndex exchangeDevice(c10::DeviceIndex device) const override { 163 torch::utils::device_lazy_init(at::kMTIA); 164 auto orig_device = current_device; 165 if (current_device != device) { 166 current_device = device; 167 } 168 return orig_device; 169 } 170 maybeExchangeDevicetorch::mtia::MTIAHooks171 c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device) const override { 172 torch::utils::device_lazy_init(at::kMTIA); 173 174 auto orig_device = current_device; 175 if (current_device != device) { 176 current_device = device; 177 } 178 return orig_device; 179 } 180 getDefaultStreamtorch::mtia::MTIAHooks181 c10::Stream getDefaultStream(c10::DeviceIndex device) const override { 182 torch::utils::device_lazy_init(at::kMTIA); 183 184 return default_streams[device]; 185 } 186 getCurrentStreamtorch::mtia::MTIAHooks187 c10::Stream getCurrentStream(c10::DeviceIndex device) const override { 188 torch::utils::device_lazy_init(at::kMTIA); 189 190 return current_streams[device]; 191 } 192 setCurrentStreamtorch::mtia::MTIAHooks193 void setCurrentStream(const c10::Stream& stream) const override { 194 torch::utils::device_lazy_init(at::kMTIA); 195 196 current_streams[stream.device_index()] = stream; 197 } 198 getCurrentDevicetorch::mtia::MTIAHooks199 c10::DeviceIndex getCurrentDevice() const override { 200 torch::utils::device_lazy_init(at::kMTIA); 201 202 return current_device; 203 } 204 setCurrentDevicetorch::mtia::MTIAHooks205 void setCurrentDevice(c10::DeviceIndex device) const override { 206 torch::utils::device_lazy_init(at::kMTIA); 207 208 if (current_device != device) { 209 current_device = device; 210 } 211 } 212 }; 213 214 using at::MTIAHooksRegistry; 215 using at::RegistererMTIAHooksRegistry; 216 217 REGISTER_MTIA_HOOKS(MTIAHooks); 218 C10_REGISTER_GUARD_IMPL(MTIA, MTIAGuardImpl); 219 220 } // namespace torch::mtia 221