1 #pragma once 2 3 #include <c10/core/Device.h> 4 #include <c10/util/Exception.h> 5 6 #include <c10/core/Stream.h> 7 #include <c10/util/Registry.h> 8 9 #include <c10/core/Allocator.h> 10 11 #include <c10/util/python_stub.h> 12 #include <ATen/detail/AcceleratorHooksInterface.h> 13 14 #include <string> 15 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") 16 namespace at { 17 class Context; 18 } 19 20 namespace at { 21 constexpr const char* MTIA_HELP = 22 "The MTIA backend requires MTIA extension for PyTorch;" 23 "this error has occurred because you are trying " 24 "to use some MTIA's functionality without MTIA extension included."; 25 26 struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { 27 // this fails the implementation if MTIAHooks functions are called, but 28 // MTIA backend is not present. 29 #define FAIL_MTIAHOOKS_FUNC(func) \ 30 TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend."); 31 32 ~MTIAHooksInterface() override = default; 33 initMTIAMTIAHooksInterface34 virtual void initMTIA() const { 35 // Avoid logging here, since MTIA needs init devices first then it will know 36 // how many devices are available. Make it as no-op if mtia extension is not 37 // dynamically loaded. 38 return; 39 } 40 hasMTIAMTIAHooksInterface41 virtual bool hasMTIA() const { 42 return false; 43 } 44 deviceCountMTIAHooksInterface45 DeviceIndex deviceCount() const override { 46 return 0; 47 } 48 deviceSynchronizeMTIAHooksInterface49 virtual void deviceSynchronize(c10::DeviceIndex device_index) const { 50 FAIL_MTIAHOOKS_FUNC(__func__); 51 } 52 showConfigMTIAHooksInterface53 virtual std::string showConfig() const { 54 FAIL_MTIAHOOKS_FUNC(__func__); 55 } 56 hasPrimaryContextMTIAHooksInterface57 bool hasPrimaryContext(DeviceIndex device_index) const override { 58 return false; 59 } 60 setCurrentDeviceMTIAHooksInterface61 void setCurrentDevice(DeviceIndex device) const override { 62 FAIL_MTIAHOOKS_FUNC(__func__); 63 } 64 getCurrentDeviceMTIAHooksInterface65 DeviceIndex getCurrentDevice() const override { 66 FAIL_MTIAHOOKS_FUNC(__func__); 67 return -1; 68 } 69 exchangeDeviceMTIAHooksInterface70 DeviceIndex exchangeDevice(DeviceIndex device) const override { 71 FAIL_MTIAHOOKS_FUNC(__func__); 72 return -1; 73 } 74 maybeExchangeDeviceMTIAHooksInterface75 DeviceIndex maybeExchangeDevice(DeviceIndex device) const override { 76 FAIL_MTIAHOOKS_FUNC(__func__); 77 return -1; 78 } 79 getCurrentStreamMTIAHooksInterface80 virtual c10::Stream getCurrentStream(DeviceIndex device) const { 81 FAIL_MTIAHOOKS_FUNC(__func__); 82 return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA); 83 } 84 getDefaultStreamMTIAHooksInterface85 virtual c10::Stream getDefaultStream(DeviceIndex device) const { 86 FAIL_MTIAHOOKS_FUNC(__func__); 87 return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA); 88 } 89 setCurrentStreamMTIAHooksInterface90 virtual void setCurrentStream(const c10::Stream& stream) const { 91 FAIL_MTIAHOOKS_FUNC(__func__); 92 } 93 isPinnedPtrMTIAHooksInterface94 bool isPinnedPtr(const void* data) const override { 95 return false; 96 } 97 getPinnedMemoryAllocatorMTIAHooksInterface98 Allocator* getPinnedMemoryAllocator() const override { 99 FAIL_MTIAHOOKS_FUNC(__func__); 100 return nullptr; 101 } 102 memoryStatsMTIAHooksInterface103 virtual PyObject* memoryStats(DeviceIndex device) const { 104 FAIL_MTIAHOOKS_FUNC(__func__); 105 return nullptr; 106 } 107 }; 108 109 struct TORCH_API MTIAHooksArgs {}; 110 111 C10_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs); 112 #define REGISTER_MTIA_HOOKS(clsname) \ 113 C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname) 114 115 namespace detail { 116 TORCH_API const MTIAHooksInterface& getMTIAHooks(); 117 TORCH_API bool isMTIAHooksBuilt(); 118 } // namespace detail 119 } // namespace at 120 C10_DIAGNOSTIC_POP() 121