1 #pragma once
2
3 #include <ATen/BlasBackend.h>
4 #include <ATen/CPUGeneratorImpl.h>
5 #include <ATen/DeviceAccelerator.h>
6 #include <ATen/LinalgBackend.h>
7 #include <ATen/core/ATenGeneral.h>
8 #include <ATen/core/DeprecatedTypeProperties.h>
9 #include <ATen/core/Generator.h>
10 #include <ATen/core/LegacyTypeDispatch.h>
11 #include <ATen/detail/AcceleratorHooksInterface.h>
12 #include <ATen/detail/CUDAHooksInterface.h>
13 #include <ATen/detail/HIPHooksInterface.h>
14 #include <ATen/detail/IPUHooksInterface.h>
15 #include <ATen/detail/MAIAHooksInterface.h>
16 #include <ATen/detail/MPSHooksInterface.h>
17 #include <ATen/detail/MTIAHooksInterface.h>
18 #include <ATen/detail/PrivateUse1HooksInterface.h>
19 #include <ATen/detail/XPUHooksInterface.h>
20 #include <c10/core/QEngine.h>
21 #include <c10/core/impl/DeviceGuardImplInterface.h>
22 #include <c10/util/CallOnce.h>
23 #include <c10/util/Exception.h>
24 #include <c10/util/env.h>
25 #include <c10/util/irange.h>
26
27 #include <cstdint>
28 #include <mutex>
29
30 namespace at {
31
32 class Tensor;
33
34 enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
35
36 class TORCH_API Context {
37 public:
38 Context();
39
defaultGenerator(Device device)40 const Generator& defaultGenerator(Device device) {
41 c10::DeviceType device_type = device.type();
42 initCUDAIfNeeded(device_type);
43 initHIPIfNeeded(device_type);
44 if (device_type == at::kCPU) {
45 return at::detail::getDefaultCPUGenerator();
46 } else if (device_type == at::kCUDA) {
47 return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
48 } else if (device_type == at::kMPS) {
49 return at::detail::getMPSHooks().getDefaultMPSGenerator();
50 } else if (device_type == at::kXPU) {
51 return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
52 } else if (device_type == at::kIPU) {
53 return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
54 } else if (device_type == at::kPrivateUse1) {
55 return at::GetPrivateUse1HooksInterface()->getDefaultGenerator(
56 device.index());
57 } else {
58 AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
59 }
60 }
61 const AcceleratorHooksInterface& getAcceleratorHooksInterface(
62 std::optional<c10::DeviceType> opt_device_type = c10::nullopt) {
63 c10::DeviceType device_type = opt_device_type.has_value()
64 ? opt_device_type.value()
65 : at::getAccelerator(true).value();
66 if (device_type == at::kCUDA) {
67 return at::detail::getCUDAHooks();
68 } else if (device_type == at::kMPS) {
69 return at::detail::getMPSHooks();
70 } else if (device_type == at::kPrivateUse1) {
71 return at::detail::getPrivateUse1Hooks();
72 } else if (device_type == at::kMTIA) {
73 return at::detail::getMTIAHooks();
74 } else {
75 AT_ERROR(
76 c10::DeviceTypeName(device_type), " device type not an accelerator.");
77 }
78 }
getDeviceFromPtr(void * data,c10::DeviceType device_type)79 Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
80 initCUDAIfNeeded(device_type);
81 initHIPIfNeeded(device_type);
82 initXPUIfNeeded(device_type);
83 if (device_type == at::kCPU) {
84 return c10::DeviceType::CPU;
85 } else if (device_type == at::kCUDA) {
86 return at::detail::getCUDAHooks().getDeviceFromPtr(data);
87 } else if (device_type == at::kXPU) {
88 return at::detail::getXPUHooks().getDeviceFromPtr(data);
89 } else if (device_type == at::kPrivateUse1) {
90 return at::GetPrivateUse1HooksInterface()->getDeviceFromPtr(data);
91 } else {
92 AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
93 }
94 }
isPinnedPtr(const void * data)95 static bool isPinnedPtr(const void* data) {
96 return detail::getCUDAHooks().isPinnedPtr(data);
97 }
98 static bool hasOpenMP();
99 static bool hasMKL();
100 static bool hasLAPACK();
101 static bool hasMKLDNN();
hasMAGMA()102 static bool hasMAGMA() {
103 return detail::getCUDAHooks().hasMAGMA();
104 }
hasCUDA()105 static bool hasCUDA() {
106 return detail::getCUDAHooks().hasCUDA();
107 }
hasMTIA()108 static bool hasMTIA() {
109 return detail::getMTIAHooks().hasMTIA();
110 }
hasCUDART()111 static bool hasCUDART() {
112 return detail::getCUDAHooks().hasCUDART();
113 }
versionCUDART()114 static long versionCUDART() {
115 return detail::getCUDAHooks().versionCUDART();
116 }
hasCuDNN()117 static bool hasCuDNN() {
118 return detail::getCUDAHooks().hasCuDNN();
119 }
versionCuDNN()120 static long versionCuDNN() {
121 return detail::getCUDAHooks().versionCuDNN();
122 }
hasCuSOLVER()123 static bool hasCuSOLVER() {
124 return detail::getCUDAHooks().hasCuSOLVER();
125 }
hasCuBLASLt()126 static bool hasCuBLASLt() {
127 return detail::getCUDAHooks().hasCuBLASLt();
128 }
hasHIP()129 static bool hasHIP() {
130 return detail::getHIPHooks().hasHIP();
131 }
hasMPS()132 static bool hasMPS() {
133 return detail::getMPSHooks().hasMPS();
134 }
hasIPU()135 static bool hasIPU() {
136 return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
137 }
hasXLA()138 static bool hasXLA() {
139 return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
140 }
hasXPU()141 static bool hasXPU() {
142 return detail::getXPUHooks().hasXPU();
143 }
hasLazy()144 static bool hasLazy() {
145 return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
146 }
hasMAIA()147 static bool hasMAIA() {
148 return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
149 }
150 // defined in header so that getNonVariableType has ability to inline
151 // call_once check. getNonVariableType is called fairly frequently
lazyInitCUDA()152 void lazyInitCUDA() {
153 c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
154 }
lazyInitHIP()155 void lazyInitHIP() {
156 c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
157 }
lazyInitXPU()158 void lazyInitXPU() {
159 c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
160 }
lazyInitMTIA()161 void lazyInitMTIA() {
162 c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
163 }
lazyInitPrivateUse1()164 void lazyInitPrivateUse1() {
165 c10::call_once(thp_init, [&] {
166 if (isPrivateUse1HooksRegistered()) {
167 at::GetPrivateUse1HooksInterface()->initPrivateUse1();
168 }
169 });
170 }
getNVRTC()171 static const at::cuda::NVRTC& getNVRTC() {
172 return detail::getCUDAHooks().nvrtc();
173 }
174
175 static bool setFlushDenormal(bool on);
176
177 // NB: This method is *purely* whether or not a user requested
178 // that CuDNN was enabled, it doesn't actually say anything about
179 // whether or not CuDNN is actually usable. Use cudnn_is_acceptable
180 // to test this instead
181 bool userEnabledCuDNN() const;
182 void setUserEnabledCuDNN(bool e);
183 bool userEnabledMkldnn() const;
184 void setUserEnabledMkldnn(bool e);
185 bool benchmarkCuDNN() const;
186 void setBenchmarkCuDNN(bool);
187 int benchmarkLimitCuDNN() const;
188 void setBenchmarkLimitCuDNN(int);
189 bool deterministicCuDNN() const;
190 void setDeterministicCuDNN(bool);
191 bool userEnabledNNPACK() const;
192 void setUserEnabledNNPACK(bool e);
193
194 // Note [Disabling Fused SDP Kernels]
195 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
196 // Flash and Memory Efficient SDP kernels are enabled by default.
197 // However, they can be disabled by setting
198 // at::globalContext().setUserEnabledFlashSDP(false) flag.
199 // This is useful for debugging purposes. For example, if you want to
200 // compare the performance of the flash SDP kernels with the unfused
201 // kernel, you can disable the flash SDP kernels. By disabling
202 // the math SDP kernel, you can force your code to use flash kernels.
203 // The math SDP kernel can be disabled by setting
204 // at::globalContext().setUserEnabledMathSDP(false) flag.
205 void setSDPUseFlash(bool);
206 bool userEnabledFlashSDP() const;
207
208 void setSDPUseMemEfficient(bool);
209 bool userEnabledMemEfficientSDP() const;
210
211 void setSDPUseMath(bool);
212 bool userEnabledMathSDP() const;
213
214 void setSDPUseCuDNN(bool);
215 bool userEnabledCuDNNSDP() const;
216
217 at::LinalgBackend linalgPreferredBackend() const;
218 void setLinalgPreferredBackend(at::LinalgBackend);
219
220 at::BlasBackend blasPreferredBackend();
221 void setBlasPreferredBackend(at::BlasBackend);
222
223 // Note [Enabling Deterministic Operations]
224 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
225 // Operations in PyTorch that normally act nondeterministically, but have an
226 // alternate deterministic implementation, should satisfy the following
227 // requirements:
228 //
229 // * Include this comment: "See Note [Enabling Deterministic Operations]"
230 //
231 // * Check the value of `at::globalContext().deterministicAlgorithms()` to
232 // toggle
233 // between nondeterministic and deterministic implementations.
234 //
235 // * Have an entry in the list of PyTorch operations that toggle between
236 // nondeterministic
237 // and deterministic implementations, in the docstring of
238 // `use_deterministic_algorithms()` in torch/__init__.py
239 //
240 // `example_func()` below shows an example of toggling between
241 // nondeterministic and deterministic implementations:
242 //
243 // void example_func() {
244 // // See Note [Enabling Deterministic Operations]
245 // if (at::globalContext().deterministicAlgorithms()) {
246 // example_func_deterministic();
247 // } else {
248 // example_func_nondeterministic();
249 // }
250 // }
251
252 bool deterministicAlgorithms() const;
253 bool deterministicAlgorithmsWarnOnly() const;
254 void setDeterministicAlgorithms(bool, bool);
255 bool deterministicFillUninitializedMemory() const;
256 void setDeterministicFillUninitializedMemory(bool);
257
258 // Note [Writing Nondeterministic Operations]
259 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
260 // Operations in PyTorch that act nondeterministically and do not have an
261 // alternate deterministic implementation should satisfy the following
262 // requirements:
263 //
264 // * Include this comment: "See Note [Writing Nondeterministic Operations]"
265 //
266 // * Include a comment explaining why the operation is nondeterministic.
267 //
268 // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
269 // of the time, this should be accomplished by calling
270 // `at::globalContext().alertNotDeterminstic()`. However, if the
271 // nondeterministic behavior is caused by the CuBLAS workspace
272 // configuration in CUDA >= 10.2,
273 // `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
274 // called instead (in this case, a comment explaining why the operation is
275 // nondeterministic is not necessary). See below for details on these
276 // methods.
277 //
278 // * Have an entry in the list of nondeterministic PyTorch operations in the
279 // docstring of `use_deterministic_algorithms()` in torch/__init__.py
280 //
281 // * Have a test function in `test/test_torch.py` whose name begins with
282 // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
283 // configuration is the reason for nondeterminism, the operation should be
284 // included in the `test_cublas_config_nondeterministic_alert` test. Any new
285 // tests should ideally follow a pattern similar to the existing ones.
286 //
287 // `example_func()` below shows an example of the comments and error-throwing
288 // code for a nondeterministic operation:
289 //
290 // void example_func() {
291 // // See Note [Writing Nondeterministic Operations]
292 // // Nondeterministic because <reason>
293 // at::globalContext().alertNondeterministic("example_func");
294 // ...
295 // }
296
297 // Throws an error if `Context::deterministicAlgorithms()` is true
298 static void alertNotDeterministic(c10::string_view const& caller);
299
300 // Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
301 // >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
302 // ":4096:8". For more details:
303 // https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
304 void alertCuBLASConfigNotDeterministic() const;
305
306 void setFloat32MatmulPrecision(const std::string& s);
307 bool allowTF32CuDNN() const;
308 void setAllowTF32CuDNN(bool);
309 bool allowTF32CuBLAS() const;
310 void setAllowTF32CuBLAS(bool);
311 Float32MatmulPrecision float32MatmulPrecision() const;
312 void setFloat32MatmulPrecision(Float32MatmulPrecision p);
313 bool allowFP16ReductionCuBLAS() const;
314 void setAllowFP16ReductionCuBLAS(bool);
315 bool allowBF16ReductionCuBLAS() const;
316 void setAllowBF16ReductionCuBLAS(bool);
317 at::QEngine qEngine() const;
318 void setQEngine(at::QEngine e);
319 static const std::vector<at::QEngine>& supportedQEngines();
320 static bool isXNNPACKAvailable();
321 void setCheckSparseTensorInvariants(bool e);
322 bool checkSparseTensorInvariants() const;
323 // This method is used to release the original weight after pre-packing.
324 // It should be called once before loading/running the model.
325 // NB: By default it is set to true for mobile builds.
326 void setReleaseWeightsWhenPrepacking(bool e);
327 bool releaseWeightsWhenPrepacking() const;
328
329 void setDisplayVmapFallbackWarnings(bool enabled);
330 bool areVmapFallbackWarningsEnabled() const;
331
332 void setDefaultMobileCPUAllocator();
333 void unsetDefaultMobileCPUAllocator();
334 bool allowFP16ReductionCPU() const;
335 void setAllowFP16ReductionCPU(bool);
336
337 private:
initCUDAIfNeeded(c10::DeviceType p)338 void initCUDAIfNeeded(c10::DeviceType p) {
339 if (p == c10::DeviceType::CUDA) {
340 lazyInitCUDA();
341 }
342 }
initHIPIfNeeded(c10::DeviceType p)343 void initHIPIfNeeded(c10::DeviceType p) {
344 if (p == c10::DeviceType::HIP) {
345 lazyInitHIP();
346 }
347 }
initXPUIfNeeded(c10::DeviceType p)348 void initXPUIfNeeded(c10::DeviceType p) {
349 if (p == c10::DeviceType::XPU) {
350 lazyInitXPU();
351 }
352 }
353 static bool checkCuBLASConfigDeterministic();
354 c10::once_flag thc_init;
355 c10::once_flag thh_init;
356 c10::once_flag thx_init;
357 c10::once_flag th_mtia_init;
358 c10::once_flag thp_init;
359 bool enabled_cudnn = true;
360 bool deterministic_cudnn = false;
361 bool _deterministic_algorithms = false;
362 bool _deterministic_algorithms_warn_only = false;
363 bool _deterministic_fill_uninitialized_memory = true;
364 bool enabled_flashSDP = true;
365 bool enabled_mem_efficientSDP = true;
366 bool enabled_mathSDP = true;
367 bool enabled_cudnnSDP = false;
368 #ifdef USE_ROCM
369 bool benchmark_cudnn = true;
370 #else
371 bool benchmark_cudnn = false;
372 #endif
373 Float32MatmulPrecision float32_matmul_precision =
374 c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
375 ? at::Float32MatmulPrecision::HIGH
376 : at::Float32MatmulPrecision::HIGHEST;
377 int benchmark_limit_cudnn = 10;
378 bool allow_tf32_cudnn = true;
379 bool allow_fp16_reduction_cublas = true;
380 bool allow_bf16_reduction_cublas = true;
381 bool enabled_mkldnn = true;
382 bool enabled_nnpack = true;
383 at::LinalgBackend linalg_preferred_backend =
384 c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true
385 ? at::LinalgBackend::Cusolver
386 : at::LinalgBackend::Default;
387 at::BlasBackend blas_preferred_backend =
388 #ifdef USE_ROCM
389 (c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false)
390 #else
391 (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true)
392 #endif
393 ? at::BlasBackend::Cublaslt
394 : at::BlasBackend::Cublas;
395 #ifdef C10_MOBILE
396 bool release_original_weights = true;
397 #else
398 bool release_original_weights = false;
399 #endif
400 bool display_vmap_fallback_warnings_ = false;
401 std::optional<at::QEngine> quantized_engine = c10::nullopt;
402 bool enable_sparse_tensor_invariant_checks = false;
403 bool allow_fp16_reduction_cpu = false;
404
405 Allocator* prev_allocator_ptr_{nullptr};
406 };
407
408 TORCH_API Context& globalContext();
409
init()410 static inline void init() {
411 globalContext();
412 }
413
414 TORCH_API Allocator* getCPUAllocator();
415
getDeprecatedTypeProperties(Backend p,ScalarType s)416 static inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
417 Backend p,
418 ScalarType s) {
419 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
420 p, s);
421 }
422
CPU(ScalarType s)423 static inline DeprecatedTypeProperties& CPU(ScalarType s) {
424 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
425 Backend::CPU, s);
426 }
427
CUDA(ScalarType s)428 static inline DeprecatedTypeProperties& CUDA(ScalarType s) {
429 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
430 Backend::CUDA, s);
431 }
432
HIP(ScalarType s)433 static inline DeprecatedTypeProperties& HIP(ScalarType s) {
434 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
435 Backend::HIP, s);
436 }
437
MPS(ScalarType s)438 static inline DeprecatedTypeProperties& MPS(ScalarType s) {
439 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
440 Backend::MPS, s);
441 }
442
hasCUDA()443 static inline bool hasCUDA() {
444 return globalContext().hasCUDA();
445 }
446
hasMTIA()447 static inline bool hasMTIA() {
448 return globalContext().hasMTIA();
449 }
450
hasHIP()451 static inline bool hasHIP() {
452 return globalContext().hasHIP();
453 }
454
hasIPU()455 static inline bool hasIPU() {
456 return globalContext().hasIPU();
457 }
458
hasXLA()459 static inline bool hasXLA() {
460 return globalContext().hasXLA();
461 }
462
hasMPS()463 static inline bool hasMPS() {
464 return globalContext().hasMPS();
465 }
466
hasMAIA()467 static inline bool hasMAIA() {
468 return globalContext().hasMAIA();
469 }
470
hasXPU()471 static inline bool hasXPU() {
472 return globalContext().hasXPU();
473 }
474
475 // Despite its name, this function returns the number of *CUDA* GPUs.
getNumGPUs()476 static inline size_t getNumGPUs() {
477 // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
478 // FUNCTION. If you are interested in interrogating the number of
479 // devices for a specific device type, add that function to the
480 // relevant library (e.g., similar to at::cuda::device_count())
481 if (hasCUDA() && hasHIP()) {
482 throw std::runtime_error(
483 "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
484 "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
485 "means HIP. Rebuild PyTorch with one or the other disabled.");
486 } else if (hasCUDA()) {
487 return detail::getCUDAHooks().getNumGPUs();
488 } else if (hasHIP()) {
489 return detail::getHIPHooks().getNumGPUs();
490 } else {
491 return 0;
492 }
493 }
494
hasOpenMP()495 static inline bool hasOpenMP() {
496 return globalContext().hasOpenMP();
497 }
498
hasMKL()499 static inline bool hasMKL() {
500 return globalContext().hasMKL();
501 }
502
hasLAPACK()503 static inline bool hasLAPACK() {
504 return globalContext().hasLAPACK();
505 }
506
hasMAGMA()507 static inline bool hasMAGMA() {
508 return globalContext().hasMAGMA();
509 }
510
hasMKLDNN()511 static inline bool hasMKLDNN() {
512 return globalContext().hasMKLDNN();
513 }
514
manual_seed(uint64_t seed)515 static inline void manual_seed(uint64_t seed) {
516 auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
517 {
518 // See Note [Acquire lock when using random generators]
519 std::lock_guard<std::mutex> lock(gen.mutex());
520 gen.set_current_seed(seed);
521 }
522 // NB: Sometimes we build with CUDA, but we don't have any GPUs
523 // available. In that case, we must not seed CUDA; it will fail!
524 const auto cuda_num_gpus = detail::getCUDAHooks().getNumGPUs();
525 if (hasCUDA() && cuda_num_gpus > 0) {
526 for (const auto i : c10::irange(cuda_num_gpus)) {
527 auto cuda_gen = globalContext().defaultGenerator(
528 Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
529 {
530 // See Note [Acquire lock when using random generators]
531 std::lock_guard<std::mutex> lock(cuda_gen.mutex());
532 cuda_gen.set_current_seed(seed);
533 }
534 }
535 }
536
537 const auto xpu_num_gpus = detail::getXPUHooks().getNumGPUs();
538 if (hasXPU() && xpu_num_gpus) {
539 for (const auto i : c10::irange(xpu_num_gpus)) {
540 auto xpu_gen = globalContext().defaultGenerator(
541 Device(at::kXPU, static_cast<c10::DeviceIndex>(i)));
542 {
543 // See Note [Acquire lock when using random generators]
544 std::lock_guard<std::mutex> lock(xpu_gen.mutex());
545 xpu_gen.set_current_seed(seed);
546 }
547 }
548 }
549
550 if (hasMPS()) {
551 auto mps_gen = globalContext().defaultGenerator(c10::DeviceType::MPS);
552 // See Note [Acquire lock when using random generators]
553 std::lock_guard<std::mutex> lock(mps_gen.mutex());
554 mps_gen.set_current_seed(seed);
555 }
556 }
557
558 // When the global flag `allow_tf32` is set to true, cuBLAS handles are
559 // automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
560 // For some operators, such as addmv, TF32 offers no performance improvement
561 // but causes precision loss. To help this case, this class implements
562 // a RAII guard that can be used to quickly disable TF32 within its scope.
563 //
564 // Usage:
565 // NoTF32Guard disable_tf32;
566 struct TORCH_API NoTF32Guard {
567 NoTF32Guard();
568 ~NoTF32Guard();
569 static bool should_disable_tf32();
570
571 private:
572 bool changed = false;
573 };
574
575 struct TORCH_API ROCmBackwardPassGuard {
576 ROCmBackwardPassGuard();
577 ~ROCmBackwardPassGuard();
578 static bool is_backward_pass();
579 };
580
581 } // namespace at
582