xref: /aosp_15_r20/external/pytorch/aten/src/ATen/detail/MTIAHooksInterface.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Device.h>
4 #include <c10/util/Exception.h>
5 
6 #include <c10/core/Stream.h>
7 #include <c10/util/Registry.h>
8 
9 #include <c10/core/Allocator.h>
10 
11 #include <c10/util/python_stub.h>
12 #include <ATen/detail/AcceleratorHooksInterface.h>
13 
14 #include <string>
15 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
16 namespace at {
17 class Context;
18 }
19 
20 namespace at {
21 constexpr const char* MTIA_HELP =
22     "The MTIA backend requires MTIA extension for PyTorch;"
23     "this error has occurred because you are trying "
24     "to use some MTIA's functionality without MTIA extension included.";
25 
26 struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
27 // this fails the implementation if MTIAHooks functions are called, but
28 // MTIA backend is not present.
29 #define FAIL_MTIAHOOKS_FUNC(func) \
30   TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
31 
32   ~MTIAHooksInterface() override = default;
33 
initMTIAMTIAHooksInterface34   virtual void initMTIA() const {
35     // Avoid logging here, since MTIA needs init devices first then it will know
36     // how many devices are available. Make it as no-op if mtia extension is not
37     // dynamically loaded.
38     return;
39   }
40 
hasMTIAMTIAHooksInterface41   virtual bool hasMTIA() const {
42     return false;
43   }
44 
deviceCountMTIAHooksInterface45   DeviceIndex deviceCount() const override {
46     return 0;
47   }
48 
deviceSynchronizeMTIAHooksInterface49   virtual void deviceSynchronize(c10::DeviceIndex device_index) const {
50     FAIL_MTIAHOOKS_FUNC(__func__);
51   }
52 
showConfigMTIAHooksInterface53   virtual std::string showConfig() const {
54     FAIL_MTIAHOOKS_FUNC(__func__);
55   }
56 
hasPrimaryContextMTIAHooksInterface57   bool hasPrimaryContext(DeviceIndex device_index) const override {
58     return false;
59   }
60 
setCurrentDeviceMTIAHooksInterface61   void setCurrentDevice(DeviceIndex device) const override {
62     FAIL_MTIAHOOKS_FUNC(__func__);
63   }
64 
getCurrentDeviceMTIAHooksInterface65   DeviceIndex getCurrentDevice() const override {
66     FAIL_MTIAHOOKS_FUNC(__func__);
67     return -1;
68   }
69 
exchangeDeviceMTIAHooksInterface70   DeviceIndex exchangeDevice(DeviceIndex device) const override {
71     FAIL_MTIAHOOKS_FUNC(__func__);
72     return -1;
73   }
74 
maybeExchangeDeviceMTIAHooksInterface75   DeviceIndex maybeExchangeDevice(DeviceIndex device) const override {
76     FAIL_MTIAHOOKS_FUNC(__func__);
77     return -1;
78   }
79 
getCurrentStreamMTIAHooksInterface80   virtual c10::Stream getCurrentStream(DeviceIndex device) const {
81     FAIL_MTIAHOOKS_FUNC(__func__);
82     return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
83   }
84 
getDefaultStreamMTIAHooksInterface85   virtual c10::Stream getDefaultStream(DeviceIndex device) const {
86     FAIL_MTIAHOOKS_FUNC(__func__);
87     return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
88   }
89 
setCurrentStreamMTIAHooksInterface90   virtual void setCurrentStream(const c10::Stream& stream) const {
91     FAIL_MTIAHOOKS_FUNC(__func__);
92   }
93 
isPinnedPtrMTIAHooksInterface94   bool isPinnedPtr(const void* data) const override {
95     return false;
96   }
97 
getPinnedMemoryAllocatorMTIAHooksInterface98   Allocator* getPinnedMemoryAllocator() const override {
99     FAIL_MTIAHOOKS_FUNC(__func__);
100     return nullptr;
101   }
102 
memoryStatsMTIAHooksInterface103   virtual PyObject* memoryStats(DeviceIndex device) const {
104     FAIL_MTIAHOOKS_FUNC(__func__);
105     return nullptr;
106   }
107 };
108 
109 struct TORCH_API MTIAHooksArgs {};
110 
111 C10_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
112 #define REGISTER_MTIA_HOOKS(clsname) \
113   C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname)
114 
115 namespace detail {
116 TORCH_API const MTIAHooksInterface& getMTIAHooks();
117 TORCH_API bool isMTIAHooksBuilt();
118 } // namespace detail
119 } // namespace at
120 C10_DIAGNOSTIC_POP()
121