xref: /aosp_15_r20/external/pytorch/c10/core/DispatchKey.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/DispatchKey.h>
2 #include <c10/core/DispatchKeySet.h>
3 
4 #include <unordered_map>
5 
6 namespace c10 {
7 
toString(BackendComponent t)8 const char* toString(BackendComponent t) {
9   switch (t) {
10     case BackendComponent::CPUBit:
11       return "CPUBit";
12     case BackendComponent::CUDABit:
13       return "CUDABit";
14     case BackendComponent::HIPBit:
15       return "HIPBit";
16     case BackendComponent::XLABit:
17       return "XLABit";
18     case BackendComponent::LazyBit:
19       return "LazyBit";
20     case BackendComponent::MetaBit:
21       return "MetaBit";
22     case BackendComponent::XPUBit:
23       return "XPUBit";
24     case BackendComponent::IPUBit:
25       return "IPUBit";
26     case BackendComponent::MPSBit:
27       return "MPSBit";
28     case BackendComponent::HPUBit:
29       return "HPUBit";
30     case BackendComponent::VEBit:
31       return "VEBit";
32     case BackendComponent::MTIABit:
33       return "MTIA";
34     case BackendComponent::PrivateUse1Bit:
35       return "PrivateUse1Bit";
36     case BackendComponent::PrivateUse2Bit:
37       return "PrivateUse2Bit";
38     case BackendComponent::PrivateUse3Bit:
39       return "PrivateUse3Bit";
40     case BackendComponent::InvalidBit:
41       return "InvalidBit";
42     default:
43       return "UNKNOWN_BACKEND_BIT";
44   }
45 }
46 
toBackendComponent(DeviceType device_type)47 BackendComponent toBackendComponent(DeviceType device_type) {
48   switch (device_type) {
49 #define DO_CASE(device, _)                          \
50   case DeviceType::device: {                        \
51     return toBackendComponent(DispatchKey::device); \
52   }
53     C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
54 #undef DO_CASE
55     default:
56       return BackendComponent::InvalidBit;
57   }
58 }
59 
toString(DispatchKey t)60 const char* toString(DispatchKey t) {
61   switch (t) {
62     case DispatchKey::Undefined:
63       return "Undefined";
64 
65     case DispatchKey::Dense:
66       return "Dense";
67     case DispatchKey::FPGA:
68       return "FPGA";
69     case DispatchKey::MAIA:
70       return "MAIA";
71     case DispatchKey::Vulkan:
72       return "Vulkan";
73     case DispatchKey::Metal:
74       return "Metal";
75 
76     case DispatchKey::Lazy:
77       return "Lazy";
78     case DispatchKey::MPS:
79       return "MPS";
80     case DispatchKey::HPU:
81       return "HPU";
82     case DispatchKey::MTIA:
83       return "MTIA";
84 
85     case DispatchKey::Quantized:
86       return "Quantized";
87     case DispatchKey::CustomRNGKeyId:
88       return "CustomRNGKeyId";
89     case DispatchKey::MkldnnCPU:
90       return "MkldnnCPU";
91 
92     case DispatchKey::Sparse:
93       return "Sparse";
94 
95     case DispatchKey::SparseCsr:
96       return "SparseCsr";
97 
98     case DispatchKey::NestedTensor:
99       return "NestedTensor";
100 
101     case DispatchKey::BackendSelect:
102       return "BackendSelect";
103 
104     case DispatchKey::Python:
105       return "Python";
106 
107     case DispatchKey::Fake:
108       return "Fake";
109     case DispatchKey::FuncTorchDynamicLayerBackMode:
110       return "FuncTorchDynamicLayerBackMode";
111 
112     case DispatchKey::Functionalize:
113       return "Functionalize";
114 
115     case DispatchKey::Named:
116       return "Named";
117 
118     case DispatchKey::Conjugate:
119       return "Conjugate";
120     case DispatchKey::Negative:
121       return "Negative";
122     case DispatchKey::ZeroTensor:
123       return "ZeroTensor";
124 
125     case DispatchKey::ADInplaceOrView:
126       return "ADInplaceOrView";
127 
128     case DispatchKey::AutogradOther:
129       return "AutogradOther";
130     case DispatchKey::AutogradFunctionality:
131       return "AutogradFunctionality";
132     case DispatchKey::AutogradNestedTensor:
133       return "AutogradNestedTensor";
134 
135     case DispatchKey::Tracer:
136       return "Tracer";
137 
138     case DispatchKey::AutocastCPU:
139       return "AutocastCPU";
140     case DispatchKey::AutocastXPU:
141       return "AutocastXPU";
142     case DispatchKey::AutocastIPU:
143       return "AutocastIPU";
144     case DispatchKey::AutocastHPU:
145       return "AutocastHPU";
146     case DispatchKey::AutocastCUDA:
147       return "AutocastCUDA";
148     case DispatchKey::AutocastXLA:
149       return "AutocastXLA";
150     case DispatchKey::AutocastPrivateUse1:
151       return "AutocastPrivateUse1";
152     case DispatchKey::AutocastMPS:
153       return "AutocastMPS";
154 
155     case DispatchKey::FuncTorchBatched:
156       return "FuncTorchBatched";
157     case DispatchKey::BatchedNestedTensor:
158       return "BatchedNestedTensor";
159     case DispatchKey::FuncTorchVmapMode:
160       return "FuncTorchVmapMode";
161 
162     case DispatchKey::Batched:
163       return "Batched";
164     case DispatchKey::VmapMode:
165       return "VmapMode";
166 
167     case DispatchKey::FuncTorchGradWrapper:
168       return "FuncTorchGradWrapper";
169 
170     case DispatchKey::DeferredInit:
171       return "DeferredInit";
172     case DispatchKey::PythonTLSSnapshot:
173       return "PythonTLSSnapshot";
174 
175     // Note [Out-of-tree vmap+grad prototype]
176     // The following keys are used in the implementation of the out-of-tree
177     // composable functions transforms (vmap+grad) prototype that lives at
178     // https://github.com/zou3519/functorch
179     // We plan on eventually upstreaming the prototype into core, at which
180     // point it will have a different design that should use fewer keys.
181     case DispatchKey::FuncTorchDynamicLayerFrontMode:
182       return "FuncTorchDynamicLayerFrontMode";
183 
184     case DispatchKey::TESTING_ONLY_GenericWrapper:
185       return "TESTING_ONLY_GenericWrapper";
186 
187     case DispatchKey::TESTING_ONLY_GenericMode:
188       return "TESTING_ONLY_GenericMode";
189 
190     case DispatchKey::PreDispatch:
191       return "PreDispatch";
192 
193     case DispatchKey::PythonDispatcher:
194       return "PythonDispatcher";
195 
196       // Aliases
197 
198     case DispatchKey::Autograd:
199       return "Autograd";
200     case DispatchKey::CompositeImplicitAutograd:
201       return "CompositeImplicitAutograd";
202     case DispatchKey::CompositeImplicitAutogradNestedTensor:
203       return "CompositeImplicitAutogradNestedTensor";
204     case DispatchKey::CompositeExplicitAutograd:
205       return "CompositeExplicitAutograd";
206     case DispatchKey::CompositeExplicitAutogradNonFunctional:
207       return "CompositeExplicitAutogradNonFunctional";
208     case DispatchKey::FuncTorchBatchedDecomposition:
209       return "FuncTorchBatchedDecomposition";
210 
211       // Per-backend dispatch keys
212 
213     default:
214       auto bc = toBackendComponent(t);
215       auto fk = toFunctionalityKey(t);
216 
217       switch (fk) {
218 #define ENTRY(backend, functionality)  \
219   case BackendComponent::backend##Bit: \
220     return #functionality #backend;
221 
222 #define FORALL_BC(dkname, prefix)                  \
223   case DispatchKey::dkname:                        \
224     switch (bc) {                                  \
225       C10_FORALL_BACKEND_COMPONENTS(ENTRY, prefix) \
226       default:                                     \
227         return #prefix "Undefined";                \
228     }
229 
230         C10_FORALL_FUNCTIONALITY_KEYS(FORALL_BC)
231 
232         default:
233           switch (bc) {
234             C10_FORALL_BACKEND_COMPONENTS(ENTRY, Unknown)
235             default:
236               return "UnknownUnknown";
237           }
238 
239 #undef FORALL_BC
240 #undef ENTRY
241       }
242   }
243 }
244 
operator <<(std::ostream & str,DispatchKey rhs)245 std::ostream& operator<<(std::ostream& str, DispatchKey rhs) {
246   return str << toString(rhs);
247 }
operator <<(std::ostream & str,BackendComponent rhs)248 std::ostream& operator<<(std::ostream& str, BackendComponent rhs) {
249   return str << toString(rhs);
250 }
251 
getAutogradKeyFromBackend(BackendComponent k)252 DispatchKey getAutogradKeyFromBackend(BackendComponent k) {
253   // We want this to return an autograd key. We're relying on the fact that
254   // getAutogradRelatedKeySetFromBackend returns an autograd key +
255   // ADInplaceOrView, and autograd has higher precedence. The core mapping from
256   // backend -> autograd key lives in `getAutogradRelatedKeySetFromBackend`
257   // instead of here for performance. `getAutogradRelatedKeySetFromBackend` is a
258   // hotpath function, and we want to make sure that it doesn't have to
259   // construct any DispatchKeySets at runtime.
260   return getAutogradRelatedKeySetFromBackend(k).highestPriorityTypeId();
261 }
262 
parseDispatchKey(const std::string & k)263 c10::DispatchKey parseDispatchKey(const std::string& k) {
264   static std::unordered_map<std::string, c10::DispatchKey> key_map = {
265       {"Undefined", c10::DispatchKey::Undefined},
266       {"Dense", c10::DispatchKey::Dense},
267       {"FPGA", c10::DispatchKey::FPGA},
268       {"MAIA", c10::DispatchKey::MAIA},
269       {"MPS", c10::DispatchKey::MPS},
270       {"Vulkan", c10::DispatchKey::Vulkan},
271       {"Metal", c10::DispatchKey::Metal},
272       {"VE", c10::DispatchKey::VE},
273       {"Meta", c10::DispatchKey::Meta},
274       {"Quantized", c10::DispatchKey::Quantized},
275       {"CustomRNGKeyId", c10::DispatchKey::CustomRNGKeyId},
276       {"MkldnnCPU", c10::DispatchKey::MkldnnCPU},
277       {"Sparse", c10::DispatchKey::Sparse},
278       {"SparseCsr", c10::DispatchKey::SparseCsr},
279       {"BackendSelect", c10::DispatchKey::BackendSelect},
280       {"Python", c10::DispatchKey::Python},
281       {"PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot},
282       {"Fake", c10::DispatchKey::Fake},
283       {"Named", c10::DispatchKey::Named},
284       {"Conjugate", c10::DispatchKey::Conjugate},
285       {"Negative", c10::DispatchKey::Negative},
286       {"ZeroTensor", c10::DispatchKey::ZeroTensor},
287       {"FuncTorchDynamicLayerBackMode",
288        c10::DispatchKey::FuncTorchDynamicLayerBackMode},
289       {"Functionalize", c10::DispatchKey::Functionalize},
290       {"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView},
291       {"AutogradOther", c10::DispatchKey::AutogradOther},
292       {"AutogradFunctionality", c10::DispatchKey::AutogradFunctionality},
293       {"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor},
294       {"Tracer", c10::DispatchKey::Tracer},
295       {"AutocastCPU", c10::DispatchKey::AutocastCPU},
296       {"AutocastXPU", c10::DispatchKey::AutocastXPU},
297       {"AutocastIPU", c10::DispatchKey::AutocastIPU},
298       {"AutocastHPU", c10::DispatchKey::AutocastHPU},
299       {"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
300       {"AutocastXLA", c10::DispatchKey::AutocastXLA},
301       {"AutocastPrivateUse1", c10::DispatchKey::AutocastPrivateUse1},
302       {"AutocastMPS", c10::DispatchKey::AutocastMPS},
303       {"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched},
304       {"BatchedNestedTensor", c10::DispatchKey::BatchedNestedTensor},
305       {"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode},
306       {"Batched", c10::DispatchKey::Batched},
307       {"VmapMode", c10::DispatchKey::VmapMode},
308       {"DeferredInit", c10::DispatchKey::DeferredInit},
309       {"FuncTorchGradWrapper", c10::DispatchKey::FuncTorchGradWrapper},
310       {"FuncTorchDynamicLayerFrontMode",
311        c10::DispatchKey::FuncTorchDynamicLayerFrontMode},
312       {"TESTING_ONLY_GenericWrapper",
313        c10::DispatchKey::TESTING_ONLY_GenericWrapper},
314       {"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
315       {"PythonDispatcher", c10::DispatchKey::PythonDispatcher},
316       {"PreDispatch", c10::DispatchKey::PreDispatch},
317 
318       {"CPU", c10::DispatchKey::CPU},
319       {"CUDA", c10::DispatchKey::CUDA},
320       {"HIP", c10::DispatchKey::HIP},
321       {"XLA", c10::DispatchKey::XLA},
322       {"MPS", c10::DispatchKey::MPS},
323       {"XPU", c10::DispatchKey::XPU},
324       {"IPU", c10::DispatchKey::IPU},
325       {"HPU", c10::DispatchKey::HPU},
326       {"Lazy", c10::DispatchKey::Lazy},
327       {"MTIA", c10::DispatchKey::MTIA},
328       {"NestedTensor", c10::DispatchKey::NestedTensor},
329       {"NestedTensorCPU", c10::DispatchKey::NestedTensorCPU},
330       {"NestedTensorCUDA", c10::DispatchKey::NestedTensorCUDA},
331       {"NestedTensorMeta", c10::DispatchKey::NestedTensorMeta},
332       {"NestedTensorPrivateUse1", c10::DispatchKey::NestedTensorPrivateUse1},
333       {"PrivateUse1", c10::DispatchKey::PrivateUse1},
334       {"PrivateUse2", c10::DispatchKey::PrivateUse2},
335       {"PrivateUse3", c10::DispatchKey::PrivateUse3},
336 
337       {"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
338       {"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA},
339       {"QuantizedXPU", c10::DispatchKey::QuantizedXPU},
340       {"QuantizedPrivateUse1", c10::DispatchKey::QuantizedPrivateUse1},
341 
342       {"SparseCPU", c10::DispatchKey::SparseCPU},
343       {"SparseCUDA", c10::DispatchKey::SparseCUDA},
344       {"SparseHIP", c10::DispatchKey::SparseHIP},
345       {"SparseXPU", c10::DispatchKey::SparseXPU},
346       {"SparseVE", c10::DispatchKey::SparseVE},
347       {"SparseMeta", c10::DispatchKey::SparseMeta},
348       {"SparsePrivateUse1", c10::DispatchKey::SparsePrivateUse1},
349 
350       {"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
351       {"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
352       {"SparseCsrHIP", c10::DispatchKey::SparseCsrHIP},
353       {"SparseCsrXPU", c10::DispatchKey::SparseCsrXPU},
354       {"SparseCsrVE", c10::DispatchKey::SparseCsrVE},
355       {"SparseCsrMeta", c10::DispatchKey::SparseCsrMeta},
356       {"SparseCsrPrivateUse1", c10::DispatchKey::SparseCsrPrivateUse1},
357 
358       {"AutogradCPU", c10::DispatchKey::AutogradCPU},
359       {"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
360       {"AutogradXLA", c10::DispatchKey::AutogradXLA},
361       {"AutogradLazy", c10::DispatchKey::AutogradLazy},
362       {"AutogradMeta", c10::DispatchKey::AutogradMeta},
363       {"AutogradIPU", c10::DispatchKey::AutogradIPU},
364       {"AutogradXPU", c10::DispatchKey::AutogradXPU},
365       {"AutogradMPS", c10::DispatchKey::AutogradMPS},
366       {"AutogradHPU", c10::DispatchKey::AutogradHPU},
367       {"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1},
368       {"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2},
369       {"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3},
370 
371       {"Autograd", c10::DispatchKey::Autograd},
372       {"CompositeImplicitAutograd",
373        c10::DispatchKey::CompositeImplicitAutograd},
374       {"CompositeImplicitAutogradNestedTensor",
375        c10::DispatchKey::CompositeImplicitAutogradNestedTensor},
376       {"CompositeExplicitAutograd",
377        c10::DispatchKey::CompositeExplicitAutograd},
378       {"CompositeExplicitAutogradNonFunctional",
379        c10::DispatchKey::CompositeExplicitAutogradNonFunctional},
380       {"FuncTorchBatchedDecomposition",
381        c10::DispatchKey::FuncTorchBatchedDecomposition},
382   };
383   auto it = key_map.find(k);
384   TORCH_CHECK(it != key_map.end(), "could not parse dispatch key: ", k);
385   return it->second;
386 }
387 
388 } // namespace c10
389