xref: /aosp_15_r20/external/pytorch/aten/src/ATen/autocast_mode.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/autocast_mode.h>
2 
3 #include <mutex>
4 #include <ATen/CachedTensorUtils.h>
5 #include <c10/util/flat_hash_map.h>
6 
7 namespace at::autocast {
8 
is_autocast_enabled(at::DeviceType device_type)9 bool is_autocast_enabled(at::DeviceType device_type) {
10   at::DispatchKey dispatch_key = get_autocast_dispatch_key_from_device_type(device_type);
11   return !c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
12 }
13 
set_autocast_enabled(at::DeviceType device_type,bool enabled)14 void set_autocast_enabled(at::DeviceType device_type, bool enabled) {
15   at::DispatchKey dispatch_key = get_autocast_dispatch_key_from_device_type(device_type);
16   c10::impl::tls_set_dispatch_key_excluded(dispatch_key, !enabled);
17 }
18 
19 namespace {
20 // Imitate Apex and cache some of the casts to streamline parameter reuse.
21 // Our heuristic is to cache lower_precision_fp casts of fp32 model weights (see cached_cast below).
22 //
23 // After discussion with @ezyang, the cache uses the following structure:
24 // The key is the fp32 source tensor's TensorImpl*, a proxy for a Tensor uuid that's
25 // unchanged across shallow copies.
26 // The value is a tuple with a weakref to the source tensor's TensorImpl as the first
27 // element and the casted tensor as the second element.
28 //
29 // The weakref keeps the source's TensorImpl from being deleted.  We need to because we're
30 // using the source TensorImpl* as the key.  If it were deleted, another random Tensor could
31 // be allocated whose TensorImpl* happened to have the same value.  This TensorImpl* would
32 // then mistakenly hit in cache:  a rare, intermittent, unpredictable bug.
33 //
34 // I'm not using the weak_intrusive_ptr as the key because it's more difficult to compare
35 // directly against incoming TensorImpl*s.
36 using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
37 using val_type = std::tuple<weakref_type, Tensor>;
38 
get_cached_casts()39 static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
40   static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
41   return cached_casts;
42 }
43 std::mutex cached_casts_mutex;
44 
45 
46 // nesting tracks the nesting depth of the Python-side context manager.
47 // When the autocast context manager exits to a nesting level that's outside
48 // any instance of autocast (which should occur at the end of each forward pass)
49 // it calls clear_cache() to ensure cached Tensors don't leak outside the autocasting region.
50 thread_local int nesting = 0;
51 
52 // The order of this array MUST exactly match the definition order of DeviceType
53 // in c10/core/DeviceType.h.
54 static_assert(
55     at::COMPILE_TIME_MAX_DEVICE_TYPES == 21,
56     "The definition of the default autocast data type per device backend doesn't match with the definition of the device type.");
57 thread_local std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
58     autocast_dtype = {
59         at::kBFloat16, // CPU
60         at::kHalf, // CUDA.
61         at::ScalarType::Undefined, // Reserved for explicit MKLDNN
62         at::ScalarType::Undefined, // OpenGL
63         at::ScalarType::Undefined, // OpenCL
64         at::ScalarType::Undefined, // IDEEP.
65         at::kHalf, // AMD HIP
66         at::ScalarType::Undefined, // FPGA
67         at::ScalarType::Undefined, // ONNX Runtime / Microsoft
68         at::kBFloat16, // XLA / TPU
69         at::ScalarType::Undefined, // Vulkan
70         at::ScalarType::Undefined, // Metal
71         at::kHalf, // XPU
72         at::kHalf, // MPS
73         at::ScalarType::Undefined, // Meta (tensors with no data)
74         at::kBFloat16, // HPU / HABANA
75         at::ScalarType::Undefined, // SX-Aurora / NEC
76         at::ScalarType::Undefined, // Lazy Tensors
77         at::kHalf, // Graphcore IPU
78         at::ScalarType::Undefined, // Meta training and inference devices
79         at::kHalf, // PrivateUse1 device
80 };
81 
82 // should we enabled the cache inside autocast.
83 thread_local bool cache_enabled = true;
84 
85 } // anonymous namespace
86 
clear_cache()87 void clear_cache() {
88   const std::lock_guard<std::mutex> lock(cached_casts_mutex);
89   get_cached_casts().clear();
90 }
91 
increment_nesting()92 int increment_nesting() {
93   return ++nesting;
94 }
95 
decrement_nesting()96 int decrement_nesting() {
97   return --nesting;
98 }
99 
get_autocast_dtype(at::DeviceType device_type)100 at::ScalarType get_autocast_dtype(at::DeviceType device_type) {
101   return autocast_dtype[static_cast<int>(device_type)];
102 }
103 
set_autocast_dtype(at::DeviceType device_type,at::ScalarType dtype)104 void set_autocast_dtype(at::DeviceType device_type, at::ScalarType dtype) {
105   autocast_dtype[static_cast<int>(device_type)] = dtype;
106 }
107 
is_autocast_cache_enabled()108 bool is_autocast_cache_enabled() {
109   return cache_enabled;
110 }
111 
set_autocast_cache_enabled(bool enabled)112 void set_autocast_cache_enabled(bool enabled) {
113   cache_enabled = enabled;
114 }
115 
116 // Overload to catch Tensor args
117 // TODO (possible optimization):
118 // Move cast_cache to an inline function in a header with cached_casts declared as
119 // extern thread_local in the header.
cached_cast(at::ScalarType to_type,const Tensor & arg,DeviceType device_type)120 Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_type) {
121   if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
122     // Heuristic:  Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
123     // See cached_casts declaration above for detailed strategy.
124     bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
125                          arg.scalar_type() == at::kFloat && arg.requires_grad() &&
126                          arg.is_leaf() && !arg.is_view() && cache_enabled &&
127                          !at::caching::is_cached_tensor(arg));
128 
129     if (can_try_cache) {
130       const std::lock_guard<std::mutex> lock(cached_casts_mutex);
131       auto it = get_cached_casts().find(arg.unsafeGetTensorImpl());
132       if (it != get_cached_casts().end()) {
133         return std::get<1>(it->second);
134       } else {
135         auto casted_arg = arg.to(to_type);
136         get_cached_casts().emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg});
137         return casted_arg;
138       }
139     } else {
140       return arg.to(to_type);
141     }
142   } else {
143     return arg;
144   }
145 }
146 
147 /*******************************
148 Banned functions
149 *******************************/
150 
binary_cross_entropy_banned(const Tensor &,const Tensor &,const std::optional<Tensor> &,int64_t)151 static Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const std::optional<Tensor>&, int64_t) {
152   AT_ERROR("torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n"
153            "Many models use a sigmoid layer right before the binary cross entropy layer.\n"
154            "In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n"
155            "or torch.nn.BCEWithLogitsLoss.  binary_cross_entropy_with_logits and BCEWithLogits are\n"
156            "safe to autocast.");
157 }
158 
159 namespace {
160 
161 /*****************************************
162 Explicit registration for out-of-place ops
163 *****************************************/
164 
TORCH_LIBRARY_IMPL(_,Autocast,m)165 TORCH_LIBRARY_IMPL(_, Autocast, m) {
166   m.fallback(torch::CppFunction::makeFallthrough());
167 }
168 
TORCH_LIBRARY_IMPL(aten,Autocast,m)169 TORCH_LIBRARY_IMPL(aten, Autocast, m) {
170   // lower_precision_fp
171 #define _KERNEL_CUDA_LOW_PRECISION_FP(...) \
172   KERNEL_CUDA(__VA_ARGS__, lower_precision_fp)
173 
174   AT_FORALL_LOWER_PRECISION_FP(_KERNEL_CUDA_LOW_PRECISION_FP)
175   KERNEL_CUDA(cudnn_convolution, lower_precision_fp)
176   KERNEL_CUDA(cudnn_convolution_transpose, lower_precision_fp)
177 
178   // fp32
179 #define _KERNEL_CUDA_FP32(...) KERNEL_CUDA(__VA_ARGS__, fp32)
180 
181   AT_FORALL_FP32(_KERNEL_CUDA_FP32)
182 
183   // fp32_set_opt_dtype
184 #define _KERNEL_CUDA_FP32_SET_OPT_DTYPE(...) \
185   KERNEL_CUDA(__VA_ARGS__, fp32_set_opt_dtype)
186 
187   AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_CUDA_FP32_SET_OPT_DTYPE)
188   // commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even
189   // when autocasting.
190   // KERNEL_CUDA(norm, ScalarOpt_dtype, fp32_set_opt_dtype)
191   // KERNEL_CUDA(norm, ScalarOpt_dim_dtype, fp32_set_opt_dtype)
192   // KERNEL_CUDA(norm, names_ScalarOpt_dim_dtype, fp32_set_opt_dtype)
193 
194   // fp32_append_dtype
195   // The fp32_append_dtype wrapper overrides implicit promotion behavior.
196   // norm does not implicitly promote, but be aware when adding new ops to this policy.
197   AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(
198       KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA)
199 
200   // promote
201 #define _KERNEL_CUDA_PROMOTE(...) KERNEL_CUDA(__VA_ARGS__, promote)
202 
203   AT_FORALL_PROMOTE(_KERNEL_CUDA_PROMOTE)
204 
205   m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
206          TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
207 }
208 
TORCH_LIBRARY_IMPL(_,AutocastMPS,m)209 TORCH_LIBRARY_IMPL(_, AutocastMPS, m) {
210   m.fallback(torch::CppFunction::makeFallthrough());
211 }
212 
TORCH_LIBRARY_IMPL(aten,AutocastMPS,m)213 TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) {
214   // lower_precision_fp
215   KERNEL_MPS2(_convolution, deprecated, lower_precision_fp)
216   KERNEL_MPS(_convolution, lower_precision_fp)
217   KERNEL_MPS(conv1d, lower_precision_fp)
218   KERNEL_MPS(conv2d, lower_precision_fp)
219   KERNEL_MPS(conv_tbc, lower_precision_fp)
220   KERNEL_MPS(conv_transpose1d, lower_precision_fp)
221   KERNEL_MPS2(conv_transpose2d, input, lower_precision_fp)
222   KERNEL_MPS(convolution, lower_precision_fp)
223   KERNEL_MPS(_mps_convolution, lower_precision_fp)
224   KERNEL_MPS(prelu, lower_precision_fp)
225   KERNEL_MPS(addmm, lower_precision_fp)
226   KERNEL_MPS(addmv, lower_precision_fp)
227   KERNEL_MPS(addr, lower_precision_fp)
228   KERNEL_MPS(matmul, lower_precision_fp)
229   KERNEL_MPS(einsum, lower_precision_fp)
230   KERNEL_MPS(mm, lower_precision_fp)
231   KERNEL_MPS(mv, lower_precision_fp)
232   KERNEL_MPS(linear, lower_precision_fp)
233   KERNEL_MPS(addbmm, lower_precision_fp)
234   KERNEL_MPS(baddbmm, lower_precision_fp)
235   KERNEL_MPS(bmm, lower_precision_fp)
236   KERNEL_MPS(chain_matmul, lower_precision_fp)
237   KERNEL_MPS(linalg_multi_dot, lower_precision_fp)
238   KERNEL_MPS(lstm_cell, lower_precision_fp)
239 
240   // fp32
241   KERNEL_MPS(acos, fp32)
242   KERNEL_MPS(asin, fp32)
243   KERNEL_MPS(cosh, fp32)
244   KERNEL_MPS(erfinv, fp32)
245   KERNEL_MPS(exp, fp32)
246   KERNEL_MPS(expm1, fp32)
247   KERNEL_MPS(log, fp32)
248   KERNEL_MPS(log10, fp32)
249   KERNEL_MPS(log2, fp32)
250   KERNEL_MPS(log1p, fp32)
251   KERNEL_MPS(reciprocal, fp32)
252   KERNEL_MPS(rsqrt, fp32)
253   KERNEL_MPS(sinh, fp32)
254   KERNEL_MPS(tan, fp32)
255   KERNEL_MPS2(pow, Tensor_Scalar, fp32)
256   KERNEL_MPS2(pow, Tensor_Tensor, fp32)
257   KERNEL_MPS2(pow, Scalar, fp32)
258   KERNEL_MPS(softplus, fp32)
259   KERNEL_MPS(layer_norm, fp32)
260   KERNEL_MPS(native_layer_norm, fp32)
261   KERNEL_MPS(group_norm, fp32)
262   KERNEL_MPS2(frobenius_norm, dim, fp32)
263   KERNEL_MPS(nuclear_norm, fp32)
264   KERNEL_MPS2(nuclear_norm, dim, fp32)
265   KERNEL_MPS(batch_norm, fp32)
266   KERNEL_MPS(cosine_similarity, fp32)
267   KERNEL_MPS(poisson_nll_loss, fp32)
268   KERNEL_MPS(cosine_embedding_loss, fp32)
269   KERNEL_MPS(nll_loss, fp32)
270   KERNEL_MPS(nll_loss2d, fp32)
271   KERNEL_MPS(hinge_embedding_loss, fp32)
272   KERNEL_MPS(kl_div, fp32)
273   KERNEL_MPS(l1_loss, fp32)
274   KERNEL_MPS(smooth_l1_loss, fp32)
275   KERNEL_MPS(huber_loss, fp32)
276   KERNEL_MPS(mse_loss, fp32)
277   KERNEL_MPS(margin_ranking_loss, fp32)
278   KERNEL_MPS(multilabel_margin_loss, fp32)
279   KERNEL_MPS(soft_margin_loss, fp32)
280   KERNEL_MPS(triplet_margin_loss, fp32)
281   KERNEL_MPS(multi_margin_loss, fp32)
282   KERNEL_MPS(binary_cross_entropy_with_logits, fp32)
283   KERNEL_MPS(dist, fp32)
284   KERNEL_MPS(pdist, fp32)
285   KERNEL_MPS(cdist, fp32)
286   KERNEL_MPS(renorm, fp32)
287   KERNEL_MPS(logsumexp, fp32)
288 
289   // fp32_set_opt_dtype
290   KERNEL_MPS(prod, fp32)
291   KERNEL_MPS2(prod, dim_int, fp32)
292   KERNEL_MPS2(prod, dim_Dimname, fp32)
293   KERNEL_MPS2(softmax, int, fp32)
294   KERNEL_MPS2(softmax, Dimname, fp32)
295   KERNEL_MPS2(log_softmax, int, fp32)
296   KERNEL_MPS2(log_softmax, Dimname, fp32)
297   KERNEL_MPS(cumprod, fp32)
298   KERNEL_MPS2(cumprod, dimname, fp32)
299   KERNEL_MPS(cumsum, fp32)
300   KERNEL_MPS2(cumsum, dimname, fp32)
301   KERNEL_MPS(linalg_vector_norm, fp32)
302   KERNEL_MPS(linalg_matrix_norm, fp32)
303   KERNEL_MPS2(linalg_matrix_norm, str_ord, fp32)
304   KERNEL_MPS(sum, fp32)
305   KERNEL_MPS2(sum, dim_IntList, fp32)
306   KERNEL_MPS2(sum, dim_DimnameList, fp32)
307   //
308   // promote
309   KERNEL_MPS(addcdiv, promote)
310   KERNEL_MPS(addcmul, promote)
311   KERNEL_MPS(atan2, promote)
312   KERNEL_MPS(bilinear, promote)
313   KERNEL_MPS(cross, promote)
314   KERNEL_MPS(dot, promote)
315   KERNEL_MPS(grid_sampler, promote)
316   KERNEL_MPS(index_put, promote)
317   KERNEL_MPS(tensordot, promote)
318   KERNEL_MPS(scatter_add, promote)
319 }
320 
TORCH_LIBRARY_IMPL(_,AutocastCPU,m)321 TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
322   m.fallback(torch::CppFunction::makeFallthrough());
323 }
324 
325 
TORCH_LIBRARY_IMPL(aten,AutocastCPU,m)326 TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
327   // lower_precision_fp cast policy
328   KERNEL_CPU(conv1d, lower_precision_fp)
329   KERNEL_CPU(conv1d, padding, lower_precision_fp)
330   KERNEL_CPU(conv2d, lower_precision_fp)
331   KERNEL_CPU(conv2d, padding, lower_precision_fp)
332   KERNEL_CPU(conv3d, lower_precision_fp)
333   KERNEL_CPU(conv3d, padding, lower_precision_fp)
334   KERNEL_CPU(bmm, lower_precision_fp)
335   KERNEL_CPU(mm, lower_precision_fp)
336   KERNEL_CPU(linalg_vecdot, lower_precision_fp)
337   KERNEL_CPU(baddbmm, lower_precision_fp)
338   KERNEL_CPU(addmm, lower_precision_fp)
339   KERNEL_CPU(addbmm, lower_precision_fp)
340   KERNEL_CPU(linear, lower_precision_fp)
341   KERNEL_CPU(_convolution, deprecated, lower_precision_fp)
342   KERNEL_CPU(matmul, lower_precision_fp)
343   KERNEL_CPU(conv_tbc, lower_precision_fp)
344   KERNEL_CPU(mkldnn_rnn_layer, lower_precision_fp)
345   KERNEL_CPU(conv_transpose1d, lower_precision_fp)
346   KERNEL_CPU(conv_transpose2d, input, lower_precision_fp)
347   KERNEL_CPU(conv_transpose3d, input, lower_precision_fp)
348   KERNEL_CPU(prelu, lower_precision_fp)
349   KERNEL_CPU(scaled_dot_product_attention, lower_precision_fp)
350   KERNEL_CPU(_native_multi_head_attention, lower_precision_fp)
351 
352   // fp32 cast policy
353   KERNEL_CPU(avg_pool3d, fp32)
354   KERNEL_CPU(binary_cross_entropy, fp32)
355   KERNEL_CPU(grid_sampler, fp32)
356   KERNEL_CPU(polar, fp32)
357   KERNEL_CPU(prod, fp32)
358   KERNEL_CPU(prod, dim_int, fp32)
359   KERNEL_CPU(prod, dim_Dimname, fp32)
360   KERNEL_CPU(quantile, fp32)
361   KERNEL_CPU(quantile, scalar, fp32)
362   KERNEL_CPU(nanquantile, fp32)
363   KERNEL_CPU(nanquantile, scalar, fp32)
364   KERNEL_CPU(stft, fp32)
365   KERNEL_CPU(stft, center, fp32)
366   KERNEL_CPU(cdist, fp32)
367   KERNEL_CPU(grid_sampler_2d, fp32)
368   KERNEL_CPU(_grid_sampler_2d_cpu_fallback, fp32)
369   KERNEL_CPU(grid_sampler_3d, fp32)
370   KERNEL_CPU(trace, fp32)
371   KERNEL_CPU(view_as_complex, fp32)
372   KERNEL_CPU(cholesky, fp32)
373   KERNEL_CPU(cholesky_inverse, fp32)
374   KERNEL_CPU(cholesky_solve, fp32)
375   KERNEL_CPU(inverse, fp32)
376   KERNEL_CPU(lu_solve, fp32)
377   KERNEL_CPU(orgqr, fp32)
378   KERNEL_CPU(ormqr, fp32)
379   KERNEL_CPU(pinverse, fp32)
380   KERNEL_CPU(max_pool3d, fp32)
381   KERNEL_CPU(max_unpool2d, fp32)
382   KERNEL_CPU(max_unpool3d, fp32)
383   KERNEL_CPU(adaptive_avg_pool3d, fp32)
384   KERNEL_CPU(reflection_pad1d, fp32)
385   KERNEL_CPU(reflection_pad2d, fp32)
386   KERNEL_CPU(replication_pad1d, fp32)
387   KERNEL_CPU(replication_pad2d, fp32)
388   KERNEL_CPU(replication_pad3d, fp32)
389   KERNEL_CPU(mse_loss, fp32)
390   KERNEL_CPU(cosine_embedding_loss, fp32)
391   KERNEL_CPU(nll_loss, fp32)
392   KERNEL_CPU(nll_loss2d, fp32)
393   KERNEL_CPU(hinge_embedding_loss, fp32)
394   KERNEL_CPU(poisson_nll_loss, fp32)
395   KERNEL_CPU(smooth_l1_loss, fp32)
396   KERNEL_CPU(cross_entropy_loss, fp32)
397   KERNEL_CPU(l1_loss, fp32)
398   KERNEL_CPU(huber_loss, fp32)
399   KERNEL_CPU(margin_ranking_loss, fp32)
400   KERNEL_CPU(soft_margin_loss, fp32)
401   KERNEL_CPU(triplet_margin_loss, fp32)
402   KERNEL_CPU(multi_margin_loss, fp32)
403   KERNEL_CPU(ctc_loss, IntList, fp32)
404   KERNEL_CPU(ctc_loss, Tensor, fp32)
405   KERNEL_CPU(kl_div, fp32)
406   KERNEL_CPU(multilabel_margin_loss, fp32)
407   KERNEL_CPU(binary_cross_entropy_with_logits, fp32)
408   KERNEL_CPU(fft_fft, fp32)
409   KERNEL_CPU(fft_ifft, fp32)
410   KERNEL_CPU(fft_fft2, fp32)
411   KERNEL_CPU(fft_ifft2, fp32)
412   KERNEL_CPU(fft_fftn, fp32)
413   KERNEL_CPU(fft_ifftn, fp32)
414   KERNEL_CPU(fft_rfft, fp32)
415   KERNEL_CPU(fft_irfft, fp32)
416   KERNEL_CPU(fft_rfft2, fp32)
417   KERNEL_CPU(fft_irfft2, fp32)
418   KERNEL_CPU(fft_rfftn, fp32)
419   KERNEL_CPU(fft_irfftn, fp32)
420   KERNEL_CPU(fft_hfft, fp32)
421   KERNEL_CPU(fft_ihfft, fp32)
422   KERNEL_CPU(linalg_cond, fp32)
423   KERNEL_CPU(linalg_cond, p_str, fp32)
424   KERNEL_CPU(linalg_matrix_rank, fp32)
425   KERNEL_CPU(linalg_matrix_rank, tol_tensor, fp32)
426   KERNEL_CPU(linalg_matrix_rank, atol_rtol_tensor, fp32)
427   KERNEL_CPU(linalg_matrix_rank, atol_rtol_float, fp32)
428   KERNEL_CPU(linalg_solve, fp32)
429   KERNEL_CPU(linalg_cholesky, fp32)
430   KERNEL_CPU(linalg_svdvals, fp32)
431   KERNEL_CPU(linalg_eigvals, fp32)
432   KERNEL_CPU(linalg_eigvalsh, fp32)
433   KERNEL_CPU(linalg_inv, fp32)
434   KERNEL_CPU(linalg_householder_product, fp32)
435   KERNEL_CPU(linalg_tensorinv, fp32)
436   KERNEL_CPU(linalg_tensorsolve, fp32)
437   KERNEL_CPU(fake_quantize_per_tensor_affine, fp32)
438   KERNEL_CPU(geqrf, fp32)
439   KERNEL_CPU(_lu_with_info, fp32)
440   KERNEL_CPU(qr, fp32)
441   KERNEL_CPU(svd, fp32)
442   KERNEL_CPU(triangular_solve, fp32)
443   KERNEL_CPU(fractional_max_pool2d, fp32)
444   KERNEL_CPU(fractional_max_pool3d, fp32)
445   KERNEL_CPU(adaptive_max_pool3d, fp32)
446   KERNEL_CPU(multilabel_margin_loss_forward, fp32)
447   KERNEL_CPU(linalg_qr, fp32)
448   KERNEL_CPU(linalg_cholesky_ex, fp32)
449   KERNEL_CPU(linalg_svd, fp32)
450   KERNEL_CPU(linalg_eig, fp32)
451   KERNEL_CPU(linalg_eigh, fp32)
452   KERNEL_CPU(linalg_lstsq, fp32)
453   KERNEL_CPU(linalg_inv_ex, fp32)
454 
455   // promote
456   KERNEL_CPU(stack, promote)
457   KERNEL_CPU(cat, promote)
458   KERNEL_CPU(index_copy, promote)
459   KERNEL_CPU(index_copy, dimname, promote)
460 
461 }
462 
TORCH_LIBRARY_IMPL(_,AutocastXPU,m)463 TORCH_LIBRARY_IMPL(_, AutocastXPU, m) {
464   m.fallback(torch::CppFunction::makeFallthrough());
465 }
466 
TORCH_LIBRARY_IMPL(aten,AutocastXPU,m)467 TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) {
468   // lower_precision_fp
469 #define _KERNEL_XPU_LOW_PRECISION_FP(...) \
470   KERNEL_XPU(__VA_ARGS__, lower_precision_fp)
471 
472   AT_FORALL_LOWER_PRECISION_FP(_KERNEL_XPU_LOW_PRECISION_FP)
473 
474   // fp32
475 #define _KERNEL_XPU_FP32(...) KERNEL_XPU(__VA_ARGS__, fp32)
476 
477   AT_FORALL_FP32(_KERNEL_XPU_FP32)
478 
479   // fp32_set_opt_dtype
480 #define _KERNEL_XPU_FP32_SET_OPT_DTYPE(...) \
481   KERNEL_XPU(__VA_ARGS__, fp32_set_opt_dtype)
482 
483   AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_XPU_FP32_SET_OPT_DTYPE)
484 
485   // fp32_append_dtype
486   // The fp32_append_dtype wrapper overrides implicit promotion behavior.
487   // norm does not implicitly promote, but be aware when adding new ops to this policy.
488   AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(
489       KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU)
490 
491   // promote
492 #define _KERNEL_XPU_PROMOTE(...) KERNEL_XPU(__VA_ARGS__, promote)
493 
494   AT_FORALL_PROMOTE(_KERNEL_XPU_PROMOTE)
495 
496   m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
497          TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
498 }
499 
500 } // namespace
501 } // namespace at::autocast
502