1 #include <ATen/Config.h>
2
3 #include <ATen/Context.h>
4
5 #include <c10/core/CPUAllocator.h>
6
7 #include <algorithm>
8 #include <cctype>
9 #include <string>
10 #include <stdexcept>
11
12 #include <ATen/cpu/FlushDenormal.h>
13
14 #ifdef USE_FBGEMM
15 #include <fbgemm/Fbgemm.h>
16 #endif // USE_FBGEMM
17 #if defined(__aarch64__) && !defined(C10_MOBILE)
18 #include <cpuinfo.h>
19 #endif
20
21 namespace at {
22
23 Context::Context() = default;
24
25 // TODO: This could be bad juju if someone calls globalContext() in the
26 // destructor of an object with static lifetime.
globalContext()27 Context& globalContext() {
28 static Context globalContext_;
29 return globalContext_;
30 }
31
32 // NB: This method is *purely* whether or not a user requested
33 // that CuDNN was enabled, it doesn't actually say anything about
34 // whether or not CuDNN is actually usable.
userEnabledCuDNN() const35 bool Context::userEnabledCuDNN() const {
36 return enabled_cudnn;
37 }
38
setUserEnabledCuDNN(bool e)39 void Context::setUserEnabledCuDNN(bool e) {
40 enabled_cudnn = e;
41 }
42
userEnabledMkldnn() const43 bool Context::userEnabledMkldnn() const {
44 return enabled_mkldnn;
45 }
46
setUserEnabledMkldnn(bool e)47 void Context::setUserEnabledMkldnn(bool e) {
48 enabled_mkldnn = e;
49 }
50
deterministicCuDNN() const51 bool Context::deterministicCuDNN() const {
52 return deterministic_cudnn;
53 }
54
setDeterministicCuDNN(bool b)55 void Context::setDeterministicCuDNN(bool b) {
56 deterministic_cudnn = b;
57 }
58
deterministicAlgorithms() const59 bool Context::deterministicAlgorithms() const {
60 return _deterministic_algorithms;
61 }
62
deterministicAlgorithmsWarnOnly() const63 bool Context::deterministicAlgorithmsWarnOnly() const {
64 return _deterministic_algorithms_warn_only;
65 }
66
setDeterministicAlgorithms(bool b,bool warn_only=false)67 void Context::setDeterministicAlgorithms(bool b, bool warn_only=false) {
68 _deterministic_algorithms = b;
69 _deterministic_algorithms_warn_only = warn_only;
70 }
71
deterministicFillUninitializedMemory() const72 bool Context::deterministicFillUninitializedMemory() const {
73 return _deterministic_fill_uninitialized_memory;
74 }
75
setDeterministicFillUninitializedMemory(bool b)76 void Context::setDeterministicFillUninitializedMemory(bool b) {
77 _deterministic_fill_uninitialized_memory = b;
78 }
79
alertNotDeterministic(c10::string_view const & caller)80 void Context::alertNotDeterministic(c10::string_view const& caller) {
81 if (globalContext().deterministicAlgorithms()) {
82 if (globalContext().deterministicAlgorithmsWarnOnly()) {
83 TORCH_WARN(
84 caller, " does not have a deterministic implementation, but you set "
85 "'torch.use_deterministic_algorithms(True, warn_only=True)'. "
86 "You can file an issue at https://github.com/pytorch/pytorch/issues "
87 "to help us prioritize adding deterministic support for this operation.");
88 } else {
89 TORCH_CHECK(false,
90 caller, " does not have a deterministic implementation, but you set "
91 "'torch.use_deterministic_algorithms(True)'. You can turn off "
92 "determinism just for this operation, or you can use the "
93 "'warn_only=True' option, if that's acceptable for your application. "
94 "You can also file an issue at https://github.com/pytorch/pytorch/issues "
95 "to help us prioritize adding deterministic support for this operation.");
96 }
97 }
98 }
99
userEnabledNNPACK() const100 bool Context::userEnabledNNPACK() const {
101 return enabled_nnpack;
102 }
103
setUserEnabledNNPACK(bool e)104 void Context::setUserEnabledNNPACK(bool e) {
105 enabled_nnpack = e;
106 }
107
allowTF32CuDNN() const108 bool Context::allowTF32CuDNN() const {
109 return allow_tf32_cudnn;
110 }
111
setAllowTF32CuDNN(bool b)112 void Context::setAllowTF32CuDNN(bool b) {
113 allow_tf32_cudnn = b;
114 }
115
userEnabledFlashSDP() const116 bool Context::userEnabledFlashSDP() const {
117 return enabled_flashSDP;
118 }
119
setSDPUseFlash(bool e)120 void Context::setSDPUseFlash(bool e) {
121 enabled_flashSDP = e;
122 }
123
userEnabledMemEfficientSDP() const124 bool Context::userEnabledMemEfficientSDP() const {
125 return enabled_mem_efficientSDP;
126 }
127
setSDPUseMemEfficient(bool e)128 void Context::setSDPUseMemEfficient(bool e) {
129 enabled_mem_efficientSDP = e;
130 }
131
userEnabledMathSDP() const132 bool Context::userEnabledMathSDP() const {
133 return enabled_mathSDP;
134 }
135
setSDPUseMath(bool e)136 void Context::setSDPUseMath(bool e) {
137 enabled_mathSDP = e;
138 }
139
userEnabledCuDNNSDP() const140 bool Context::userEnabledCuDNNSDP() const {
141 return enabled_cudnnSDP;
142 }
143
setSDPUseCuDNN(bool e)144 void Context::setSDPUseCuDNN(bool e) {
145 enabled_cudnnSDP = e;
146 }
147
148
149 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
150 static const char cublas_config_var_name[] = "CUBLAS_WORKSPACE_CONFIG";
151 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
152 static const char* const cublas_deterministic_configs[] = { ":4096:8", ":16:8" };
153
checkCuBLASConfigDeterministic()154 bool Context::checkCuBLASConfigDeterministic() {
155 bool cublas_config_deterministic = true;
156 // If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
157 // is set to deterministic setting
158 if (hasCUDART() && (versionCUDART() >= 10020)) {
159 char* workspace_config = std::getenv(cublas_config_var_name);
160 cublas_config_deterministic = (workspace_config != nullptr) && (
161 (strcmp(workspace_config, cublas_deterministic_configs[0]) == 0)
162 || (strcmp(workspace_config, cublas_deterministic_configs[1]) == 0)
163 );
164 }
165 return cublas_config_deterministic;
166 }
167
alertCuBLASConfigNotDeterministic() const168 void Context::alertCuBLASConfigNotDeterministic() const {
169 static bool cublas_config_deterministic = checkCuBLASConfigDeterministic();
170 if (C10_LIKELY(!deterministicAlgorithms() || cublas_config_deterministic)) {
171 return;
172 }
173
174 auto msg = c10::str(
175 "Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or ",
176 "`at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because ",
177 "it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this ",
178 "case, you must set an environment variable before running your PyTorch application: ",
179 cublas_config_var_name, "=", cublas_deterministic_configs[0], " or ",
180 cublas_config_var_name, "=", cublas_deterministic_configs[1], ". For more information, go to ",
181 "https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility"
182 );
183
184 if (deterministicAlgorithmsWarnOnly()) {
185 TORCH_WARN(msg);
186 } else {
187 TORCH_CHECK(false, msg);
188 }
189 }
190
benchmarkCuDNN() const191 bool Context::benchmarkCuDNN() const {
192 return benchmark_cudnn;
193 }
194
setBenchmarkCuDNN(bool b)195 void Context::setBenchmarkCuDNN(bool b) {
196 benchmark_cudnn = b;
197 }
198
benchmarkLimitCuDNN() const199 int Context::benchmarkLimitCuDNN() const {
200 return benchmark_limit_cudnn;
201 }
202
setBenchmarkLimitCuDNN(int b)203 void Context::setBenchmarkLimitCuDNN(int b) {
204 benchmark_limit_cudnn = b;
205 }
206
allowTF32CuBLAS() const207 bool Context::allowTF32CuBLAS() const {
208 return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
209 }
210
setAllowTF32CuBLAS(bool b)211 void Context::setAllowTF32CuBLAS(bool b) {
212 float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
213 }
214
float32MatmulPrecision() const215 Float32MatmulPrecision Context::float32MatmulPrecision() const {
216 return float32_matmul_precision;
217 }
218
setFloat32MatmulPrecision(Float32MatmulPrecision p)219 void Context::setFloat32MatmulPrecision(Float32MatmulPrecision p) {
220 float32_matmul_precision = p;
221 }
222
setFloat32MatmulPrecision(const std::string & s)223 void Context::setFloat32MatmulPrecision(const std::string &s) {
224 auto match = [this](const std::string & s_) {
225 // TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
226 if (s_ == "highest") {
227 float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
228 return true;
229 } else if (s_ == "high") {
230 float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
231 return true;
232 } else if (s_ == "medium") {
233 float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;
234 return true;
235 }
236 return false;
237 };
238 if (match(s)) { return; }
239 std::string sl;
240 std::transform(s.begin(), s.end(), sl.begin(),
241 [](unsigned char c) -> unsigned char { return std::tolower(c); });
242 if (match(sl)) { return; }
243 TORCH_WARN(s, " is not one of 'highest', 'high', or 'medium'; the current"
244 "setFloat32MatmulPrecision call has no effect.");
245 }
246
linalgPreferredBackend() const247 at::LinalgBackend Context::linalgPreferredBackend() const {
248 return linalg_preferred_backend;
249 }
250
setLinalgPreferredBackend(at::LinalgBackend b)251 void Context::setLinalgPreferredBackend(at::LinalgBackend b) {
252 linalg_preferred_backend = b;
253 TORCH_CHECK((b != at::LinalgBackend::Cusolver) || hasCuSOLVER(),
254 "Cannot set preferred backend to cuSOLVER if PyTorch has not been compiled with cuSOLVER.");
255 TORCH_CHECK((b != at::LinalgBackend::Magma) || hasMAGMA(),
256 "Cannot set preferred backend to MAGMA if PyTorch has not been compiled with MAGMA.");
257 if (b != at::LinalgBackend::Default) {
258 TORCH_WARN_ONCE(
259 "torch.backends.cuda.preferred_linalg_library is an experimental feature. "
260 "If you see any error or unexpected behavior when this flag is set "
261 "please file an issue on GitHub."
262 );
263 }
264 }
265
blasPreferredBackend()266 at::BlasBackend Context::blasPreferredBackend() {
267 #ifdef USE_ROCM
268 if (blas_preferred_backend == at::BlasBackend::Cublaslt) {
269 static const bool hipblaslt_unsupported = []() {
270 static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
271 for (auto index = 0; index < at::getNumGPUs(); index++) {
272 if (!detail::getCUDAHooks().isGPUArch(index, archs)) {
273 TORCH_WARN_ONCE(
274 "Attempting to use hipBLASLt on an unsupported architecture! "
275 "Overriding blas backend to hipblas");
276 return true;
277 }
278 }
279 return false;
280 }();
281 if (hipblaslt_unsupported) blas_preferred_backend = at::BlasBackend::Cublas;
282 }
283 #endif
284 return blas_preferred_backend;
285 }
286
setBlasPreferredBackend(at::BlasBackend b)287 void Context::setBlasPreferredBackend(at::BlasBackend b) {
288 #ifdef _MSC_VER
289 TORCH_WARN_ONCE(
290 "torch.backends.cuda.preferred_blas_library is an experimental feature. "
291 "It is not supported on Windows."
292 );
293 #else
294 TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(),
295 "Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
296 if (b != at::BlasBackend::Cublas) {
297 TORCH_WARN_ONCE(
298 "torch.backends.cuda.preferred_blas_library is an experimental feature. "
299 "If you see any error or unexpected behavior when this flag is set "
300 "please file an issue on GitHub."
301 );
302 }
303 blas_preferred_backend = b;
304 #endif
305 }
306
allowFP16ReductionCuBLAS() const307 bool Context::allowFP16ReductionCuBLAS() const {
308 return allow_fp16_reduction_cublas;
309 }
310
setAllowFP16ReductionCuBLAS(bool b)311 void Context::setAllowFP16ReductionCuBLAS(bool b) {
312 allow_fp16_reduction_cublas = b;
313 }
314
allowBF16ReductionCuBLAS() const315 bool Context::allowBF16ReductionCuBLAS() const {
316 return allow_bf16_reduction_cublas;
317 }
318
setAllowBF16ReductionCuBLAS(bool b)319 void Context::setAllowBF16ReductionCuBLAS(bool b) {
320 allow_bf16_reduction_cublas = b;
321 }
322
323
hasMKL()324 bool Context::hasMKL() {
325 #if AT_MKL_ENABLED()
326 return true;
327 #else
328 return false;
329 #endif
330 }
331
hasMKLDNN()332 bool Context::hasMKLDNN() {
333 #if AT_MKLDNN_ENABLED()
334 return true;
335 #else
336 return false;
337 #endif
338 }
339
hasOpenMP()340 bool Context::hasOpenMP() {
341 #ifdef _OPENMP
342 return true;
343 #else
344 return false;
345 #endif
346 }
347
hasLAPACK()348 bool Context::hasLAPACK() {
349 #if AT_BUILD_WITH_LAPACK()
350 return true;
351 #else
352 return false;
353 #endif
354 }
355
qEngine() const356 at::QEngine Context::qEngine() const {
357 static auto _quantized_engine = []() {
358 at::QEngine qengine = at::kNoQEngine;
359 #if defined(C10_MOBILE) && defined(USE_PYTORCH_QNNPACK)
360 qengine = at::kQNNPACK;
361 #endif
362
363 #if AT_MKLDNN_ENABLED()
364 qengine = at::kONEDNN;
365 #endif
366
367 #ifdef USE_FBGEMM
368 if (fbgemm::fbgemmSupportedCPU()) {
369 /* X86 is enabled if and only if fbgemm is available.
370 * It combines goodness of fbgemm and onednn by dispatching.
371 * If onednn not available, always dispatch to fbgemm.
372 * Make it default qengine for X86 CPU platforms.
373 */
374 qengine = at::kX86;
375 }
376 #endif
377 return qengine;
378 }();
379 return quantized_engine.value_or(_quantized_engine);
380 }
381
setQEngine(at::QEngine e)382 void Context::setQEngine(at::QEngine e) {
383 const auto& qengines = supportedQEngines();
384 if (std::find(qengines.begin(), qengines.end(), e) != qengines.end()) {
385 quantized_engine = e;
386 return;
387 }
388 TORCH_CHECK(false, "quantized engine ", toString(e), " is not supported");
389 }
390
supportedQEngines()391 const std::vector<at::QEngine>& Context::supportedQEngines() {
392 static auto supported_qengines = []() {
393 std::vector<at::QEngine> engines = {};
394 // Engines are listed in priority order: later one wins
395 // By default we prefer FBGEMM if we're running on server side
396 // QNNPACK on server side has some issue, so we disable it by default.
397 #ifdef C10_MOBILE
398 engines.push_back(at::kNoQEngine);
399 #ifdef USE_PYTORCH_QNNPACK
400 engines.push_back(at::kQNNPACK);
401 #endif
402 #else // C10_MOBILE
403 #ifdef USE_PYTORCH_QNNPACK
404 engines.push_back(at::kQNNPACK);
405 #endif
406 engines.push_back(at::kNoQEngine);
407 #endif // C10_MOBILE
408
409 #if AT_MKLDNN_ENABLED()
410 engines.push_back(at::kONEDNN);
411 #endif
412
413 #ifdef USE_FBGEMM
414 if (fbgemm::fbgemmSupportedCPU()) {
415 engines.push_back(at::kX86);
416 // The X86 qengine is available if and only if FBGEMM is available
417 engines.push_back(at::kFBGEMM);
418 }
419 #endif
420
421 return engines;
422 }();
423 return supported_qengines;
424 }
425
isXNNPACKAvailable()426 bool Context::isXNNPACKAvailable() {
427 #ifdef USE_XNNPACK
428 return true;
429 #else
430 return false;
431 #endif
432 }
433
setCheckSparseTensorInvariants(bool e)434 void Context::setCheckSparseTensorInvariants(bool e) {
435 enable_sparse_tensor_invariant_checks = e;
436 }
437
checkSparseTensorInvariants() const438 bool Context::checkSparseTensorInvariants() const {
439 return enable_sparse_tensor_invariant_checks;
440 }
441
releaseWeightsWhenPrepacking() const442 bool Context::releaseWeightsWhenPrepacking() const {
443 return release_original_weights;
444 }
445
setReleaseWeightsWhenPrepacking(bool e)446 void Context::setReleaseWeightsWhenPrepacking(bool e) {
447 release_original_weights = e;
448 }
449
setFlushDenormal(bool on)450 bool Context::setFlushDenormal(bool on) {
451 return at::cpu::set_flush_denormal(on);
452 }
453
getCPUAllocator()454 Allocator* getCPUAllocator() {
455 return c10::GetCPUAllocator();
456 }
457
458 // override_allow_tf32_flag = true
459 // means the allow_tf32 flags are overrided and tf32 is force disabled
460 // override_allow_tf32_flag = false
461 // means the original allow_tf32 flags are followed
462 thread_local bool override_allow_tf32_flag = false;
463
NoTF32Guard()464 NoTF32Guard::NoTF32Guard() {
465 if (!override_allow_tf32_flag) {
466 changed = true;
467 override_allow_tf32_flag = true;
468 }
469 }
470
~NoTF32Guard()471 NoTF32Guard::~NoTF32Guard() {
472 if (changed) {
473 override_allow_tf32_flag = false;
474 }
475 }
476
should_disable_tf32()477 bool NoTF32Guard::should_disable_tf32() {
478 return override_allow_tf32_flag;
479 }
480
481 // Ops can query this flag to know they are in the backward pass.
482 // This information can be used, for example, to select implementations
483 // with different numerical or performance characteristics.
484 // See https://pytorch.org/docs/stable/notes/numerical_accuracy.html for details.
485 thread_local bool rocm_is_backward_pass;
486
ROCmBackwardPassGuard()487 ROCmBackwardPassGuard::ROCmBackwardPassGuard() {
488 rocm_is_backward_pass = true;
489 }
490
~ROCmBackwardPassGuard()491 ROCmBackwardPassGuard::~ROCmBackwardPassGuard() {
492 rocm_is_backward_pass = false;
493 }
494
is_backward_pass()495 bool ROCmBackwardPassGuard::is_backward_pass() {
496 return rocm_is_backward_pass;
497 }
498
areVmapFallbackWarningsEnabled() const499 bool Context::areVmapFallbackWarningsEnabled() const {
500 return display_vmap_fallback_warnings_;
501 }
502
setDisplayVmapFallbackWarnings(bool enabled)503 void Context::setDisplayVmapFallbackWarnings(bool enabled) {
504 display_vmap_fallback_warnings_ = enabled;
505 }
506
setDefaultMobileCPUAllocator()507 void Context::setDefaultMobileCPUAllocator() {
508 TORCH_CHECK(prev_allocator_ptr_ == nullptr,
509 "Already within the scope of another non-default cpu allocator."
510 "Cannot set another allocator.");
511 // Setting the priority high to make sure no other allocator gets used instead of this.
512 prev_allocator_ptr_ = c10::GetCPUAllocator();
513 c10::SetCPUAllocator(c10::GetDefaultMobileCPUAllocator(), /*priority*/ 100);
514 }
515
unsetDefaultMobileCPUAllocator()516 void Context::unsetDefaultMobileCPUAllocator() {
517 TORCH_CHECK(prev_allocator_ptr_ != nullptr,
518 "setDefaultMobileCPUAllocator must have been called "
519 "before unsetDefaultMobileCPUAllocator.");
520 // Setting the priority high to make sure no other allocator gets used instead of this.
521 c10::SetCPUAllocator(prev_allocator_ptr_ , /*priority*/ 100);
522 prev_allocator_ptr_ = nullptr;
523 }
524
allowFP16ReductionCPU() const525 bool Context::allowFP16ReductionCPU() const {
526 return allow_fp16_reduction_cpu;
527 }
528
setAllowFP16ReductionCPU(bool b)529 void Context::setAllowFP16ReductionCPU(bool b) {
530 if ( b && !allow_fp16_reduction_cpu) {
531 // Check that CPU supports fp16 reductions
532 #if defined(__aarch64__) && !defined(C10_MOBILE)
533 if (!cpuinfo_initialize() || !cpuinfo_has_arm_fp16_arith())
534 #else
535 if (true)
536 #endif
537 throw std::runtime_error("Float16 arithmetic is not supported by the CPU!");
538 }
539 allow_fp16_reduction_cpu = b;
540 }
541 } // namespace at
542