1 #pragma once
2
3 #include <c10/core/DeviceType.h>
4 #include <c10/core/DispatchKey.h>
5 #include <c10/core/DispatchKeySet.h>
6 #include <c10/util/Exception.h>
7
8 #include <stdexcept>
9
10 namespace c10 {
11
12 /**
13 * This legacy enum class defines the set of backends supported by old school,
14 * code generated Type-based ATen. A "backend" in this sense roughly
15 * corresponds to the cartesian product of (device type, layout), but restricted
16 * only to combinations which we actually have kernels for. Backend does NOT
17 * include dtype.
18 *
19 * The reason we are sunsetting this enum class is because it doesn't allow for
20 * open registration; e.g., if you want to add SparseXLA, you'd have to
21 * edit this enum; you wouldn't be able to do it out of tree. DispatchKey is
22 * the replacement for Backend which supports open registration.
23 *
24 * NB: The concept of 'Backend' here disagrees with the notion of backend
25 * exposed to users in torch.backends. Backend here is something like "CPU"
26 * or "SparseCUDA"; backend in torch.backends is something like "MKL" or
27 * "CUDNN".
28 */
29 enum class Backend {
30 CPU,
31 CUDA,
32 HIP,
33 VE,
34 FPGA,
35 IPU,
36 XPU,
37 SparseCPU,
38 SparseCUDA,
39 SparseCsrCPU,
40 SparseCsrCUDA,
41 SparseHIP,
42 SparseVE,
43 SparseXPU,
44 SparsePrivateUse1,
45 SparseCsrHIP,
46 SparseCsrVE,
47 SparseCsrXPU,
48 SparseCsrPrivateUse1,
49 MAIA,
50 XLA,
51 Vulkan,
52 Metal,
53 Meta,
54 QuantizedCPU,
55 QuantizedCUDA,
56 QuantizedXPU,
57 QuantizedPrivateUse1,
58 Undefined,
59 MkldnnCPU,
60 MPS,
61 HPU,
62 Lazy,
63 MTIA,
64 PrivateUse1,
65 NumOptions
66 };
67
dispatchKeyToBackend(DispatchKey t)68 inline Backend dispatchKeyToBackend(DispatchKey t) {
69 if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) {
70 return Backend::CPU;
71 } else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) {
72 return Backend::CUDA;
73 } else if (t == DispatchKey::HIP) {
74 return Backend::HIP;
75 } else if (t == DispatchKey::VE) {
76 return Backend::VE;
77 } else if (t == DispatchKey::FPGA) {
78 return Backend::FPGA;
79 } else if (t == DispatchKey::MAIA) {
80 return Backend::MAIA;
81 } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
82 return Backend::XLA;
83 } else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
84 return Backend::Lazy;
85 } else if (t == DispatchKey::MPS || t == DispatchKey::AutogradMPS) {
86 return Backend::MPS;
87 } else if (t == DispatchKey::Vulkan) {
88 return Backend::Vulkan;
89 } else if (t == DispatchKey::Metal) {
90 return Backend::Metal;
91 } else if (t == DispatchKey::Meta) {
92 return Backend::Meta;
93 } else if (t == DispatchKey::SparseCPU) {
94 return Backend::SparseCPU;
95 } else if (t == DispatchKey::SparseCUDA) {
96 return Backend::SparseCUDA;
97 } else if (t == DispatchKey::SparseHIP) {
98 return Backend::SparseHIP;
99 } else if (t == DispatchKey::SparseVE) {
100 return Backend::SparseVE;
101 } else if (t == DispatchKey::SparsePrivateUse1) {
102 return Backend::SparsePrivateUse1;
103 } else if (t == DispatchKey::SparseCsrCPU) {
104 return Backend::SparseCsrCPU;
105 } else if (t == DispatchKey::SparseCsrCUDA) {
106 return Backend::SparseCsrCUDA;
107 } else if (t == DispatchKey::SparseCsrHIP) {
108 return Backend::SparseCsrHIP;
109 } else if (t == DispatchKey::SparseCsrVE) {
110 return Backend::SparseCsrVE;
111 } else if (t == DispatchKey::SparseCsrPrivateUse1) {
112 return Backend::SparseCsrPrivateUse1;
113 } else if (t == DispatchKey::MkldnnCPU) {
114 return Backend::MkldnnCPU;
115 } else if (t == DispatchKey::QuantizedCPU) {
116 return Backend::QuantizedCPU;
117 } else if (t == DispatchKey::QuantizedCUDA) {
118 return Backend::QuantizedCUDA;
119 } else if (t == DispatchKey::IPU || t == DispatchKey::AutogradIPU) {
120 return Backend::IPU;
121 } else if (t == DispatchKey::XPU || t == DispatchKey::AutogradXPU) {
122 return Backend::XPU;
123 } else if (t == DispatchKey::SparseXPU) {
124 return Backend::SparseXPU;
125 } else if (t == DispatchKey::SparseCsrXPU) {
126 return Backend::SparseCsrXPU;
127 } else if (t == DispatchKey::QuantizedXPU) {
128 return Backend::QuantizedXPU;
129 } else if (t == DispatchKey::QuantizedPrivateUse1) {
130 return Backend::QuantizedPrivateUse1;
131 } else if (t == DispatchKey::HPU || t == DispatchKey::AutogradHPU) {
132 return Backend::HPU;
133 } else if (t == DispatchKey::MTIA || t == DispatchKey::AutogradMTIA) {
134 return Backend::MTIA;
135 } else if (
136 t == DispatchKey::PrivateUse1 || t == DispatchKey::AutogradPrivateUse1) {
137 return Backend::PrivateUse1;
138 } else if (t == DispatchKey::Undefined) {
139 return Backend::Undefined;
140 } else {
141 TORCH_CHECK(false, "Unrecognized tensor type ID: ", t);
142 }
143 }
144
backendToDispatchKey(Backend b)145 inline DispatchKey backendToDispatchKey(Backend b) {
146 switch (b) {
147 case Backend::CPU:
148 return DispatchKey::CPU;
149 case Backend::CUDA:
150 return DispatchKey::CUDA;
151 case Backend::HIP:
152 return DispatchKey::HIP;
153 case Backend::VE:
154 return DispatchKey::VE;
155 case Backend::FPGA:
156 return DispatchKey::FPGA;
157 case Backend::MAIA:
158 return DispatchKey::MAIA;
159 case Backend::XLA:
160 return DispatchKey::XLA;
161 case Backend::Lazy:
162 return DispatchKey::Lazy;
163 case Backend::IPU:
164 return DispatchKey::IPU;
165 case Backend::XPU:
166 return DispatchKey::XPU;
167 case Backend::SparseXPU:
168 return DispatchKey::SparseXPU;
169 case Backend::SparseCsrXPU:
170 return DispatchKey::SparseCsrXPU;
171 case Backend::SparseCPU:
172 return DispatchKey::SparseCPU;
173 case Backend::SparseCUDA:
174 return DispatchKey::SparseCUDA;
175 case Backend::SparseHIP:
176 return DispatchKey::SparseHIP;
177 case Backend::SparseVE:
178 return DispatchKey::SparseVE;
179 case Backend::SparsePrivateUse1:
180 return DispatchKey::SparsePrivateUse1;
181 case Backend::SparseCsrCPU:
182 return DispatchKey::SparseCsrCPU;
183 case Backend::SparseCsrCUDA:
184 return DispatchKey::SparseCsrCUDA;
185 case Backend::SparseCsrHIP:
186 return DispatchKey::SparseCsrHIP;
187 case Backend::SparseCsrVE:
188 return DispatchKey::SparseCsrVE;
189 case Backend::SparseCsrPrivateUse1:
190 return DispatchKey::SparseCsrPrivateUse1;
191 case Backend::MkldnnCPU:
192 return DispatchKey::MkldnnCPU;
193 case Backend::Vulkan:
194 return DispatchKey::Vulkan;
195 case Backend::Metal:
196 return DispatchKey::Metal;
197 case Backend::Meta:
198 return DispatchKey::Meta;
199 case Backend::QuantizedCPU:
200 return DispatchKey::QuantizedCPU;
201 case Backend::QuantizedCUDA:
202 return DispatchKey::QuantizedCUDA;
203 case Backend::QuantizedPrivateUse1:
204 return DispatchKey::QuantizedPrivateUse1;
205 case Backend::Undefined:
206 return DispatchKey::Undefined;
207 case Backend::MPS:
208 return DispatchKey::MPS;
209 case Backend::HPU:
210 return DispatchKey::HPU;
211 case Backend::MTIA:
212 return DispatchKey::MTIA;
213 case Backend::PrivateUse1:
214 return DispatchKey::PrivateUse1;
215 default:
216 throw std::runtime_error("Unknown backend");
217 }
218 }
219
backendToDeviceType(Backend b)220 inline DeviceType backendToDeviceType(Backend b) {
221 switch (b) {
222 case Backend::CPU:
223 case Backend::MkldnnCPU:
224 case Backend::SparseCPU:
225 case Backend::SparseCsrCPU:
226 case Backend::QuantizedCPU:
227 return DeviceType::CPU;
228 case Backend::CUDA:
229 case Backend::SparseCUDA:
230 case Backend::QuantizedCUDA:
231 case Backend::SparseCsrCUDA:
232 return DeviceType::CUDA;
233 case Backend::HIP:
234 return DeviceType::HIP;
235 case Backend::VE:
236 return DeviceType::VE;
237 case Backend::FPGA:
238 return DeviceType::FPGA;
239 case Backend::MAIA:
240 return DeviceType::MAIA;
241 case Backend::XLA:
242 return DeviceType::XLA;
243 case Backend::Lazy:
244 return DeviceType::Lazy;
245 case Backend::SparseHIP:
246 return DeviceType::HIP;
247 case Backend::SparseVE:
248 return DeviceType::VE;
249 case Backend::SparseCsrHIP:
250 return DeviceType::HIP;
251 case Backend::SparseCsrVE:
252 return DeviceType::VE;
253 case Backend::IPU:
254 return DeviceType::IPU;
255 case Backend::XPU:
256 case Backend::SparseXPU:
257 case Backend::SparseCsrXPU:
258 case Backend::QuantizedXPU:
259 return DeviceType::XPU;
260 case Backend::Vulkan:
261 return DeviceType::Vulkan;
262 case Backend::Metal:
263 return DeviceType::Metal;
264 case Backend::Meta:
265 return DeviceType::Meta;
266 case Backend::MPS:
267 return DeviceType::MPS;
268 case Backend::HPU:
269 return DeviceType::HPU;
270 case Backend::MTIA:
271 return DeviceType::MTIA;
272 case Backend::PrivateUse1:
273 case Backend::SparsePrivateUse1:
274 case Backend::SparseCsrPrivateUse1:
275 case Backend::QuantizedPrivateUse1:
276 return DeviceType::PrivateUse1;
277 case Backend::Undefined:
278 TORCH_CHECK(false, "Undefined backend is not a valid device type");
279 default:
280 TORCH_CHECK(false, "Unknown backend");
281 }
282 }
283
toString(Backend b)284 inline const char* toString(Backend b) {
285 switch (b) {
286 case Backend::CPU:
287 return "CPU";
288 case Backend::CUDA:
289 return "CUDA";
290 case Backend::HIP:
291 return "HIP";
292 case Backend::VE:
293 return "VE";
294 case Backend::FPGA:
295 return "FPGA";
296 case Backend::XPU:
297 return "XPU";
298 case Backend::IPU:
299 return "IPU";
300 case Backend::MAIA:
301 return "MAIA";
302 case Backend::XLA:
303 return "XLA";
304 case Backend::Lazy:
305 return "Lazy";
306 case Backend::MPS:
307 return "MPS";
308 case Backend::SparseCPU:
309 return "SparseCPU";
310 case Backend::SparseCUDA:
311 return "SparseCUDA";
312 case Backend::SparseHIP:
313 return "SparseHIP";
314 case Backend::SparseVE:
315 return "SparseVE";
316 case Backend::SparseXPU:
317 return "SparseXPU";
318 case Backend::SparsePrivateUse1:
319 return "SparsePrivateUse1";
320 case Backend::SparseCsrCPU:
321 return "SparseCsrCPU";
322 case Backend::SparseCsrCUDA:
323 return "SparseCsrCUDA";
324 case Backend::SparseCsrHIP:
325 return "SparseCsrHIP";
326 case Backend::SparseCsrVE:
327 return "SparseCsrVE";
328 case Backend::SparseCsrXPU:
329 return "SparseCsrXPU";
330 case Backend::SparseCsrPrivateUse1:
331 return "SparseCsrPrivateUse1";
332 case Backend::MkldnnCPU:
333 return "MkldnnCPU";
334 case Backend::Vulkan:
335 return "Vulkan";
336 case Backend::Metal:
337 return "Metal";
338 case Backend::Meta:
339 return "Meta";
340 case Backend::QuantizedCPU:
341 return "QuantizedCPU";
342 case Backend::QuantizedCUDA:
343 return "QuantizedCUDA";
344 case Backend::QuantizedXPU:
345 return "QuantizedXPU";
346 case Backend::QuantizedPrivateUse1:
347 return "QuantizedPrivateUse1";
348 case Backend::HPU:
349 return "HPU";
350 case Backend::MTIA:
351 return "MTIA";
352 case Backend::PrivateUse1:
353 return "PrivateUseOne";
354 default:
355 return "UNKNOWN_BACKEND";
356 }
357 }
358
isSparse(Backend b)359 inline bool isSparse(Backend b) {
360 switch (b) {
361 case Backend::SparseXPU:
362 case Backend::SparseCPU:
363 case Backend::SparseCUDA:
364 case Backend::SparseHIP:
365 case Backend::SparseVE:
366 case Backend::SparsePrivateUse1:
367 return true;
368 default:
369 return false;
370 }
371 }
372
isSparseCsr(Backend b)373 inline bool isSparseCsr(Backend b) {
374 switch (b) {
375 case Backend::SparseCsrXPU:
376 case Backend::SparseCsrCPU:
377 case Backend::SparseCsrCUDA:
378 case Backend::SparseCsrHIP:
379 case Backend::SparseCsrVE:
380 case Backend::SparseCsrPrivateUse1:
381 return true;
382 default:
383 return false;
384 }
385 }
386
387 } // namespace c10
388