1 #include <c10/core/DispatchKey.h>
2 #include <c10/core/DispatchKeySet.h>
3
4 #include <unordered_map>
5
6 namespace c10 {
7
toString(BackendComponent t)8 const char* toString(BackendComponent t) {
9 switch (t) {
10 case BackendComponent::CPUBit:
11 return "CPUBit";
12 case BackendComponent::CUDABit:
13 return "CUDABit";
14 case BackendComponent::HIPBit:
15 return "HIPBit";
16 case BackendComponent::XLABit:
17 return "XLABit";
18 case BackendComponent::LazyBit:
19 return "LazyBit";
20 case BackendComponent::MetaBit:
21 return "MetaBit";
22 case BackendComponent::XPUBit:
23 return "XPUBit";
24 case BackendComponent::IPUBit:
25 return "IPUBit";
26 case BackendComponent::MPSBit:
27 return "MPSBit";
28 case BackendComponent::HPUBit:
29 return "HPUBit";
30 case BackendComponent::VEBit:
31 return "VEBit";
32 case BackendComponent::MTIABit:
33 return "MTIA";
34 case BackendComponent::PrivateUse1Bit:
35 return "PrivateUse1Bit";
36 case BackendComponent::PrivateUse2Bit:
37 return "PrivateUse2Bit";
38 case BackendComponent::PrivateUse3Bit:
39 return "PrivateUse3Bit";
40 case BackendComponent::InvalidBit:
41 return "InvalidBit";
42 default:
43 return "UNKNOWN_BACKEND_BIT";
44 }
45 }
46
toBackendComponent(DeviceType device_type)47 BackendComponent toBackendComponent(DeviceType device_type) {
48 switch (device_type) {
49 #define DO_CASE(device, _) \
50 case DeviceType::device: { \
51 return toBackendComponent(DispatchKey::device); \
52 }
53 C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
54 #undef DO_CASE
55 default:
56 return BackendComponent::InvalidBit;
57 }
58 }
59
toString(DispatchKey t)60 const char* toString(DispatchKey t) {
61 switch (t) {
62 case DispatchKey::Undefined:
63 return "Undefined";
64
65 case DispatchKey::Dense:
66 return "Dense";
67 case DispatchKey::FPGA:
68 return "FPGA";
69 case DispatchKey::MAIA:
70 return "MAIA";
71 case DispatchKey::Vulkan:
72 return "Vulkan";
73 case DispatchKey::Metal:
74 return "Metal";
75
76 case DispatchKey::Lazy:
77 return "Lazy";
78 case DispatchKey::MPS:
79 return "MPS";
80 case DispatchKey::HPU:
81 return "HPU";
82 case DispatchKey::MTIA:
83 return "MTIA";
84
85 case DispatchKey::Quantized:
86 return "Quantized";
87 case DispatchKey::CustomRNGKeyId:
88 return "CustomRNGKeyId";
89 case DispatchKey::MkldnnCPU:
90 return "MkldnnCPU";
91
92 case DispatchKey::Sparse:
93 return "Sparse";
94
95 case DispatchKey::SparseCsr:
96 return "SparseCsr";
97
98 case DispatchKey::NestedTensor:
99 return "NestedTensor";
100
101 case DispatchKey::BackendSelect:
102 return "BackendSelect";
103
104 case DispatchKey::Python:
105 return "Python";
106
107 case DispatchKey::Fake:
108 return "Fake";
109 case DispatchKey::FuncTorchDynamicLayerBackMode:
110 return "FuncTorchDynamicLayerBackMode";
111
112 case DispatchKey::Functionalize:
113 return "Functionalize";
114
115 case DispatchKey::Named:
116 return "Named";
117
118 case DispatchKey::Conjugate:
119 return "Conjugate";
120 case DispatchKey::Negative:
121 return "Negative";
122 case DispatchKey::ZeroTensor:
123 return "ZeroTensor";
124
125 case DispatchKey::ADInplaceOrView:
126 return "ADInplaceOrView";
127
128 case DispatchKey::AutogradOther:
129 return "AutogradOther";
130 case DispatchKey::AutogradFunctionality:
131 return "AutogradFunctionality";
132 case DispatchKey::AutogradNestedTensor:
133 return "AutogradNestedTensor";
134
135 case DispatchKey::Tracer:
136 return "Tracer";
137
138 case DispatchKey::AutocastCPU:
139 return "AutocastCPU";
140 case DispatchKey::AutocastXPU:
141 return "AutocastXPU";
142 case DispatchKey::AutocastIPU:
143 return "AutocastIPU";
144 case DispatchKey::AutocastHPU:
145 return "AutocastHPU";
146 case DispatchKey::AutocastCUDA:
147 return "AutocastCUDA";
148 case DispatchKey::AutocastXLA:
149 return "AutocastXLA";
150 case DispatchKey::AutocastPrivateUse1:
151 return "AutocastPrivateUse1";
152 case DispatchKey::AutocastMPS:
153 return "AutocastMPS";
154
155 case DispatchKey::FuncTorchBatched:
156 return "FuncTorchBatched";
157 case DispatchKey::BatchedNestedTensor:
158 return "BatchedNestedTensor";
159 case DispatchKey::FuncTorchVmapMode:
160 return "FuncTorchVmapMode";
161
162 case DispatchKey::Batched:
163 return "Batched";
164 case DispatchKey::VmapMode:
165 return "VmapMode";
166
167 case DispatchKey::FuncTorchGradWrapper:
168 return "FuncTorchGradWrapper";
169
170 case DispatchKey::DeferredInit:
171 return "DeferredInit";
172 case DispatchKey::PythonTLSSnapshot:
173 return "PythonTLSSnapshot";
174
175 // Note [Out-of-tree vmap+grad prototype]
176 // The following keys are used in the implementation of the out-of-tree
177 // composable functions transforms (vmap+grad) prototype that lives at
178 // https://github.com/zou3519/functorch
179 // We plan on eventually upstreaming the prototype into core, at which
180 // point it will have a different design that should use fewer keys.
181 case DispatchKey::FuncTorchDynamicLayerFrontMode:
182 return "FuncTorchDynamicLayerFrontMode";
183
184 case DispatchKey::TESTING_ONLY_GenericWrapper:
185 return "TESTING_ONLY_GenericWrapper";
186
187 case DispatchKey::TESTING_ONLY_GenericMode:
188 return "TESTING_ONLY_GenericMode";
189
190 case DispatchKey::PreDispatch:
191 return "PreDispatch";
192
193 case DispatchKey::PythonDispatcher:
194 return "PythonDispatcher";
195
196 // Aliases
197
198 case DispatchKey::Autograd:
199 return "Autograd";
200 case DispatchKey::CompositeImplicitAutograd:
201 return "CompositeImplicitAutograd";
202 case DispatchKey::CompositeImplicitAutogradNestedTensor:
203 return "CompositeImplicitAutogradNestedTensor";
204 case DispatchKey::CompositeExplicitAutograd:
205 return "CompositeExplicitAutograd";
206 case DispatchKey::CompositeExplicitAutogradNonFunctional:
207 return "CompositeExplicitAutogradNonFunctional";
208 case DispatchKey::FuncTorchBatchedDecomposition:
209 return "FuncTorchBatchedDecomposition";
210
211 // Per-backend dispatch keys
212
213 default:
214 auto bc = toBackendComponent(t);
215 auto fk = toFunctionalityKey(t);
216
217 switch (fk) {
218 #define ENTRY(backend, functionality) \
219 case BackendComponent::backend##Bit: \
220 return #functionality #backend;
221
222 #define FORALL_BC(dkname, prefix) \
223 case DispatchKey::dkname: \
224 switch (bc) { \
225 C10_FORALL_BACKEND_COMPONENTS(ENTRY, prefix) \
226 default: \
227 return #prefix "Undefined"; \
228 }
229
230 C10_FORALL_FUNCTIONALITY_KEYS(FORALL_BC)
231
232 default:
233 switch (bc) {
234 C10_FORALL_BACKEND_COMPONENTS(ENTRY, Unknown)
235 default:
236 return "UnknownUnknown";
237 }
238
239 #undef FORALL_BC
240 #undef ENTRY
241 }
242 }
243 }
244
operator <<(std::ostream & str,DispatchKey rhs)245 std::ostream& operator<<(std::ostream& str, DispatchKey rhs) {
246 return str << toString(rhs);
247 }
operator <<(std::ostream & str,BackendComponent rhs)248 std::ostream& operator<<(std::ostream& str, BackendComponent rhs) {
249 return str << toString(rhs);
250 }
251
getAutogradKeyFromBackend(BackendComponent k)252 DispatchKey getAutogradKeyFromBackend(BackendComponent k) {
253 // We want this to return an autograd key. We're relying on the fact that
254 // getAutogradRelatedKeySetFromBackend returns an autograd key +
255 // ADInplaceOrView, and autograd has higher precedence. The core mapping from
256 // backend -> autograd key lives in `getAutogradRelatedKeySetFromBackend`
257 // instead of here for performance. `getAutogradRelatedKeySetFromBackend` is a
258 // hotpath function, and we want to make sure that it doesn't have to
259 // construct any DispatchKeySets at runtime.
260 return getAutogradRelatedKeySetFromBackend(k).highestPriorityTypeId();
261 }
262
parseDispatchKey(const std::string & k)263 c10::DispatchKey parseDispatchKey(const std::string& k) {
264 static std::unordered_map<std::string, c10::DispatchKey> key_map = {
265 {"Undefined", c10::DispatchKey::Undefined},
266 {"Dense", c10::DispatchKey::Dense},
267 {"FPGA", c10::DispatchKey::FPGA},
268 {"MAIA", c10::DispatchKey::MAIA},
269 {"MPS", c10::DispatchKey::MPS},
270 {"Vulkan", c10::DispatchKey::Vulkan},
271 {"Metal", c10::DispatchKey::Metal},
272 {"VE", c10::DispatchKey::VE},
273 {"Meta", c10::DispatchKey::Meta},
274 {"Quantized", c10::DispatchKey::Quantized},
275 {"CustomRNGKeyId", c10::DispatchKey::CustomRNGKeyId},
276 {"MkldnnCPU", c10::DispatchKey::MkldnnCPU},
277 {"Sparse", c10::DispatchKey::Sparse},
278 {"SparseCsr", c10::DispatchKey::SparseCsr},
279 {"BackendSelect", c10::DispatchKey::BackendSelect},
280 {"Python", c10::DispatchKey::Python},
281 {"PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot},
282 {"Fake", c10::DispatchKey::Fake},
283 {"Named", c10::DispatchKey::Named},
284 {"Conjugate", c10::DispatchKey::Conjugate},
285 {"Negative", c10::DispatchKey::Negative},
286 {"ZeroTensor", c10::DispatchKey::ZeroTensor},
287 {"FuncTorchDynamicLayerBackMode",
288 c10::DispatchKey::FuncTorchDynamicLayerBackMode},
289 {"Functionalize", c10::DispatchKey::Functionalize},
290 {"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView},
291 {"AutogradOther", c10::DispatchKey::AutogradOther},
292 {"AutogradFunctionality", c10::DispatchKey::AutogradFunctionality},
293 {"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor},
294 {"Tracer", c10::DispatchKey::Tracer},
295 {"AutocastCPU", c10::DispatchKey::AutocastCPU},
296 {"AutocastXPU", c10::DispatchKey::AutocastXPU},
297 {"AutocastIPU", c10::DispatchKey::AutocastIPU},
298 {"AutocastHPU", c10::DispatchKey::AutocastHPU},
299 {"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
300 {"AutocastXLA", c10::DispatchKey::AutocastXLA},
301 {"AutocastPrivateUse1", c10::DispatchKey::AutocastPrivateUse1},
302 {"AutocastMPS", c10::DispatchKey::AutocastMPS},
303 {"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched},
304 {"BatchedNestedTensor", c10::DispatchKey::BatchedNestedTensor},
305 {"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode},
306 {"Batched", c10::DispatchKey::Batched},
307 {"VmapMode", c10::DispatchKey::VmapMode},
308 {"DeferredInit", c10::DispatchKey::DeferredInit},
309 {"FuncTorchGradWrapper", c10::DispatchKey::FuncTorchGradWrapper},
310 {"FuncTorchDynamicLayerFrontMode",
311 c10::DispatchKey::FuncTorchDynamicLayerFrontMode},
312 {"TESTING_ONLY_GenericWrapper",
313 c10::DispatchKey::TESTING_ONLY_GenericWrapper},
314 {"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
315 {"PythonDispatcher", c10::DispatchKey::PythonDispatcher},
316 {"PreDispatch", c10::DispatchKey::PreDispatch},
317
318 {"CPU", c10::DispatchKey::CPU},
319 {"CUDA", c10::DispatchKey::CUDA},
320 {"HIP", c10::DispatchKey::HIP},
321 {"XLA", c10::DispatchKey::XLA},
322 {"MPS", c10::DispatchKey::MPS},
323 {"XPU", c10::DispatchKey::XPU},
324 {"IPU", c10::DispatchKey::IPU},
325 {"HPU", c10::DispatchKey::HPU},
326 {"Lazy", c10::DispatchKey::Lazy},
327 {"MTIA", c10::DispatchKey::MTIA},
328 {"NestedTensor", c10::DispatchKey::NestedTensor},
329 {"NestedTensorCPU", c10::DispatchKey::NestedTensorCPU},
330 {"NestedTensorCUDA", c10::DispatchKey::NestedTensorCUDA},
331 {"NestedTensorMeta", c10::DispatchKey::NestedTensorMeta},
332 {"NestedTensorPrivateUse1", c10::DispatchKey::NestedTensorPrivateUse1},
333 {"PrivateUse1", c10::DispatchKey::PrivateUse1},
334 {"PrivateUse2", c10::DispatchKey::PrivateUse2},
335 {"PrivateUse3", c10::DispatchKey::PrivateUse3},
336
337 {"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
338 {"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA},
339 {"QuantizedXPU", c10::DispatchKey::QuantizedXPU},
340 {"QuantizedPrivateUse1", c10::DispatchKey::QuantizedPrivateUse1},
341
342 {"SparseCPU", c10::DispatchKey::SparseCPU},
343 {"SparseCUDA", c10::DispatchKey::SparseCUDA},
344 {"SparseHIP", c10::DispatchKey::SparseHIP},
345 {"SparseXPU", c10::DispatchKey::SparseXPU},
346 {"SparseVE", c10::DispatchKey::SparseVE},
347 {"SparseMeta", c10::DispatchKey::SparseMeta},
348 {"SparsePrivateUse1", c10::DispatchKey::SparsePrivateUse1},
349
350 {"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
351 {"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
352 {"SparseCsrHIP", c10::DispatchKey::SparseCsrHIP},
353 {"SparseCsrXPU", c10::DispatchKey::SparseCsrXPU},
354 {"SparseCsrVE", c10::DispatchKey::SparseCsrVE},
355 {"SparseCsrMeta", c10::DispatchKey::SparseCsrMeta},
356 {"SparseCsrPrivateUse1", c10::DispatchKey::SparseCsrPrivateUse1},
357
358 {"AutogradCPU", c10::DispatchKey::AutogradCPU},
359 {"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
360 {"AutogradXLA", c10::DispatchKey::AutogradXLA},
361 {"AutogradLazy", c10::DispatchKey::AutogradLazy},
362 {"AutogradMeta", c10::DispatchKey::AutogradMeta},
363 {"AutogradIPU", c10::DispatchKey::AutogradIPU},
364 {"AutogradXPU", c10::DispatchKey::AutogradXPU},
365 {"AutogradMPS", c10::DispatchKey::AutogradMPS},
366 {"AutogradHPU", c10::DispatchKey::AutogradHPU},
367 {"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1},
368 {"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2},
369 {"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3},
370
371 {"Autograd", c10::DispatchKey::Autograd},
372 {"CompositeImplicitAutograd",
373 c10::DispatchKey::CompositeImplicitAutograd},
374 {"CompositeImplicitAutogradNestedTensor",
375 c10::DispatchKey::CompositeImplicitAutogradNestedTensor},
376 {"CompositeExplicitAutograd",
377 c10::DispatchKey::CompositeExplicitAutograd},
378 {"CompositeExplicitAutogradNonFunctional",
379 c10::DispatchKey::CompositeExplicitAutogradNonFunctional},
380 {"FuncTorchBatchedDecomposition",
381 c10::DispatchKey::FuncTorchBatchedDecomposition},
382 };
383 auto it = key_map.find(k);
384 TORCH_CHECK(it != key_map.end(), "could not parse dispatch key: ", k);
385 return it->second;
386 }
387
388 } // namespace c10
389