1 #include <ATen/autocast_mode.h>
2
3 #include <mutex>
4 #include <ATen/CachedTensorUtils.h>
5 #include <c10/util/flat_hash_map.h>
6
7 namespace at::autocast {
8
is_autocast_enabled(at::DeviceType device_type)9 bool is_autocast_enabled(at::DeviceType device_type) {
10 at::DispatchKey dispatch_key = get_autocast_dispatch_key_from_device_type(device_type);
11 return !c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
12 }
13
set_autocast_enabled(at::DeviceType device_type,bool enabled)14 void set_autocast_enabled(at::DeviceType device_type, bool enabled) {
15 at::DispatchKey dispatch_key = get_autocast_dispatch_key_from_device_type(device_type);
16 c10::impl::tls_set_dispatch_key_excluded(dispatch_key, !enabled);
17 }
18
19 namespace {
20 // Imitate Apex and cache some of the casts to streamline parameter reuse.
21 // Our heuristic is to cache lower_precision_fp casts of fp32 model weights (see cached_cast below).
22 //
23 // After discussion with @ezyang, the cache uses the following structure:
24 // The key is the fp32 source tensor's TensorImpl*, a proxy for a Tensor uuid that's
25 // unchanged across shallow copies.
26 // The value is a tuple with a weakref to the source tensor's TensorImpl as the first
27 // element and the casted tensor as the second element.
28 //
29 // The weakref keeps the source's TensorImpl from being deleted. We need to because we're
30 // using the source TensorImpl* as the key. If it were deleted, another random Tensor could
31 // be allocated whose TensorImpl* happened to have the same value. This TensorImpl* would
32 // then mistakenly hit in cache: a rare, intermittent, unpredictable bug.
33 //
34 // I'm not using the weak_intrusive_ptr as the key because it's more difficult to compare
35 // directly against incoming TensorImpl*s.
36 using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
37 using val_type = std::tuple<weakref_type, Tensor>;
38
get_cached_casts()39 static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
40 static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
41 return cached_casts;
42 }
43 std::mutex cached_casts_mutex;
44
45
46 // nesting tracks the nesting depth of the Python-side context manager.
47 // When the autocast context manager exits to a nesting level that's outside
48 // any instance of autocast (which should occur at the end of each forward pass)
49 // it calls clear_cache() to ensure cached Tensors don't leak outside the autocasting region.
50 thread_local int nesting = 0;
51
52 // The order of this array MUST exactly match the definition order of DeviceType
53 // in c10/core/DeviceType.h.
54 static_assert(
55 at::COMPILE_TIME_MAX_DEVICE_TYPES == 21,
56 "The definition of the default autocast data type per device backend doesn't match with the definition of the device type.");
57 thread_local std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
58 autocast_dtype = {
59 at::kBFloat16, // CPU
60 at::kHalf, // CUDA.
61 at::ScalarType::Undefined, // Reserved for explicit MKLDNN
62 at::ScalarType::Undefined, // OpenGL
63 at::ScalarType::Undefined, // OpenCL
64 at::ScalarType::Undefined, // IDEEP.
65 at::kHalf, // AMD HIP
66 at::ScalarType::Undefined, // FPGA
67 at::ScalarType::Undefined, // ONNX Runtime / Microsoft
68 at::kBFloat16, // XLA / TPU
69 at::ScalarType::Undefined, // Vulkan
70 at::ScalarType::Undefined, // Metal
71 at::kHalf, // XPU
72 at::kHalf, // MPS
73 at::ScalarType::Undefined, // Meta (tensors with no data)
74 at::kBFloat16, // HPU / HABANA
75 at::ScalarType::Undefined, // SX-Aurora / NEC
76 at::ScalarType::Undefined, // Lazy Tensors
77 at::kHalf, // Graphcore IPU
78 at::ScalarType::Undefined, // Meta training and inference devices
79 at::kHalf, // PrivateUse1 device
80 };
81
82 // should we enabled the cache inside autocast.
83 thread_local bool cache_enabled = true;
84
85 } // anonymous namespace
86
clear_cache()87 void clear_cache() {
88 const std::lock_guard<std::mutex> lock(cached_casts_mutex);
89 get_cached_casts().clear();
90 }
91
increment_nesting()92 int increment_nesting() {
93 return ++nesting;
94 }
95
decrement_nesting()96 int decrement_nesting() {
97 return --nesting;
98 }
99
get_autocast_dtype(at::DeviceType device_type)100 at::ScalarType get_autocast_dtype(at::DeviceType device_type) {
101 return autocast_dtype[static_cast<int>(device_type)];
102 }
103
set_autocast_dtype(at::DeviceType device_type,at::ScalarType dtype)104 void set_autocast_dtype(at::DeviceType device_type, at::ScalarType dtype) {
105 autocast_dtype[static_cast<int>(device_type)] = dtype;
106 }
107
is_autocast_cache_enabled()108 bool is_autocast_cache_enabled() {
109 return cache_enabled;
110 }
111
set_autocast_cache_enabled(bool enabled)112 void set_autocast_cache_enabled(bool enabled) {
113 cache_enabled = enabled;
114 }
115
116 // Overload to catch Tensor args
117 // TODO (possible optimization):
118 // Move cast_cache to an inline function in a header with cached_casts declared as
119 // extern thread_local in the header.
cached_cast(at::ScalarType to_type,const Tensor & arg,DeviceType device_type)120 Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_type) {
121 if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
122 // Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
123 // See cached_casts declaration above for detailed strategy.
124 bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
125 arg.scalar_type() == at::kFloat && arg.requires_grad() &&
126 arg.is_leaf() && !arg.is_view() && cache_enabled &&
127 !at::caching::is_cached_tensor(arg));
128
129 if (can_try_cache) {
130 const std::lock_guard<std::mutex> lock(cached_casts_mutex);
131 auto it = get_cached_casts().find(arg.unsafeGetTensorImpl());
132 if (it != get_cached_casts().end()) {
133 return std::get<1>(it->second);
134 } else {
135 auto casted_arg = arg.to(to_type);
136 get_cached_casts().emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg});
137 return casted_arg;
138 }
139 } else {
140 return arg.to(to_type);
141 }
142 } else {
143 return arg;
144 }
145 }
146
147 /*******************************
148 Banned functions
149 *******************************/
150
binary_cross_entropy_banned(const Tensor &,const Tensor &,const std::optional<Tensor> &,int64_t)151 static Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const std::optional<Tensor>&, int64_t) {
152 AT_ERROR("torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n"
153 "Many models use a sigmoid layer right before the binary cross entropy layer.\n"
154 "In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n"
155 "or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n"
156 "safe to autocast.");
157 }
158
159 namespace {
160
161 /*****************************************
162 Explicit registration for out-of-place ops
163 *****************************************/
164
TORCH_LIBRARY_IMPL(_,Autocast,m)165 TORCH_LIBRARY_IMPL(_, Autocast, m) {
166 m.fallback(torch::CppFunction::makeFallthrough());
167 }
168
TORCH_LIBRARY_IMPL(aten,Autocast,m)169 TORCH_LIBRARY_IMPL(aten, Autocast, m) {
170 // lower_precision_fp
171 #define _KERNEL_CUDA_LOW_PRECISION_FP(...) \
172 KERNEL_CUDA(__VA_ARGS__, lower_precision_fp)
173
174 AT_FORALL_LOWER_PRECISION_FP(_KERNEL_CUDA_LOW_PRECISION_FP)
175 KERNEL_CUDA(cudnn_convolution, lower_precision_fp)
176 KERNEL_CUDA(cudnn_convolution_transpose, lower_precision_fp)
177
178 // fp32
179 #define _KERNEL_CUDA_FP32(...) KERNEL_CUDA(__VA_ARGS__, fp32)
180
181 AT_FORALL_FP32(_KERNEL_CUDA_FP32)
182
183 // fp32_set_opt_dtype
184 #define _KERNEL_CUDA_FP32_SET_OPT_DTYPE(...) \
185 KERNEL_CUDA(__VA_ARGS__, fp32_set_opt_dtype)
186
187 AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_CUDA_FP32_SET_OPT_DTYPE)
188 // commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even
189 // when autocasting.
190 // KERNEL_CUDA(norm, ScalarOpt_dtype, fp32_set_opt_dtype)
191 // KERNEL_CUDA(norm, ScalarOpt_dim_dtype, fp32_set_opt_dtype)
192 // KERNEL_CUDA(norm, names_ScalarOpt_dim_dtype, fp32_set_opt_dtype)
193
194 // fp32_append_dtype
195 // The fp32_append_dtype wrapper overrides implicit promotion behavior.
196 // norm does not implicitly promote, but be aware when adding new ops to this policy.
197 AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(
198 KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA)
199
200 // promote
201 #define _KERNEL_CUDA_PROMOTE(...) KERNEL_CUDA(__VA_ARGS__, promote)
202
203 AT_FORALL_PROMOTE(_KERNEL_CUDA_PROMOTE)
204
205 m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
206 TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
207 }
208
TORCH_LIBRARY_IMPL(_,AutocastMPS,m)209 TORCH_LIBRARY_IMPL(_, AutocastMPS, m) {
210 m.fallback(torch::CppFunction::makeFallthrough());
211 }
212
TORCH_LIBRARY_IMPL(aten,AutocastMPS,m)213 TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) {
214 // lower_precision_fp
215 KERNEL_MPS2(_convolution, deprecated, lower_precision_fp)
216 KERNEL_MPS(_convolution, lower_precision_fp)
217 KERNEL_MPS(conv1d, lower_precision_fp)
218 KERNEL_MPS(conv2d, lower_precision_fp)
219 KERNEL_MPS(conv_tbc, lower_precision_fp)
220 KERNEL_MPS(conv_transpose1d, lower_precision_fp)
221 KERNEL_MPS2(conv_transpose2d, input, lower_precision_fp)
222 KERNEL_MPS(convolution, lower_precision_fp)
223 KERNEL_MPS(_mps_convolution, lower_precision_fp)
224 KERNEL_MPS(prelu, lower_precision_fp)
225 KERNEL_MPS(addmm, lower_precision_fp)
226 KERNEL_MPS(addmv, lower_precision_fp)
227 KERNEL_MPS(addr, lower_precision_fp)
228 KERNEL_MPS(matmul, lower_precision_fp)
229 KERNEL_MPS(einsum, lower_precision_fp)
230 KERNEL_MPS(mm, lower_precision_fp)
231 KERNEL_MPS(mv, lower_precision_fp)
232 KERNEL_MPS(linear, lower_precision_fp)
233 KERNEL_MPS(addbmm, lower_precision_fp)
234 KERNEL_MPS(baddbmm, lower_precision_fp)
235 KERNEL_MPS(bmm, lower_precision_fp)
236 KERNEL_MPS(chain_matmul, lower_precision_fp)
237 KERNEL_MPS(linalg_multi_dot, lower_precision_fp)
238 KERNEL_MPS(lstm_cell, lower_precision_fp)
239
240 // fp32
241 KERNEL_MPS(acos, fp32)
242 KERNEL_MPS(asin, fp32)
243 KERNEL_MPS(cosh, fp32)
244 KERNEL_MPS(erfinv, fp32)
245 KERNEL_MPS(exp, fp32)
246 KERNEL_MPS(expm1, fp32)
247 KERNEL_MPS(log, fp32)
248 KERNEL_MPS(log10, fp32)
249 KERNEL_MPS(log2, fp32)
250 KERNEL_MPS(log1p, fp32)
251 KERNEL_MPS(reciprocal, fp32)
252 KERNEL_MPS(rsqrt, fp32)
253 KERNEL_MPS(sinh, fp32)
254 KERNEL_MPS(tan, fp32)
255 KERNEL_MPS2(pow, Tensor_Scalar, fp32)
256 KERNEL_MPS2(pow, Tensor_Tensor, fp32)
257 KERNEL_MPS2(pow, Scalar, fp32)
258 KERNEL_MPS(softplus, fp32)
259 KERNEL_MPS(layer_norm, fp32)
260 KERNEL_MPS(native_layer_norm, fp32)
261 KERNEL_MPS(group_norm, fp32)
262 KERNEL_MPS2(frobenius_norm, dim, fp32)
263 KERNEL_MPS(nuclear_norm, fp32)
264 KERNEL_MPS2(nuclear_norm, dim, fp32)
265 KERNEL_MPS(batch_norm, fp32)
266 KERNEL_MPS(cosine_similarity, fp32)
267 KERNEL_MPS(poisson_nll_loss, fp32)
268 KERNEL_MPS(cosine_embedding_loss, fp32)
269 KERNEL_MPS(nll_loss, fp32)
270 KERNEL_MPS(nll_loss2d, fp32)
271 KERNEL_MPS(hinge_embedding_loss, fp32)
272 KERNEL_MPS(kl_div, fp32)
273 KERNEL_MPS(l1_loss, fp32)
274 KERNEL_MPS(smooth_l1_loss, fp32)
275 KERNEL_MPS(huber_loss, fp32)
276 KERNEL_MPS(mse_loss, fp32)
277 KERNEL_MPS(margin_ranking_loss, fp32)
278 KERNEL_MPS(multilabel_margin_loss, fp32)
279 KERNEL_MPS(soft_margin_loss, fp32)
280 KERNEL_MPS(triplet_margin_loss, fp32)
281 KERNEL_MPS(multi_margin_loss, fp32)
282 KERNEL_MPS(binary_cross_entropy_with_logits, fp32)
283 KERNEL_MPS(dist, fp32)
284 KERNEL_MPS(pdist, fp32)
285 KERNEL_MPS(cdist, fp32)
286 KERNEL_MPS(renorm, fp32)
287 KERNEL_MPS(logsumexp, fp32)
288
289 // fp32_set_opt_dtype
290 KERNEL_MPS(prod, fp32)
291 KERNEL_MPS2(prod, dim_int, fp32)
292 KERNEL_MPS2(prod, dim_Dimname, fp32)
293 KERNEL_MPS2(softmax, int, fp32)
294 KERNEL_MPS2(softmax, Dimname, fp32)
295 KERNEL_MPS2(log_softmax, int, fp32)
296 KERNEL_MPS2(log_softmax, Dimname, fp32)
297 KERNEL_MPS(cumprod, fp32)
298 KERNEL_MPS2(cumprod, dimname, fp32)
299 KERNEL_MPS(cumsum, fp32)
300 KERNEL_MPS2(cumsum, dimname, fp32)
301 KERNEL_MPS(linalg_vector_norm, fp32)
302 KERNEL_MPS(linalg_matrix_norm, fp32)
303 KERNEL_MPS2(linalg_matrix_norm, str_ord, fp32)
304 KERNEL_MPS(sum, fp32)
305 KERNEL_MPS2(sum, dim_IntList, fp32)
306 KERNEL_MPS2(sum, dim_DimnameList, fp32)
307 //
308 // promote
309 KERNEL_MPS(addcdiv, promote)
310 KERNEL_MPS(addcmul, promote)
311 KERNEL_MPS(atan2, promote)
312 KERNEL_MPS(bilinear, promote)
313 KERNEL_MPS(cross, promote)
314 KERNEL_MPS(dot, promote)
315 KERNEL_MPS(grid_sampler, promote)
316 KERNEL_MPS(index_put, promote)
317 KERNEL_MPS(tensordot, promote)
318 KERNEL_MPS(scatter_add, promote)
319 }
320
TORCH_LIBRARY_IMPL(_,AutocastCPU,m)321 TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
322 m.fallback(torch::CppFunction::makeFallthrough());
323 }
324
325
TORCH_LIBRARY_IMPL(aten,AutocastCPU,m)326 TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
327 // lower_precision_fp cast policy
328 KERNEL_CPU(conv1d, lower_precision_fp)
329 KERNEL_CPU(conv1d, padding, lower_precision_fp)
330 KERNEL_CPU(conv2d, lower_precision_fp)
331 KERNEL_CPU(conv2d, padding, lower_precision_fp)
332 KERNEL_CPU(conv3d, lower_precision_fp)
333 KERNEL_CPU(conv3d, padding, lower_precision_fp)
334 KERNEL_CPU(bmm, lower_precision_fp)
335 KERNEL_CPU(mm, lower_precision_fp)
336 KERNEL_CPU(linalg_vecdot, lower_precision_fp)
337 KERNEL_CPU(baddbmm, lower_precision_fp)
338 KERNEL_CPU(addmm, lower_precision_fp)
339 KERNEL_CPU(addbmm, lower_precision_fp)
340 KERNEL_CPU(linear, lower_precision_fp)
341 KERNEL_CPU(_convolution, deprecated, lower_precision_fp)
342 KERNEL_CPU(matmul, lower_precision_fp)
343 KERNEL_CPU(conv_tbc, lower_precision_fp)
344 KERNEL_CPU(mkldnn_rnn_layer, lower_precision_fp)
345 KERNEL_CPU(conv_transpose1d, lower_precision_fp)
346 KERNEL_CPU(conv_transpose2d, input, lower_precision_fp)
347 KERNEL_CPU(conv_transpose3d, input, lower_precision_fp)
348 KERNEL_CPU(prelu, lower_precision_fp)
349 KERNEL_CPU(scaled_dot_product_attention, lower_precision_fp)
350 KERNEL_CPU(_native_multi_head_attention, lower_precision_fp)
351
352 // fp32 cast policy
353 KERNEL_CPU(avg_pool3d, fp32)
354 KERNEL_CPU(binary_cross_entropy, fp32)
355 KERNEL_CPU(grid_sampler, fp32)
356 KERNEL_CPU(polar, fp32)
357 KERNEL_CPU(prod, fp32)
358 KERNEL_CPU(prod, dim_int, fp32)
359 KERNEL_CPU(prod, dim_Dimname, fp32)
360 KERNEL_CPU(quantile, fp32)
361 KERNEL_CPU(quantile, scalar, fp32)
362 KERNEL_CPU(nanquantile, fp32)
363 KERNEL_CPU(nanquantile, scalar, fp32)
364 KERNEL_CPU(stft, fp32)
365 KERNEL_CPU(stft, center, fp32)
366 KERNEL_CPU(cdist, fp32)
367 KERNEL_CPU(grid_sampler_2d, fp32)
368 KERNEL_CPU(_grid_sampler_2d_cpu_fallback, fp32)
369 KERNEL_CPU(grid_sampler_3d, fp32)
370 KERNEL_CPU(trace, fp32)
371 KERNEL_CPU(view_as_complex, fp32)
372 KERNEL_CPU(cholesky, fp32)
373 KERNEL_CPU(cholesky_inverse, fp32)
374 KERNEL_CPU(cholesky_solve, fp32)
375 KERNEL_CPU(inverse, fp32)
376 KERNEL_CPU(lu_solve, fp32)
377 KERNEL_CPU(orgqr, fp32)
378 KERNEL_CPU(ormqr, fp32)
379 KERNEL_CPU(pinverse, fp32)
380 KERNEL_CPU(max_pool3d, fp32)
381 KERNEL_CPU(max_unpool2d, fp32)
382 KERNEL_CPU(max_unpool3d, fp32)
383 KERNEL_CPU(adaptive_avg_pool3d, fp32)
384 KERNEL_CPU(reflection_pad1d, fp32)
385 KERNEL_CPU(reflection_pad2d, fp32)
386 KERNEL_CPU(replication_pad1d, fp32)
387 KERNEL_CPU(replication_pad2d, fp32)
388 KERNEL_CPU(replication_pad3d, fp32)
389 KERNEL_CPU(mse_loss, fp32)
390 KERNEL_CPU(cosine_embedding_loss, fp32)
391 KERNEL_CPU(nll_loss, fp32)
392 KERNEL_CPU(nll_loss2d, fp32)
393 KERNEL_CPU(hinge_embedding_loss, fp32)
394 KERNEL_CPU(poisson_nll_loss, fp32)
395 KERNEL_CPU(smooth_l1_loss, fp32)
396 KERNEL_CPU(cross_entropy_loss, fp32)
397 KERNEL_CPU(l1_loss, fp32)
398 KERNEL_CPU(huber_loss, fp32)
399 KERNEL_CPU(margin_ranking_loss, fp32)
400 KERNEL_CPU(soft_margin_loss, fp32)
401 KERNEL_CPU(triplet_margin_loss, fp32)
402 KERNEL_CPU(multi_margin_loss, fp32)
403 KERNEL_CPU(ctc_loss, IntList, fp32)
404 KERNEL_CPU(ctc_loss, Tensor, fp32)
405 KERNEL_CPU(kl_div, fp32)
406 KERNEL_CPU(multilabel_margin_loss, fp32)
407 KERNEL_CPU(binary_cross_entropy_with_logits, fp32)
408 KERNEL_CPU(fft_fft, fp32)
409 KERNEL_CPU(fft_ifft, fp32)
410 KERNEL_CPU(fft_fft2, fp32)
411 KERNEL_CPU(fft_ifft2, fp32)
412 KERNEL_CPU(fft_fftn, fp32)
413 KERNEL_CPU(fft_ifftn, fp32)
414 KERNEL_CPU(fft_rfft, fp32)
415 KERNEL_CPU(fft_irfft, fp32)
416 KERNEL_CPU(fft_rfft2, fp32)
417 KERNEL_CPU(fft_irfft2, fp32)
418 KERNEL_CPU(fft_rfftn, fp32)
419 KERNEL_CPU(fft_irfftn, fp32)
420 KERNEL_CPU(fft_hfft, fp32)
421 KERNEL_CPU(fft_ihfft, fp32)
422 KERNEL_CPU(linalg_cond, fp32)
423 KERNEL_CPU(linalg_cond, p_str, fp32)
424 KERNEL_CPU(linalg_matrix_rank, fp32)
425 KERNEL_CPU(linalg_matrix_rank, tol_tensor, fp32)
426 KERNEL_CPU(linalg_matrix_rank, atol_rtol_tensor, fp32)
427 KERNEL_CPU(linalg_matrix_rank, atol_rtol_float, fp32)
428 KERNEL_CPU(linalg_solve, fp32)
429 KERNEL_CPU(linalg_cholesky, fp32)
430 KERNEL_CPU(linalg_svdvals, fp32)
431 KERNEL_CPU(linalg_eigvals, fp32)
432 KERNEL_CPU(linalg_eigvalsh, fp32)
433 KERNEL_CPU(linalg_inv, fp32)
434 KERNEL_CPU(linalg_householder_product, fp32)
435 KERNEL_CPU(linalg_tensorinv, fp32)
436 KERNEL_CPU(linalg_tensorsolve, fp32)
437 KERNEL_CPU(fake_quantize_per_tensor_affine, fp32)
438 KERNEL_CPU(geqrf, fp32)
439 KERNEL_CPU(_lu_with_info, fp32)
440 KERNEL_CPU(qr, fp32)
441 KERNEL_CPU(svd, fp32)
442 KERNEL_CPU(triangular_solve, fp32)
443 KERNEL_CPU(fractional_max_pool2d, fp32)
444 KERNEL_CPU(fractional_max_pool3d, fp32)
445 KERNEL_CPU(adaptive_max_pool3d, fp32)
446 KERNEL_CPU(multilabel_margin_loss_forward, fp32)
447 KERNEL_CPU(linalg_qr, fp32)
448 KERNEL_CPU(linalg_cholesky_ex, fp32)
449 KERNEL_CPU(linalg_svd, fp32)
450 KERNEL_CPU(linalg_eig, fp32)
451 KERNEL_CPU(linalg_eigh, fp32)
452 KERNEL_CPU(linalg_lstsq, fp32)
453 KERNEL_CPU(linalg_inv_ex, fp32)
454
455 // promote
456 KERNEL_CPU(stack, promote)
457 KERNEL_CPU(cat, promote)
458 KERNEL_CPU(index_copy, promote)
459 KERNEL_CPU(index_copy, dimname, promote)
460
461 }
462
TORCH_LIBRARY_IMPL(_,AutocastXPU,m)463 TORCH_LIBRARY_IMPL(_, AutocastXPU, m) {
464 m.fallback(torch::CppFunction::makeFallthrough());
465 }
466
TORCH_LIBRARY_IMPL(aten,AutocastXPU,m)467 TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) {
468 // lower_precision_fp
469 #define _KERNEL_XPU_LOW_PRECISION_FP(...) \
470 KERNEL_XPU(__VA_ARGS__, lower_precision_fp)
471
472 AT_FORALL_LOWER_PRECISION_FP(_KERNEL_XPU_LOW_PRECISION_FP)
473
474 // fp32
475 #define _KERNEL_XPU_FP32(...) KERNEL_XPU(__VA_ARGS__, fp32)
476
477 AT_FORALL_FP32(_KERNEL_XPU_FP32)
478
479 // fp32_set_opt_dtype
480 #define _KERNEL_XPU_FP32_SET_OPT_DTYPE(...) \
481 KERNEL_XPU(__VA_ARGS__, fp32_set_opt_dtype)
482
483 AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_XPU_FP32_SET_OPT_DTYPE)
484
485 // fp32_append_dtype
486 // The fp32_append_dtype wrapper overrides implicit promotion behavior.
487 // norm does not implicitly promote, but be aware when adding new ops to this policy.
488 AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(
489 KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU)
490
491 // promote
492 #define _KERNEL_XPU_PROMOTE(...) KERNEL_XPU(__VA_ARGS__, promote)
493
494 AT_FORALL_PROMOTE(_KERNEL_XPU_PROMOTE)
495
496 m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
497 TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
498 }
499
500 } // namespace
501 } // namespace at::autocast
502