xref: /aosp_15_r20/external/pytorch/aten/src/ATen/Context.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/BlasBackend.h>
4 #include <ATen/CPUGeneratorImpl.h>
5 #include <ATen/DeviceAccelerator.h>
6 #include <ATen/LinalgBackend.h>
7 #include <ATen/core/ATenGeneral.h>
8 #include <ATen/core/DeprecatedTypeProperties.h>
9 #include <ATen/core/Generator.h>
10 #include <ATen/core/LegacyTypeDispatch.h>
11 #include <ATen/detail/AcceleratorHooksInterface.h>
12 #include <ATen/detail/CUDAHooksInterface.h>
13 #include <ATen/detail/HIPHooksInterface.h>
14 #include <ATen/detail/IPUHooksInterface.h>
15 #include <ATen/detail/MAIAHooksInterface.h>
16 #include <ATen/detail/MPSHooksInterface.h>
17 #include <ATen/detail/MTIAHooksInterface.h>
18 #include <ATen/detail/PrivateUse1HooksInterface.h>
19 #include <ATen/detail/XPUHooksInterface.h>
20 #include <c10/core/QEngine.h>
21 #include <c10/core/impl/DeviceGuardImplInterface.h>
22 #include <c10/util/CallOnce.h>
23 #include <c10/util/Exception.h>
24 #include <c10/util/env.h>
25 #include <c10/util/irange.h>
26 
27 #include <cstdint>
28 #include <mutex>
29 
30 namespace at {
31 
32 class Tensor;
33 
34 enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
35 
36 class TORCH_API Context {
37  public:
38   Context();
39 
defaultGenerator(Device device)40   const Generator& defaultGenerator(Device device) {
41     c10::DeviceType device_type = device.type();
42     initCUDAIfNeeded(device_type);
43     initHIPIfNeeded(device_type);
44     if (device_type == at::kCPU) {
45       return at::detail::getDefaultCPUGenerator();
46     } else if (device_type == at::kCUDA) {
47       return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
48     } else if (device_type == at::kMPS) {
49       return at::detail::getMPSHooks().getDefaultMPSGenerator();
50     } else if (device_type == at::kXPU) {
51       return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
52     } else if (device_type == at::kIPU) {
53       return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
54     } else if (device_type == at::kPrivateUse1) {
55       return at::GetPrivateUse1HooksInterface()->getDefaultGenerator(
56           device.index());
57     } else {
58       AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
59     }
60   }
61   const AcceleratorHooksInterface& getAcceleratorHooksInterface(
62       std::optional<c10::DeviceType> opt_device_type = c10::nullopt) {
63     c10::DeviceType device_type = opt_device_type.has_value()
64         ? opt_device_type.value()
65         : at::getAccelerator(true).value();
66     if (device_type == at::kCUDA) {
67       return at::detail::getCUDAHooks();
68     } else if (device_type == at::kMPS) {
69       return at::detail::getMPSHooks();
70     } else if (device_type == at::kPrivateUse1) {
71       return at::detail::getPrivateUse1Hooks();
72     } else if (device_type == at::kMTIA) {
73       return at::detail::getMTIAHooks();
74     } else {
75       AT_ERROR(
76           c10::DeviceTypeName(device_type), " device type not an accelerator.");
77     }
78   }
getDeviceFromPtr(void * data,c10::DeviceType device_type)79   Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
80     initCUDAIfNeeded(device_type);
81     initHIPIfNeeded(device_type);
82     initXPUIfNeeded(device_type);
83     if (device_type == at::kCPU) {
84       return c10::DeviceType::CPU;
85     } else if (device_type == at::kCUDA) {
86       return at::detail::getCUDAHooks().getDeviceFromPtr(data);
87     } else if (device_type == at::kXPU) {
88       return at::detail::getXPUHooks().getDeviceFromPtr(data);
89     } else if (device_type == at::kPrivateUse1) {
90       return at::GetPrivateUse1HooksInterface()->getDeviceFromPtr(data);
91     } else {
92       AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
93     }
94   }
isPinnedPtr(const void * data)95   static bool isPinnedPtr(const void* data) {
96     return detail::getCUDAHooks().isPinnedPtr(data);
97   }
98   static bool hasOpenMP();
99   static bool hasMKL();
100   static bool hasLAPACK();
101   static bool hasMKLDNN();
hasMAGMA()102   static bool hasMAGMA() {
103     return detail::getCUDAHooks().hasMAGMA();
104   }
hasCUDA()105   static bool hasCUDA() {
106     return detail::getCUDAHooks().hasCUDA();
107   }
hasMTIA()108   static bool hasMTIA() {
109     return detail::getMTIAHooks().hasMTIA();
110   }
hasCUDART()111   static bool hasCUDART() {
112     return detail::getCUDAHooks().hasCUDART();
113   }
versionCUDART()114   static long versionCUDART() {
115     return detail::getCUDAHooks().versionCUDART();
116   }
hasCuDNN()117   static bool hasCuDNN() {
118     return detail::getCUDAHooks().hasCuDNN();
119   }
versionCuDNN()120   static long versionCuDNN() {
121     return detail::getCUDAHooks().versionCuDNN();
122   }
hasCuSOLVER()123   static bool hasCuSOLVER() {
124     return detail::getCUDAHooks().hasCuSOLVER();
125   }
hasCuBLASLt()126   static bool hasCuBLASLt() {
127     return detail::getCUDAHooks().hasCuBLASLt();
128   }
hasHIP()129   static bool hasHIP() {
130     return detail::getHIPHooks().hasHIP();
131   }
hasMPS()132   static bool hasMPS() {
133     return detail::getMPSHooks().hasMPS();
134   }
hasIPU()135   static bool hasIPU() {
136     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
137   }
hasXLA()138   static bool hasXLA() {
139     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
140   }
hasXPU()141   static bool hasXPU() {
142     return detail::getXPUHooks().hasXPU();
143   }
hasLazy()144   static bool hasLazy() {
145     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
146   }
hasMAIA()147   static bool hasMAIA() {
148     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
149   }
150   // defined in header so that getNonVariableType has ability to inline
151   // call_once check. getNonVariableType is called fairly frequently
lazyInitCUDA()152   void lazyInitCUDA() {
153     c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
154   }
lazyInitHIP()155   void lazyInitHIP() {
156     c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
157   }
lazyInitXPU()158   void lazyInitXPU() {
159     c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
160   }
lazyInitMTIA()161   void lazyInitMTIA() {
162     c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
163   }
lazyInitPrivateUse1()164   void lazyInitPrivateUse1() {
165     c10::call_once(thp_init, [&] {
166       if (isPrivateUse1HooksRegistered()) {
167         at::GetPrivateUse1HooksInterface()->initPrivateUse1();
168       }
169     });
170   }
getNVRTC()171   static const at::cuda::NVRTC& getNVRTC() {
172     return detail::getCUDAHooks().nvrtc();
173   }
174 
175   static bool setFlushDenormal(bool on);
176 
177   // NB: This method is *purely* whether or not a user requested
178   // that CuDNN was enabled, it doesn't actually say anything about
179   // whether or not CuDNN is actually usable.  Use cudnn_is_acceptable
180   // to test this instead
181   bool userEnabledCuDNN() const;
182   void setUserEnabledCuDNN(bool e);
183   bool userEnabledMkldnn() const;
184   void setUserEnabledMkldnn(bool e);
185   bool benchmarkCuDNN() const;
186   void setBenchmarkCuDNN(bool);
187   int benchmarkLimitCuDNN() const;
188   void setBenchmarkLimitCuDNN(int);
189   bool deterministicCuDNN() const;
190   void setDeterministicCuDNN(bool);
191   bool userEnabledNNPACK() const;
192   void setUserEnabledNNPACK(bool e);
193 
194   // Note [Disabling Fused SDP Kernels]
195   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
196   // Flash and Memory Efficient SDP kernels are enabled by default.
197   // However, they can be disabled by setting
198   // at::globalContext().setUserEnabledFlashSDP(false) flag.
199   // This is useful for debugging purposes. For example, if you want to
200   // compare the performance of the flash SDP kernels with the unfused
201   // kernel, you can disable the flash SDP kernels. By disabling
202   // the math SDP kernel, you can force your code to use flash kernels.
203   // The math SDP kernel can be disabled by setting
204   // at::globalContext().setUserEnabledMathSDP(false) flag.
205   void setSDPUseFlash(bool);
206   bool userEnabledFlashSDP() const;
207 
208   void setSDPUseMemEfficient(bool);
209   bool userEnabledMemEfficientSDP() const;
210 
211   void setSDPUseMath(bool);
212   bool userEnabledMathSDP() const;
213 
214   void setSDPUseCuDNN(bool);
215   bool userEnabledCuDNNSDP() const;
216 
217   at::LinalgBackend linalgPreferredBackend() const;
218   void setLinalgPreferredBackend(at::LinalgBackend);
219 
220   at::BlasBackend blasPreferredBackend();
221   void setBlasPreferredBackend(at::BlasBackend);
222 
223   // Note [Enabling Deterministic Operations]
224   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
225   // Operations in PyTorch that normally act nondeterministically, but have an
226   // alternate deterministic implementation, should satisfy the following
227   // requirements:
228   //
229   // * Include this comment: "See Note [Enabling Deterministic Operations]"
230   //
231   // * Check the value of `at::globalContext().deterministicAlgorithms()` to
232   // toggle
233   //   between nondeterministic and deterministic implementations.
234   //
235   // * Have an entry in the list of PyTorch operations that toggle between
236   // nondeterministic
237   //   and deterministic implementations, in the docstring of
238   //   `use_deterministic_algorithms()` in torch/__init__.py
239   //
240   // `example_func()` below shows an example of toggling between
241   // nondeterministic and deterministic implementations:
242   //
243   //    void example_func() {
244   //      // See Note [Enabling Deterministic Operations]
245   //      if (at::globalContext().deterministicAlgorithms()) {
246   //        example_func_deterministic();
247   //      } else {
248   //        example_func_nondeterministic();
249   //      }
250   //    }
251 
252   bool deterministicAlgorithms() const;
253   bool deterministicAlgorithmsWarnOnly() const;
254   void setDeterministicAlgorithms(bool, bool);
255   bool deterministicFillUninitializedMemory() const;
256   void setDeterministicFillUninitializedMemory(bool);
257 
258   // Note [Writing Nondeterministic Operations]
259   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
260   // Operations in PyTorch that act nondeterministically and do not have an
261   // alternate deterministic implementation should satisfy the following
262   // requirements:
263   //
264   // * Include this comment: "See Note [Writing Nondeterministic Operations]"
265   //
266   // * Include a comment explaining why the operation is nondeterministic.
267   //
268   // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
269   //   of the time, this should be accomplished by calling
270   //   `at::globalContext().alertNotDeterminstic()`.  However, if the
271   //   nondeterministic behavior is caused by the CuBLAS workspace
272   //   configuration in CUDA >= 10.2,
273   //   `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
274   //   called instead (in this case, a comment explaining why the operation is
275   //   nondeterministic is not necessary). See below for details on these
276   //   methods.
277   //
278   // * Have an entry in the list of nondeterministic PyTorch operations in the
279   //   docstring of `use_deterministic_algorithms()` in torch/__init__.py
280   //
281   // * Have a test function in `test/test_torch.py` whose name begins with
282   //   `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
283   //   configuration is the reason for nondeterminism, the operation should be
284   //   included in the `test_cublas_config_nondeterministic_alert` test. Any new
285   //   tests should ideally follow a pattern similar to the existing ones.
286   //
287   // `example_func()` below shows an example of the comments and error-throwing
288   // code for a nondeterministic operation:
289   //
290   //    void example_func() {
291   //      // See Note [Writing Nondeterministic Operations]
292   //      // Nondeterministic because <reason>
293   //      at::globalContext().alertNondeterministic("example_func");
294   //      ...
295   //    }
296 
297   // Throws an error if `Context::deterministicAlgorithms()` is true
298   static void alertNotDeterministic(c10::string_view const& caller);
299 
300   // Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
301   // >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
302   // ":4096:8". For more details:
303   // https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
304   void alertCuBLASConfigNotDeterministic() const;
305 
306   void setFloat32MatmulPrecision(const std::string& s);
307   bool allowTF32CuDNN() const;
308   void setAllowTF32CuDNN(bool);
309   bool allowTF32CuBLAS() const;
310   void setAllowTF32CuBLAS(bool);
311   Float32MatmulPrecision float32MatmulPrecision() const;
312   void setFloat32MatmulPrecision(Float32MatmulPrecision p);
313   bool allowFP16ReductionCuBLAS() const;
314   void setAllowFP16ReductionCuBLAS(bool);
315   bool allowBF16ReductionCuBLAS() const;
316   void setAllowBF16ReductionCuBLAS(bool);
317   at::QEngine qEngine() const;
318   void setQEngine(at::QEngine e);
319   static const std::vector<at::QEngine>& supportedQEngines();
320   static bool isXNNPACKAvailable();
321   void setCheckSparseTensorInvariants(bool e);
322   bool checkSparseTensorInvariants() const;
323   // This method is used to release the original weight after pre-packing.
324   // It should be called once before loading/running the model.
325   // NB: By default it is set to true for mobile builds.
326   void setReleaseWeightsWhenPrepacking(bool e);
327   bool releaseWeightsWhenPrepacking() const;
328 
329   void setDisplayVmapFallbackWarnings(bool enabled);
330   bool areVmapFallbackWarningsEnabled() const;
331 
332   void setDefaultMobileCPUAllocator();
333   void unsetDefaultMobileCPUAllocator();
334   bool allowFP16ReductionCPU() const;
335   void setAllowFP16ReductionCPU(bool);
336 
337  private:
initCUDAIfNeeded(c10::DeviceType p)338   void initCUDAIfNeeded(c10::DeviceType p) {
339     if (p == c10::DeviceType::CUDA) {
340       lazyInitCUDA();
341     }
342   }
initHIPIfNeeded(c10::DeviceType p)343   void initHIPIfNeeded(c10::DeviceType p) {
344     if (p == c10::DeviceType::HIP) {
345       lazyInitHIP();
346     }
347   }
initXPUIfNeeded(c10::DeviceType p)348   void initXPUIfNeeded(c10::DeviceType p) {
349     if (p == c10::DeviceType::XPU) {
350       lazyInitXPU();
351     }
352   }
353   static bool checkCuBLASConfigDeterministic();
354   c10::once_flag thc_init;
355   c10::once_flag thh_init;
356   c10::once_flag thx_init;
357   c10::once_flag th_mtia_init;
358   c10::once_flag thp_init;
359   bool enabled_cudnn = true;
360   bool deterministic_cudnn = false;
361   bool _deterministic_algorithms = false;
362   bool _deterministic_algorithms_warn_only = false;
363   bool _deterministic_fill_uninitialized_memory = true;
364   bool enabled_flashSDP = true;
365   bool enabled_mem_efficientSDP = true;
366   bool enabled_mathSDP = true;
367   bool enabled_cudnnSDP = false;
368 #ifdef USE_ROCM
369   bool benchmark_cudnn = true;
370 #else
371   bool benchmark_cudnn = false;
372 #endif
373   Float32MatmulPrecision float32_matmul_precision =
374       c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
375       ? at::Float32MatmulPrecision::HIGH
376       : at::Float32MatmulPrecision::HIGHEST;
377   int benchmark_limit_cudnn = 10;
378   bool allow_tf32_cudnn = true;
379   bool allow_fp16_reduction_cublas = true;
380   bool allow_bf16_reduction_cublas = true;
381   bool enabled_mkldnn = true;
382   bool enabled_nnpack = true;
383   at::LinalgBackend linalg_preferred_backend =
384       c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true
385       ? at::LinalgBackend::Cusolver
386       : at::LinalgBackend::Default;
387   at::BlasBackend blas_preferred_backend =
388 #ifdef USE_ROCM
389       (c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false)
390 #else
391       (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true)
392 #endif
393       ? at::BlasBackend::Cublaslt
394       : at::BlasBackend::Cublas;
395 #ifdef C10_MOBILE
396   bool release_original_weights = true;
397 #else
398   bool release_original_weights = false;
399 #endif
400   bool display_vmap_fallback_warnings_ = false;
401   std::optional<at::QEngine> quantized_engine = c10::nullopt;
402   bool enable_sparse_tensor_invariant_checks = false;
403   bool allow_fp16_reduction_cpu = false;
404 
405   Allocator* prev_allocator_ptr_{nullptr};
406 };
407 
408 TORCH_API Context& globalContext();
409 
init()410 static inline void init() {
411   globalContext();
412 }
413 
414 TORCH_API Allocator* getCPUAllocator();
415 
getDeprecatedTypeProperties(Backend p,ScalarType s)416 static inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
417     Backend p,
418     ScalarType s) {
419   return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
420       p, s);
421 }
422 
CPU(ScalarType s)423 static inline DeprecatedTypeProperties& CPU(ScalarType s) {
424   return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
425       Backend::CPU, s);
426 }
427 
CUDA(ScalarType s)428 static inline DeprecatedTypeProperties& CUDA(ScalarType s) {
429   return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
430       Backend::CUDA, s);
431 }
432 
HIP(ScalarType s)433 static inline DeprecatedTypeProperties& HIP(ScalarType s) {
434   return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
435       Backend::HIP, s);
436 }
437 
MPS(ScalarType s)438 static inline DeprecatedTypeProperties& MPS(ScalarType s) {
439   return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
440       Backend::MPS, s);
441 }
442 
hasCUDA()443 static inline bool hasCUDA() {
444   return globalContext().hasCUDA();
445 }
446 
hasMTIA()447 static inline bool hasMTIA() {
448   return globalContext().hasMTIA();
449 }
450 
hasHIP()451 static inline bool hasHIP() {
452   return globalContext().hasHIP();
453 }
454 
hasIPU()455 static inline bool hasIPU() {
456   return globalContext().hasIPU();
457 }
458 
hasXLA()459 static inline bool hasXLA() {
460   return globalContext().hasXLA();
461 }
462 
hasMPS()463 static inline bool hasMPS() {
464   return globalContext().hasMPS();
465 }
466 
hasMAIA()467 static inline bool hasMAIA() {
468   return globalContext().hasMAIA();
469 }
470 
hasXPU()471 static inline bool hasXPU() {
472   return globalContext().hasXPU();
473 }
474 
475 // Despite its name, this function returns the number of *CUDA* GPUs.
getNumGPUs()476 static inline size_t getNumGPUs() {
477   // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
478   // FUNCTION.  If you are interested in interrogating the number of
479   // devices for a specific device type, add that function to the
480   // relevant library (e.g., similar to at::cuda::device_count())
481   if (hasCUDA() && hasHIP()) {
482     throw std::runtime_error(
483         "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
484         "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
485         "means HIP.  Rebuild PyTorch with one or the other disabled.");
486   } else if (hasCUDA()) {
487     return detail::getCUDAHooks().getNumGPUs();
488   } else if (hasHIP()) {
489     return detail::getHIPHooks().getNumGPUs();
490   } else {
491     return 0;
492   }
493 }
494 
hasOpenMP()495 static inline bool hasOpenMP() {
496   return globalContext().hasOpenMP();
497 }
498 
hasMKL()499 static inline bool hasMKL() {
500   return globalContext().hasMKL();
501 }
502 
hasLAPACK()503 static inline bool hasLAPACK() {
504   return globalContext().hasLAPACK();
505 }
506 
hasMAGMA()507 static inline bool hasMAGMA() {
508   return globalContext().hasMAGMA();
509 }
510 
hasMKLDNN()511 static inline bool hasMKLDNN() {
512   return globalContext().hasMKLDNN();
513 }
514 
manual_seed(uint64_t seed)515 static inline void manual_seed(uint64_t seed) {
516   auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
517   {
518     // See Note [Acquire lock when using random generators]
519     std::lock_guard<std::mutex> lock(gen.mutex());
520     gen.set_current_seed(seed);
521   }
522   // NB: Sometimes we build with CUDA, but we don't have any GPUs
523   // available. In that case, we must not seed CUDA; it will fail!
524   const auto cuda_num_gpus = detail::getCUDAHooks().getNumGPUs();
525   if (hasCUDA() && cuda_num_gpus > 0) {
526     for (const auto i : c10::irange(cuda_num_gpus)) {
527       auto cuda_gen = globalContext().defaultGenerator(
528           Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
529       {
530         // See Note [Acquire lock when using random generators]
531         std::lock_guard<std::mutex> lock(cuda_gen.mutex());
532         cuda_gen.set_current_seed(seed);
533       }
534     }
535   }
536 
537   const auto xpu_num_gpus = detail::getXPUHooks().getNumGPUs();
538   if (hasXPU() && xpu_num_gpus) {
539     for (const auto i : c10::irange(xpu_num_gpus)) {
540       auto xpu_gen = globalContext().defaultGenerator(
541           Device(at::kXPU, static_cast<c10::DeviceIndex>(i)));
542       {
543         // See Note [Acquire lock when using random generators]
544         std::lock_guard<std::mutex> lock(xpu_gen.mutex());
545         xpu_gen.set_current_seed(seed);
546       }
547     }
548   }
549 
550   if (hasMPS()) {
551     auto mps_gen = globalContext().defaultGenerator(c10::DeviceType::MPS);
552     // See Note [Acquire lock when using random generators]
553     std::lock_guard<std::mutex> lock(mps_gen.mutex());
554     mps_gen.set_current_seed(seed);
555   }
556 }
557 
558 // When the global flag `allow_tf32` is set to true, cuBLAS handles are
559 // automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
560 // For some operators, such as addmv, TF32 offers no performance improvement
561 // but causes precision loss. To help this case, this class implements
562 // a RAII guard that can be used to quickly disable TF32 within its scope.
563 //
564 // Usage:
565 //     NoTF32Guard disable_tf32;
566 struct TORCH_API NoTF32Guard {
567   NoTF32Guard();
568   ~NoTF32Guard();
569   static bool should_disable_tf32();
570 
571  private:
572   bool changed = false;
573 };
574 
575 struct TORCH_API ROCmBackwardPassGuard {
576   ROCmBackwardPassGuard();
577   ~ROCmBackwardPassGuard();
578   static bool is_backward_pass();
579 };
580 
581 } // namespace at
582