xref: /aosp_15_r20/external/pytorch/c10/core/Backend.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/DeviceType.h>
4 #include <c10/core/DispatchKey.h>
5 #include <c10/core/DispatchKeySet.h>
6 #include <c10/util/Exception.h>
7 
8 #include <stdexcept>
9 
10 namespace c10 {
11 
12 /**
13  * This legacy enum class defines the set of backends supported by old school,
14  * code generated Type-based ATen.  A "backend" in this sense roughly
15  * corresponds to the cartesian product of (device type, layout), but restricted
16  * only to combinations which we actually have kernels for.  Backend does NOT
17  * include dtype.
18  *
19  * The reason we are sunsetting this enum class is because it doesn't allow for
20  * open registration; e.g., if you want to add SparseXLA, you'd have to
21  * edit this enum; you wouldn't be able to do it out of tree.  DispatchKey is
22  * the replacement for Backend which supports open registration.
23  *
24  * NB: The concept of 'Backend' here disagrees with the notion of backend
25  * exposed to users in torch.backends.  Backend here is something like "CPU"
26  * or "SparseCUDA"; backend in torch.backends is something like "MKL" or
27  * "CUDNN".
28  */
29 enum class Backend {
30   CPU,
31   CUDA,
32   HIP,
33   VE,
34   FPGA,
35   IPU,
36   XPU,
37   SparseCPU,
38   SparseCUDA,
39   SparseCsrCPU,
40   SparseCsrCUDA,
41   SparseHIP,
42   SparseVE,
43   SparseXPU,
44   SparsePrivateUse1,
45   SparseCsrHIP,
46   SparseCsrVE,
47   SparseCsrXPU,
48   SparseCsrPrivateUse1,
49   MAIA,
50   XLA,
51   Vulkan,
52   Metal,
53   Meta,
54   QuantizedCPU,
55   QuantizedCUDA,
56   QuantizedXPU,
57   QuantizedPrivateUse1,
58   Undefined,
59   MkldnnCPU,
60   MPS,
61   HPU,
62   Lazy,
63   MTIA,
64   PrivateUse1,
65   NumOptions
66 };
67 
dispatchKeyToBackend(DispatchKey t)68 inline Backend dispatchKeyToBackend(DispatchKey t) {
69   if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) {
70     return Backend::CPU;
71   } else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) {
72     return Backend::CUDA;
73   } else if (t == DispatchKey::HIP) {
74     return Backend::HIP;
75   } else if (t == DispatchKey::VE) {
76     return Backend::VE;
77   } else if (t == DispatchKey::FPGA) {
78     return Backend::FPGA;
79   } else if (t == DispatchKey::MAIA) {
80     return Backend::MAIA;
81   } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
82     return Backend::XLA;
83   } else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
84     return Backend::Lazy;
85   } else if (t == DispatchKey::MPS || t == DispatchKey::AutogradMPS) {
86     return Backend::MPS;
87   } else if (t == DispatchKey::Vulkan) {
88     return Backend::Vulkan;
89   } else if (t == DispatchKey::Metal) {
90     return Backend::Metal;
91   } else if (t == DispatchKey::Meta) {
92     return Backend::Meta;
93   } else if (t == DispatchKey::SparseCPU) {
94     return Backend::SparseCPU;
95   } else if (t == DispatchKey::SparseCUDA) {
96     return Backend::SparseCUDA;
97   } else if (t == DispatchKey::SparseHIP) {
98     return Backend::SparseHIP;
99   } else if (t == DispatchKey::SparseVE) {
100     return Backend::SparseVE;
101   } else if (t == DispatchKey::SparsePrivateUse1) {
102     return Backend::SparsePrivateUse1;
103   } else if (t == DispatchKey::SparseCsrCPU) {
104     return Backend::SparseCsrCPU;
105   } else if (t == DispatchKey::SparseCsrCUDA) {
106     return Backend::SparseCsrCUDA;
107   } else if (t == DispatchKey::SparseCsrHIP) {
108     return Backend::SparseCsrHIP;
109   } else if (t == DispatchKey::SparseCsrVE) {
110     return Backend::SparseCsrVE;
111   } else if (t == DispatchKey::SparseCsrPrivateUse1) {
112     return Backend::SparseCsrPrivateUse1;
113   } else if (t == DispatchKey::MkldnnCPU) {
114     return Backend::MkldnnCPU;
115   } else if (t == DispatchKey::QuantizedCPU) {
116     return Backend::QuantizedCPU;
117   } else if (t == DispatchKey::QuantizedCUDA) {
118     return Backend::QuantizedCUDA;
119   } else if (t == DispatchKey::IPU || t == DispatchKey::AutogradIPU) {
120     return Backend::IPU;
121   } else if (t == DispatchKey::XPU || t == DispatchKey::AutogradXPU) {
122     return Backend::XPU;
123   } else if (t == DispatchKey::SparseXPU) {
124     return Backend::SparseXPU;
125   } else if (t == DispatchKey::SparseCsrXPU) {
126     return Backend::SparseCsrXPU;
127   } else if (t == DispatchKey::QuantizedXPU) {
128     return Backend::QuantizedXPU;
129   } else if (t == DispatchKey::QuantizedPrivateUse1) {
130     return Backend::QuantizedPrivateUse1;
131   } else if (t == DispatchKey::HPU || t == DispatchKey::AutogradHPU) {
132     return Backend::HPU;
133   } else if (t == DispatchKey::MTIA || t == DispatchKey::AutogradMTIA) {
134     return Backend::MTIA;
135   } else if (
136       t == DispatchKey::PrivateUse1 || t == DispatchKey::AutogradPrivateUse1) {
137     return Backend::PrivateUse1;
138   } else if (t == DispatchKey::Undefined) {
139     return Backend::Undefined;
140   } else {
141     TORCH_CHECK(false, "Unrecognized tensor type ID: ", t);
142   }
143 }
144 
backendToDispatchKey(Backend b)145 inline DispatchKey backendToDispatchKey(Backend b) {
146   switch (b) {
147     case Backend::CPU:
148       return DispatchKey::CPU;
149     case Backend::CUDA:
150       return DispatchKey::CUDA;
151     case Backend::HIP:
152       return DispatchKey::HIP;
153     case Backend::VE:
154       return DispatchKey::VE;
155     case Backend::FPGA:
156       return DispatchKey::FPGA;
157     case Backend::MAIA:
158       return DispatchKey::MAIA;
159     case Backend::XLA:
160       return DispatchKey::XLA;
161     case Backend::Lazy:
162       return DispatchKey::Lazy;
163     case Backend::IPU:
164       return DispatchKey::IPU;
165     case Backend::XPU:
166       return DispatchKey::XPU;
167     case Backend::SparseXPU:
168       return DispatchKey::SparseXPU;
169     case Backend::SparseCsrXPU:
170       return DispatchKey::SparseCsrXPU;
171     case Backend::SparseCPU:
172       return DispatchKey::SparseCPU;
173     case Backend::SparseCUDA:
174       return DispatchKey::SparseCUDA;
175     case Backend::SparseHIP:
176       return DispatchKey::SparseHIP;
177     case Backend::SparseVE:
178       return DispatchKey::SparseVE;
179     case Backend::SparsePrivateUse1:
180       return DispatchKey::SparsePrivateUse1;
181     case Backend::SparseCsrCPU:
182       return DispatchKey::SparseCsrCPU;
183     case Backend::SparseCsrCUDA:
184       return DispatchKey::SparseCsrCUDA;
185     case Backend::SparseCsrHIP:
186       return DispatchKey::SparseCsrHIP;
187     case Backend::SparseCsrVE:
188       return DispatchKey::SparseCsrVE;
189     case Backend::SparseCsrPrivateUse1:
190       return DispatchKey::SparseCsrPrivateUse1;
191     case Backend::MkldnnCPU:
192       return DispatchKey::MkldnnCPU;
193     case Backend::Vulkan:
194       return DispatchKey::Vulkan;
195     case Backend::Metal:
196       return DispatchKey::Metal;
197     case Backend::Meta:
198       return DispatchKey::Meta;
199     case Backend::QuantizedCPU:
200       return DispatchKey::QuantizedCPU;
201     case Backend::QuantizedCUDA:
202       return DispatchKey::QuantizedCUDA;
203     case Backend::QuantizedPrivateUse1:
204       return DispatchKey::QuantizedPrivateUse1;
205     case Backend::Undefined:
206       return DispatchKey::Undefined;
207     case Backend::MPS:
208       return DispatchKey::MPS;
209     case Backend::HPU:
210       return DispatchKey::HPU;
211     case Backend::MTIA:
212       return DispatchKey::MTIA;
213     case Backend::PrivateUse1:
214       return DispatchKey::PrivateUse1;
215     default:
216       throw std::runtime_error("Unknown backend");
217   }
218 }
219 
backendToDeviceType(Backend b)220 inline DeviceType backendToDeviceType(Backend b) {
221   switch (b) {
222     case Backend::CPU:
223     case Backend::MkldnnCPU:
224     case Backend::SparseCPU:
225     case Backend::SparseCsrCPU:
226     case Backend::QuantizedCPU:
227       return DeviceType::CPU;
228     case Backend::CUDA:
229     case Backend::SparseCUDA:
230     case Backend::QuantizedCUDA:
231     case Backend::SparseCsrCUDA:
232       return DeviceType::CUDA;
233     case Backend::HIP:
234       return DeviceType::HIP;
235     case Backend::VE:
236       return DeviceType::VE;
237     case Backend::FPGA:
238       return DeviceType::FPGA;
239     case Backend::MAIA:
240       return DeviceType::MAIA;
241     case Backend::XLA:
242       return DeviceType::XLA;
243     case Backend::Lazy:
244       return DeviceType::Lazy;
245     case Backend::SparseHIP:
246       return DeviceType::HIP;
247     case Backend::SparseVE:
248       return DeviceType::VE;
249     case Backend::SparseCsrHIP:
250       return DeviceType::HIP;
251     case Backend::SparseCsrVE:
252       return DeviceType::VE;
253     case Backend::IPU:
254       return DeviceType::IPU;
255     case Backend::XPU:
256     case Backend::SparseXPU:
257     case Backend::SparseCsrXPU:
258     case Backend::QuantizedXPU:
259       return DeviceType::XPU;
260     case Backend::Vulkan:
261       return DeviceType::Vulkan;
262     case Backend::Metal:
263       return DeviceType::Metal;
264     case Backend::Meta:
265       return DeviceType::Meta;
266     case Backend::MPS:
267       return DeviceType::MPS;
268     case Backend::HPU:
269       return DeviceType::HPU;
270     case Backend::MTIA:
271       return DeviceType::MTIA;
272     case Backend::PrivateUse1:
273     case Backend::SparsePrivateUse1:
274     case Backend::SparseCsrPrivateUse1:
275     case Backend::QuantizedPrivateUse1:
276       return DeviceType::PrivateUse1;
277     case Backend::Undefined:
278       TORCH_CHECK(false, "Undefined backend is not a valid device type");
279     default:
280       TORCH_CHECK(false, "Unknown backend");
281   }
282 }
283 
toString(Backend b)284 inline const char* toString(Backend b) {
285   switch (b) {
286     case Backend::CPU:
287       return "CPU";
288     case Backend::CUDA:
289       return "CUDA";
290     case Backend::HIP:
291       return "HIP";
292     case Backend::VE:
293       return "VE";
294     case Backend::FPGA:
295       return "FPGA";
296     case Backend::XPU:
297       return "XPU";
298     case Backend::IPU:
299       return "IPU";
300     case Backend::MAIA:
301       return "MAIA";
302     case Backend::XLA:
303       return "XLA";
304     case Backend::Lazy:
305       return "Lazy";
306     case Backend::MPS:
307       return "MPS";
308     case Backend::SparseCPU:
309       return "SparseCPU";
310     case Backend::SparseCUDA:
311       return "SparseCUDA";
312     case Backend::SparseHIP:
313       return "SparseHIP";
314     case Backend::SparseVE:
315       return "SparseVE";
316     case Backend::SparseXPU:
317       return "SparseXPU";
318     case Backend::SparsePrivateUse1:
319       return "SparsePrivateUse1";
320     case Backend::SparseCsrCPU:
321       return "SparseCsrCPU";
322     case Backend::SparseCsrCUDA:
323       return "SparseCsrCUDA";
324     case Backend::SparseCsrHIP:
325       return "SparseCsrHIP";
326     case Backend::SparseCsrVE:
327       return "SparseCsrVE";
328     case Backend::SparseCsrXPU:
329       return "SparseCsrXPU";
330     case Backend::SparseCsrPrivateUse1:
331       return "SparseCsrPrivateUse1";
332     case Backend::MkldnnCPU:
333       return "MkldnnCPU";
334     case Backend::Vulkan:
335       return "Vulkan";
336     case Backend::Metal:
337       return "Metal";
338     case Backend::Meta:
339       return "Meta";
340     case Backend::QuantizedCPU:
341       return "QuantizedCPU";
342     case Backend::QuantizedCUDA:
343       return "QuantizedCUDA";
344     case Backend::QuantizedXPU:
345       return "QuantizedXPU";
346     case Backend::QuantizedPrivateUse1:
347       return "QuantizedPrivateUse1";
348     case Backend::HPU:
349       return "HPU";
350     case Backend::MTIA:
351       return "MTIA";
352     case Backend::PrivateUse1:
353       return "PrivateUseOne";
354     default:
355       return "UNKNOWN_BACKEND";
356   }
357 }
358 
isSparse(Backend b)359 inline bool isSparse(Backend b) {
360   switch (b) {
361     case Backend::SparseXPU:
362     case Backend::SparseCPU:
363     case Backend::SparseCUDA:
364     case Backend::SparseHIP:
365     case Backend::SparseVE:
366     case Backend::SparsePrivateUse1:
367       return true;
368     default:
369       return false;
370   }
371 }
372 
isSparseCsr(Backend b)373 inline bool isSparseCsr(Backend b) {
374   switch (b) {
375     case Backend::SparseCsrXPU:
376     case Backend::SparseCsrCPU:
377     case Backend::SparseCsrCUDA:
378     case Backend::SparseCsrHIP:
379     case Backend::SparseCsrVE:
380     case Backend::SparseCsrPrivateUse1:
381       return true;
382     default:
383       return false;
384   }
385 }
386 
387 } // namespace c10
388