1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/BlasBackend.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/CPUGeneratorImpl.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/DeviceAccelerator.h>
6*da0073e9SAndroid Build Coastguard Worker #include <ATen/LinalgBackend.h>
7*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/ATenGeneral.h>
8*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/DeprecatedTypeProperties.h>
9*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/Generator.h>
10*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/LegacyTypeDispatch.h>
11*da0073e9SAndroid Build Coastguard Worker #include <ATen/detail/AcceleratorHooksInterface.h>
12*da0073e9SAndroid Build Coastguard Worker #include <ATen/detail/CUDAHooksInterface.h>
13*da0073e9SAndroid Build Coastguard Worker #include <ATen/detail/HIPHooksInterface.h>
14*da0073e9SAndroid Build Coastguard Worker #include <ATen/detail/IPUHooksInterface.h>
15*da0073e9SAndroid Build Coastguard Worker #include <ATen/detail/MAIAHooksInterface.h>
16*da0073e9SAndroid Build Coastguard Worker #include <ATen/detail/MPSHooksInterface.h>
17*da0073e9SAndroid Build Coastguard Worker #include <ATen/detail/MTIAHooksInterface.h>
18*da0073e9SAndroid Build Coastguard Worker #include <ATen/detail/PrivateUse1HooksInterface.h>
19*da0073e9SAndroid Build Coastguard Worker #include <ATen/detail/XPUHooksInterface.h>
20*da0073e9SAndroid Build Coastguard Worker #include <c10/core/QEngine.h>
21*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/DeviceGuardImplInterface.h>
22*da0073e9SAndroid Build Coastguard Worker #include <c10/util/CallOnce.h>
23*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
24*da0073e9SAndroid Build Coastguard Worker #include <c10/util/env.h>
25*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker #include <cstdint>
28*da0073e9SAndroid Build Coastguard Worker #include <mutex>
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker namespace at {
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker class Tensor;
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker class TORCH_API Context {
37*da0073e9SAndroid Build Coastguard Worker public:
38*da0073e9SAndroid Build Coastguard Worker Context();
39*da0073e9SAndroid Build Coastguard Worker
defaultGenerator(Device device)40*da0073e9SAndroid Build Coastguard Worker const Generator& defaultGenerator(Device device) {
41*da0073e9SAndroid Build Coastguard Worker c10::DeviceType device_type = device.type();
42*da0073e9SAndroid Build Coastguard Worker initCUDAIfNeeded(device_type);
43*da0073e9SAndroid Build Coastguard Worker initHIPIfNeeded(device_type);
44*da0073e9SAndroid Build Coastguard Worker if (device_type == at::kCPU) {
45*da0073e9SAndroid Build Coastguard Worker return at::detail::getDefaultCPUGenerator();
46*da0073e9SAndroid Build Coastguard Worker } else if (device_type == at::kCUDA) {
47*da0073e9SAndroid Build Coastguard Worker return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
48*da0073e9SAndroid Build Coastguard Worker } else if (device_type == at::kMPS) {
49*da0073e9SAndroid Build Coastguard Worker return at::detail::getMPSHooks().getDefaultMPSGenerator();
50*da0073e9SAndroid Build Coastguard Worker } else if (device_type == at::kXPU) {
51*da0073e9SAndroid Build Coastguard Worker return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
52*da0073e9SAndroid Build Coastguard Worker } else if (device_type == at::kIPU) {
53*da0073e9SAndroid Build Coastguard Worker return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
54*da0073e9SAndroid Build Coastguard Worker } else if (device_type == at::kPrivateUse1) {
55*da0073e9SAndroid Build Coastguard Worker return at::GetPrivateUse1HooksInterface()->getDefaultGenerator(
56*da0073e9SAndroid Build Coastguard Worker device.index());
57*da0073e9SAndroid Build Coastguard Worker } else {
58*da0073e9SAndroid Build Coastguard Worker AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
59*da0073e9SAndroid Build Coastguard Worker }
60*da0073e9SAndroid Build Coastguard Worker }
61*da0073e9SAndroid Build Coastguard Worker const AcceleratorHooksInterface& getAcceleratorHooksInterface(
62*da0073e9SAndroid Build Coastguard Worker std::optional<c10::DeviceType> opt_device_type = c10::nullopt) {
63*da0073e9SAndroid Build Coastguard Worker c10::DeviceType device_type = opt_device_type.has_value()
64*da0073e9SAndroid Build Coastguard Worker ? opt_device_type.value()
65*da0073e9SAndroid Build Coastguard Worker : at::getAccelerator(true).value();
66*da0073e9SAndroid Build Coastguard Worker if (device_type == at::kCUDA) {
67*da0073e9SAndroid Build Coastguard Worker return at::detail::getCUDAHooks();
68*da0073e9SAndroid Build Coastguard Worker } else if (device_type == at::kMPS) {
69*da0073e9SAndroid Build Coastguard Worker return at::detail::getMPSHooks();
70*da0073e9SAndroid Build Coastguard Worker } else if (device_type == at::kPrivateUse1) {
71*da0073e9SAndroid Build Coastguard Worker return at::detail::getPrivateUse1Hooks();
72*da0073e9SAndroid Build Coastguard Worker } else if (device_type == at::kMTIA) {
73*da0073e9SAndroid Build Coastguard Worker return at::detail::getMTIAHooks();
74*da0073e9SAndroid Build Coastguard Worker } else {
75*da0073e9SAndroid Build Coastguard Worker AT_ERROR(
76*da0073e9SAndroid Build Coastguard Worker c10::DeviceTypeName(device_type), " device type not an accelerator.");
77*da0073e9SAndroid Build Coastguard Worker }
78*da0073e9SAndroid Build Coastguard Worker }
getDeviceFromPtr(void * data,c10::DeviceType device_type)79*da0073e9SAndroid Build Coastguard Worker Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
80*da0073e9SAndroid Build Coastguard Worker initCUDAIfNeeded(device_type);
81*da0073e9SAndroid Build Coastguard Worker initHIPIfNeeded(device_type);
82*da0073e9SAndroid Build Coastguard Worker initXPUIfNeeded(device_type);
83*da0073e9SAndroid Build Coastguard Worker if (device_type == at::kCPU) {
84*da0073e9SAndroid Build Coastguard Worker return c10::DeviceType::CPU;
85*da0073e9SAndroid Build Coastguard Worker } else if (device_type == at::kCUDA) {
86*da0073e9SAndroid Build Coastguard Worker return at::detail::getCUDAHooks().getDeviceFromPtr(data);
87*da0073e9SAndroid Build Coastguard Worker } else if (device_type == at::kXPU) {
88*da0073e9SAndroid Build Coastguard Worker return at::detail::getXPUHooks().getDeviceFromPtr(data);
89*da0073e9SAndroid Build Coastguard Worker } else if (device_type == at::kPrivateUse1) {
90*da0073e9SAndroid Build Coastguard Worker return at::GetPrivateUse1HooksInterface()->getDeviceFromPtr(data);
91*da0073e9SAndroid Build Coastguard Worker } else {
92*da0073e9SAndroid Build Coastguard Worker AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
93*da0073e9SAndroid Build Coastguard Worker }
94*da0073e9SAndroid Build Coastguard Worker }
isPinnedPtr(const void * data)95*da0073e9SAndroid Build Coastguard Worker static bool isPinnedPtr(const void* data) {
96*da0073e9SAndroid Build Coastguard Worker return detail::getCUDAHooks().isPinnedPtr(data);
97*da0073e9SAndroid Build Coastguard Worker }
98*da0073e9SAndroid Build Coastguard Worker static bool hasOpenMP();
99*da0073e9SAndroid Build Coastguard Worker static bool hasMKL();
100*da0073e9SAndroid Build Coastguard Worker static bool hasLAPACK();
101*da0073e9SAndroid Build Coastguard Worker static bool hasMKLDNN();
hasMAGMA()102*da0073e9SAndroid Build Coastguard Worker static bool hasMAGMA() {
103*da0073e9SAndroid Build Coastguard Worker return detail::getCUDAHooks().hasMAGMA();
104*da0073e9SAndroid Build Coastguard Worker }
hasCUDA()105*da0073e9SAndroid Build Coastguard Worker static bool hasCUDA() {
106*da0073e9SAndroid Build Coastguard Worker return detail::getCUDAHooks().hasCUDA();
107*da0073e9SAndroid Build Coastguard Worker }
hasMTIA()108*da0073e9SAndroid Build Coastguard Worker static bool hasMTIA() {
109*da0073e9SAndroid Build Coastguard Worker return detail::getMTIAHooks().hasMTIA();
110*da0073e9SAndroid Build Coastguard Worker }
hasCUDART()111*da0073e9SAndroid Build Coastguard Worker static bool hasCUDART() {
112*da0073e9SAndroid Build Coastguard Worker return detail::getCUDAHooks().hasCUDART();
113*da0073e9SAndroid Build Coastguard Worker }
versionCUDART()114*da0073e9SAndroid Build Coastguard Worker static long versionCUDART() {
115*da0073e9SAndroid Build Coastguard Worker return detail::getCUDAHooks().versionCUDART();
116*da0073e9SAndroid Build Coastguard Worker }
hasCuDNN()117*da0073e9SAndroid Build Coastguard Worker static bool hasCuDNN() {
118*da0073e9SAndroid Build Coastguard Worker return detail::getCUDAHooks().hasCuDNN();
119*da0073e9SAndroid Build Coastguard Worker }
versionCuDNN()120*da0073e9SAndroid Build Coastguard Worker static long versionCuDNN() {
121*da0073e9SAndroid Build Coastguard Worker return detail::getCUDAHooks().versionCuDNN();
122*da0073e9SAndroid Build Coastguard Worker }
hasCuSOLVER()123*da0073e9SAndroid Build Coastguard Worker static bool hasCuSOLVER() {
124*da0073e9SAndroid Build Coastguard Worker return detail::getCUDAHooks().hasCuSOLVER();
125*da0073e9SAndroid Build Coastguard Worker }
hasCuBLASLt()126*da0073e9SAndroid Build Coastguard Worker static bool hasCuBLASLt() {
127*da0073e9SAndroid Build Coastguard Worker return detail::getCUDAHooks().hasCuBLASLt();
128*da0073e9SAndroid Build Coastguard Worker }
hasHIP()129*da0073e9SAndroid Build Coastguard Worker static bool hasHIP() {
130*da0073e9SAndroid Build Coastguard Worker return detail::getHIPHooks().hasHIP();
131*da0073e9SAndroid Build Coastguard Worker }
hasMPS()132*da0073e9SAndroid Build Coastguard Worker static bool hasMPS() {
133*da0073e9SAndroid Build Coastguard Worker return detail::getMPSHooks().hasMPS();
134*da0073e9SAndroid Build Coastguard Worker }
hasIPU()135*da0073e9SAndroid Build Coastguard Worker static bool hasIPU() {
136*da0073e9SAndroid Build Coastguard Worker return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
137*da0073e9SAndroid Build Coastguard Worker }
hasXLA()138*da0073e9SAndroid Build Coastguard Worker static bool hasXLA() {
139*da0073e9SAndroid Build Coastguard Worker return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
140*da0073e9SAndroid Build Coastguard Worker }
hasXPU()141*da0073e9SAndroid Build Coastguard Worker static bool hasXPU() {
142*da0073e9SAndroid Build Coastguard Worker return detail::getXPUHooks().hasXPU();
143*da0073e9SAndroid Build Coastguard Worker }
hasLazy()144*da0073e9SAndroid Build Coastguard Worker static bool hasLazy() {
145*da0073e9SAndroid Build Coastguard Worker return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
146*da0073e9SAndroid Build Coastguard Worker }
hasMAIA()147*da0073e9SAndroid Build Coastguard Worker static bool hasMAIA() {
148*da0073e9SAndroid Build Coastguard Worker return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
149*da0073e9SAndroid Build Coastguard Worker }
150*da0073e9SAndroid Build Coastguard Worker // defined in header so that getNonVariableType has ability to inline
151*da0073e9SAndroid Build Coastguard Worker // call_once check. getNonVariableType is called fairly frequently
lazyInitCUDA()152*da0073e9SAndroid Build Coastguard Worker void lazyInitCUDA() {
153*da0073e9SAndroid Build Coastguard Worker c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
154*da0073e9SAndroid Build Coastguard Worker }
lazyInitHIP()155*da0073e9SAndroid Build Coastguard Worker void lazyInitHIP() {
156*da0073e9SAndroid Build Coastguard Worker c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
157*da0073e9SAndroid Build Coastguard Worker }
lazyInitXPU()158*da0073e9SAndroid Build Coastguard Worker void lazyInitXPU() {
159*da0073e9SAndroid Build Coastguard Worker c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
160*da0073e9SAndroid Build Coastguard Worker }
lazyInitMTIA()161*da0073e9SAndroid Build Coastguard Worker void lazyInitMTIA() {
162*da0073e9SAndroid Build Coastguard Worker c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
163*da0073e9SAndroid Build Coastguard Worker }
lazyInitPrivateUse1()164*da0073e9SAndroid Build Coastguard Worker void lazyInitPrivateUse1() {
165*da0073e9SAndroid Build Coastguard Worker c10::call_once(thp_init, [&] {
166*da0073e9SAndroid Build Coastguard Worker if (isPrivateUse1HooksRegistered()) {
167*da0073e9SAndroid Build Coastguard Worker at::GetPrivateUse1HooksInterface()->initPrivateUse1();
168*da0073e9SAndroid Build Coastguard Worker }
169*da0073e9SAndroid Build Coastguard Worker });
170*da0073e9SAndroid Build Coastguard Worker }
getNVRTC()171*da0073e9SAndroid Build Coastguard Worker static const at::cuda::NVRTC& getNVRTC() {
172*da0073e9SAndroid Build Coastguard Worker return detail::getCUDAHooks().nvrtc();
173*da0073e9SAndroid Build Coastguard Worker }
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker static bool setFlushDenormal(bool on);
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker // NB: This method is *purely* whether or not a user requested
178*da0073e9SAndroid Build Coastguard Worker // that CuDNN was enabled, it doesn't actually say anything about
179*da0073e9SAndroid Build Coastguard Worker // whether or not CuDNN is actually usable. Use cudnn_is_acceptable
180*da0073e9SAndroid Build Coastguard Worker // to test this instead
181*da0073e9SAndroid Build Coastguard Worker bool userEnabledCuDNN() const;
182*da0073e9SAndroid Build Coastguard Worker void setUserEnabledCuDNN(bool e);
183*da0073e9SAndroid Build Coastguard Worker bool userEnabledMkldnn() const;
184*da0073e9SAndroid Build Coastguard Worker void setUserEnabledMkldnn(bool e);
185*da0073e9SAndroid Build Coastguard Worker bool benchmarkCuDNN() const;
186*da0073e9SAndroid Build Coastguard Worker void setBenchmarkCuDNN(bool);
187*da0073e9SAndroid Build Coastguard Worker int benchmarkLimitCuDNN() const;
188*da0073e9SAndroid Build Coastguard Worker void setBenchmarkLimitCuDNN(int);
189*da0073e9SAndroid Build Coastguard Worker bool deterministicCuDNN() const;
190*da0073e9SAndroid Build Coastguard Worker void setDeterministicCuDNN(bool);
191*da0073e9SAndroid Build Coastguard Worker bool userEnabledNNPACK() const;
192*da0073e9SAndroid Build Coastguard Worker void setUserEnabledNNPACK(bool e);
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker // Note [Disabling Fused SDP Kernels]
195*da0073e9SAndroid Build Coastguard Worker // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
196*da0073e9SAndroid Build Coastguard Worker // Flash and Memory Efficient SDP kernels are enabled by default.
197*da0073e9SAndroid Build Coastguard Worker // However, they can be disabled by setting
198*da0073e9SAndroid Build Coastguard Worker // at::globalContext().setUserEnabledFlashSDP(false) flag.
199*da0073e9SAndroid Build Coastguard Worker // This is useful for debugging purposes. For example, if you want to
200*da0073e9SAndroid Build Coastguard Worker // compare the performance of the flash SDP kernels with the unfused
201*da0073e9SAndroid Build Coastguard Worker // kernel, you can disable the flash SDP kernels. By disabling
202*da0073e9SAndroid Build Coastguard Worker // the math SDP kernel, you can force your code to use flash kernels.
203*da0073e9SAndroid Build Coastguard Worker // The math SDP kernel can be disabled by setting
204*da0073e9SAndroid Build Coastguard Worker // at::globalContext().setUserEnabledMathSDP(false) flag.
205*da0073e9SAndroid Build Coastguard Worker void setSDPUseFlash(bool);
206*da0073e9SAndroid Build Coastguard Worker bool userEnabledFlashSDP() const;
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker void setSDPUseMemEfficient(bool);
209*da0073e9SAndroid Build Coastguard Worker bool userEnabledMemEfficientSDP() const;
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker void setSDPUseMath(bool);
212*da0073e9SAndroid Build Coastguard Worker bool userEnabledMathSDP() const;
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker void setSDPUseCuDNN(bool);
215*da0073e9SAndroid Build Coastguard Worker bool userEnabledCuDNNSDP() const;
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker at::LinalgBackend linalgPreferredBackend() const;
218*da0073e9SAndroid Build Coastguard Worker void setLinalgPreferredBackend(at::LinalgBackend);
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker at::BlasBackend blasPreferredBackend();
221*da0073e9SAndroid Build Coastguard Worker void setBlasPreferredBackend(at::BlasBackend);
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker // Note [Enabling Deterministic Operations]
224*da0073e9SAndroid Build Coastguard Worker // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
225*da0073e9SAndroid Build Coastguard Worker // Operations in PyTorch that normally act nondeterministically, but have an
226*da0073e9SAndroid Build Coastguard Worker // alternate deterministic implementation, should satisfy the following
227*da0073e9SAndroid Build Coastguard Worker // requirements:
228*da0073e9SAndroid Build Coastguard Worker //
229*da0073e9SAndroid Build Coastguard Worker // * Include this comment: "See Note [Enabling Deterministic Operations]"
230*da0073e9SAndroid Build Coastguard Worker //
231*da0073e9SAndroid Build Coastguard Worker // * Check the value of `at::globalContext().deterministicAlgorithms()` to
232*da0073e9SAndroid Build Coastguard Worker // toggle
233*da0073e9SAndroid Build Coastguard Worker // between nondeterministic and deterministic implementations.
234*da0073e9SAndroid Build Coastguard Worker //
235*da0073e9SAndroid Build Coastguard Worker // * Have an entry in the list of PyTorch operations that toggle between
236*da0073e9SAndroid Build Coastguard Worker // nondeterministic
237*da0073e9SAndroid Build Coastguard Worker // and deterministic implementations, in the docstring of
238*da0073e9SAndroid Build Coastguard Worker // `use_deterministic_algorithms()` in torch/__init__.py
239*da0073e9SAndroid Build Coastguard Worker //
240*da0073e9SAndroid Build Coastguard Worker // `example_func()` below shows an example of toggling between
241*da0073e9SAndroid Build Coastguard Worker // nondeterministic and deterministic implementations:
242*da0073e9SAndroid Build Coastguard Worker //
243*da0073e9SAndroid Build Coastguard Worker // void example_func() {
244*da0073e9SAndroid Build Coastguard Worker // // See Note [Enabling Deterministic Operations]
245*da0073e9SAndroid Build Coastguard Worker // if (at::globalContext().deterministicAlgorithms()) {
246*da0073e9SAndroid Build Coastguard Worker // example_func_deterministic();
247*da0073e9SAndroid Build Coastguard Worker // } else {
248*da0073e9SAndroid Build Coastguard Worker // example_func_nondeterministic();
249*da0073e9SAndroid Build Coastguard Worker // }
250*da0073e9SAndroid Build Coastguard Worker // }
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker bool deterministicAlgorithms() const;
253*da0073e9SAndroid Build Coastguard Worker bool deterministicAlgorithmsWarnOnly() const;
254*da0073e9SAndroid Build Coastguard Worker void setDeterministicAlgorithms(bool, bool);
255*da0073e9SAndroid Build Coastguard Worker bool deterministicFillUninitializedMemory() const;
256*da0073e9SAndroid Build Coastguard Worker void setDeterministicFillUninitializedMemory(bool);
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker // Note [Writing Nondeterministic Operations]
259*da0073e9SAndroid Build Coastguard Worker // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
260*da0073e9SAndroid Build Coastguard Worker // Operations in PyTorch that act nondeterministically and do not have an
261*da0073e9SAndroid Build Coastguard Worker // alternate deterministic implementation should satisfy the following
262*da0073e9SAndroid Build Coastguard Worker // requirements:
263*da0073e9SAndroid Build Coastguard Worker //
264*da0073e9SAndroid Build Coastguard Worker // * Include this comment: "See Note [Writing Nondeterministic Operations]"
265*da0073e9SAndroid Build Coastguard Worker //
266*da0073e9SAndroid Build Coastguard Worker // * Include a comment explaining why the operation is nondeterministic.
267*da0073e9SAndroid Build Coastguard Worker //
268*da0073e9SAndroid Build Coastguard Worker // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
269*da0073e9SAndroid Build Coastguard Worker // of the time, this should be accomplished by calling
270*da0073e9SAndroid Build Coastguard Worker // `at::globalContext().alertNotDeterminstic()`. However, if the
271*da0073e9SAndroid Build Coastguard Worker // nondeterministic behavior is caused by the CuBLAS workspace
272*da0073e9SAndroid Build Coastguard Worker // configuration in CUDA >= 10.2,
273*da0073e9SAndroid Build Coastguard Worker // `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
274*da0073e9SAndroid Build Coastguard Worker // called instead (in this case, a comment explaining why the operation is
275*da0073e9SAndroid Build Coastguard Worker // nondeterministic is not necessary). See below for details on these
276*da0073e9SAndroid Build Coastguard Worker // methods.
277*da0073e9SAndroid Build Coastguard Worker //
278*da0073e9SAndroid Build Coastguard Worker // * Have an entry in the list of nondeterministic PyTorch operations in the
279*da0073e9SAndroid Build Coastguard Worker // docstring of `use_deterministic_algorithms()` in torch/__init__.py
280*da0073e9SAndroid Build Coastguard Worker //
281*da0073e9SAndroid Build Coastguard Worker // * Have a test function in `test/test_torch.py` whose name begins with
282*da0073e9SAndroid Build Coastguard Worker // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
283*da0073e9SAndroid Build Coastguard Worker // configuration is the reason for nondeterminism, the operation should be
284*da0073e9SAndroid Build Coastguard Worker // included in the `test_cublas_config_nondeterministic_alert` test. Any new
285*da0073e9SAndroid Build Coastguard Worker // tests should ideally follow a pattern similar to the existing ones.
286*da0073e9SAndroid Build Coastguard Worker //
287*da0073e9SAndroid Build Coastguard Worker // `example_func()` below shows an example of the comments and error-throwing
288*da0073e9SAndroid Build Coastguard Worker // code for a nondeterministic operation:
289*da0073e9SAndroid Build Coastguard Worker //
290*da0073e9SAndroid Build Coastguard Worker // void example_func() {
291*da0073e9SAndroid Build Coastguard Worker // // See Note [Writing Nondeterministic Operations]
292*da0073e9SAndroid Build Coastguard Worker // // Nondeterministic because <reason>
293*da0073e9SAndroid Build Coastguard Worker // at::globalContext().alertNondeterministic("example_func");
294*da0073e9SAndroid Build Coastguard Worker // ...
295*da0073e9SAndroid Build Coastguard Worker // }
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker // Throws an error if `Context::deterministicAlgorithms()` is true
298*da0073e9SAndroid Build Coastguard Worker static void alertNotDeterministic(c10::string_view const& caller);
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker // Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
301*da0073e9SAndroid Build Coastguard Worker // >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
302*da0073e9SAndroid Build Coastguard Worker // ":4096:8". For more details:
303*da0073e9SAndroid Build Coastguard Worker // https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
304*da0073e9SAndroid Build Coastguard Worker void alertCuBLASConfigNotDeterministic() const;
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker void setFloat32MatmulPrecision(const std::string& s);
307*da0073e9SAndroid Build Coastguard Worker bool allowTF32CuDNN() const;
308*da0073e9SAndroid Build Coastguard Worker void setAllowTF32CuDNN(bool);
309*da0073e9SAndroid Build Coastguard Worker bool allowTF32CuBLAS() const;
310*da0073e9SAndroid Build Coastguard Worker void setAllowTF32CuBLAS(bool);
311*da0073e9SAndroid Build Coastguard Worker Float32MatmulPrecision float32MatmulPrecision() const;
312*da0073e9SAndroid Build Coastguard Worker void setFloat32MatmulPrecision(Float32MatmulPrecision p);
313*da0073e9SAndroid Build Coastguard Worker bool allowFP16ReductionCuBLAS() const;
314*da0073e9SAndroid Build Coastguard Worker void setAllowFP16ReductionCuBLAS(bool);
315*da0073e9SAndroid Build Coastguard Worker bool allowBF16ReductionCuBLAS() const;
316*da0073e9SAndroid Build Coastguard Worker void setAllowBF16ReductionCuBLAS(bool);
317*da0073e9SAndroid Build Coastguard Worker at::QEngine qEngine() const;
318*da0073e9SAndroid Build Coastguard Worker void setQEngine(at::QEngine e);
319*da0073e9SAndroid Build Coastguard Worker static const std::vector<at::QEngine>& supportedQEngines();
320*da0073e9SAndroid Build Coastguard Worker static bool isXNNPACKAvailable();
321*da0073e9SAndroid Build Coastguard Worker void setCheckSparseTensorInvariants(bool e);
322*da0073e9SAndroid Build Coastguard Worker bool checkSparseTensorInvariants() const;
323*da0073e9SAndroid Build Coastguard Worker // This method is used to release the original weight after pre-packing.
324*da0073e9SAndroid Build Coastguard Worker // It should be called once before loading/running the model.
325*da0073e9SAndroid Build Coastguard Worker // NB: By default it is set to true for mobile builds.
326*da0073e9SAndroid Build Coastguard Worker void setReleaseWeightsWhenPrepacking(bool e);
327*da0073e9SAndroid Build Coastguard Worker bool releaseWeightsWhenPrepacking() const;
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker void setDisplayVmapFallbackWarnings(bool enabled);
330*da0073e9SAndroid Build Coastguard Worker bool areVmapFallbackWarningsEnabled() const;
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker void setDefaultMobileCPUAllocator();
333*da0073e9SAndroid Build Coastguard Worker void unsetDefaultMobileCPUAllocator();
334*da0073e9SAndroid Build Coastguard Worker bool allowFP16ReductionCPU() const;
335*da0073e9SAndroid Build Coastguard Worker void setAllowFP16ReductionCPU(bool);
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker private:
initCUDAIfNeeded(c10::DeviceType p)338*da0073e9SAndroid Build Coastguard Worker void initCUDAIfNeeded(c10::DeviceType p) {
339*da0073e9SAndroid Build Coastguard Worker if (p == c10::DeviceType::CUDA) {
340*da0073e9SAndroid Build Coastguard Worker lazyInitCUDA();
341*da0073e9SAndroid Build Coastguard Worker }
342*da0073e9SAndroid Build Coastguard Worker }
initHIPIfNeeded(c10::DeviceType p)343*da0073e9SAndroid Build Coastguard Worker void initHIPIfNeeded(c10::DeviceType p) {
344*da0073e9SAndroid Build Coastguard Worker if (p == c10::DeviceType::HIP) {
345*da0073e9SAndroid Build Coastguard Worker lazyInitHIP();
346*da0073e9SAndroid Build Coastguard Worker }
347*da0073e9SAndroid Build Coastguard Worker }
initXPUIfNeeded(c10::DeviceType p)348*da0073e9SAndroid Build Coastguard Worker void initXPUIfNeeded(c10::DeviceType p) {
349*da0073e9SAndroid Build Coastguard Worker if (p == c10::DeviceType::XPU) {
350*da0073e9SAndroid Build Coastguard Worker lazyInitXPU();
351*da0073e9SAndroid Build Coastguard Worker }
352*da0073e9SAndroid Build Coastguard Worker }
353*da0073e9SAndroid Build Coastguard Worker static bool checkCuBLASConfigDeterministic();
354*da0073e9SAndroid Build Coastguard Worker c10::once_flag thc_init;
355*da0073e9SAndroid Build Coastguard Worker c10::once_flag thh_init;
356*da0073e9SAndroid Build Coastguard Worker c10::once_flag thx_init;
357*da0073e9SAndroid Build Coastguard Worker c10::once_flag th_mtia_init;
358*da0073e9SAndroid Build Coastguard Worker c10::once_flag thp_init;
359*da0073e9SAndroid Build Coastguard Worker bool enabled_cudnn = true;
360*da0073e9SAndroid Build Coastguard Worker bool deterministic_cudnn = false;
361*da0073e9SAndroid Build Coastguard Worker bool _deterministic_algorithms = false;
362*da0073e9SAndroid Build Coastguard Worker bool _deterministic_algorithms_warn_only = false;
363*da0073e9SAndroid Build Coastguard Worker bool _deterministic_fill_uninitialized_memory = true;
364*da0073e9SAndroid Build Coastguard Worker bool enabled_flashSDP = true;
365*da0073e9SAndroid Build Coastguard Worker bool enabled_mem_efficientSDP = true;
366*da0073e9SAndroid Build Coastguard Worker bool enabled_mathSDP = true;
367*da0073e9SAndroid Build Coastguard Worker bool enabled_cudnnSDP = false;
368*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ROCM
369*da0073e9SAndroid Build Coastguard Worker bool benchmark_cudnn = true;
370*da0073e9SAndroid Build Coastguard Worker #else
371*da0073e9SAndroid Build Coastguard Worker bool benchmark_cudnn = false;
372*da0073e9SAndroid Build Coastguard Worker #endif
373*da0073e9SAndroid Build Coastguard Worker Float32MatmulPrecision float32_matmul_precision =
374*da0073e9SAndroid Build Coastguard Worker c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
375*da0073e9SAndroid Build Coastguard Worker ? at::Float32MatmulPrecision::HIGH
376*da0073e9SAndroid Build Coastguard Worker : at::Float32MatmulPrecision::HIGHEST;
377*da0073e9SAndroid Build Coastguard Worker int benchmark_limit_cudnn = 10;
378*da0073e9SAndroid Build Coastguard Worker bool allow_tf32_cudnn = true;
379*da0073e9SAndroid Build Coastguard Worker bool allow_fp16_reduction_cublas = true;
380*da0073e9SAndroid Build Coastguard Worker bool allow_bf16_reduction_cublas = true;
381*da0073e9SAndroid Build Coastguard Worker bool enabled_mkldnn = true;
382*da0073e9SAndroid Build Coastguard Worker bool enabled_nnpack = true;
383*da0073e9SAndroid Build Coastguard Worker at::LinalgBackend linalg_preferred_backend =
384*da0073e9SAndroid Build Coastguard Worker c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true
385*da0073e9SAndroid Build Coastguard Worker ? at::LinalgBackend::Cusolver
386*da0073e9SAndroid Build Coastguard Worker : at::LinalgBackend::Default;
387*da0073e9SAndroid Build Coastguard Worker at::BlasBackend blas_preferred_backend =
388*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ROCM
389*da0073e9SAndroid Build Coastguard Worker (c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false)
390*da0073e9SAndroid Build Coastguard Worker #else
391*da0073e9SAndroid Build Coastguard Worker (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true)
392*da0073e9SAndroid Build Coastguard Worker #endif
393*da0073e9SAndroid Build Coastguard Worker ? at::BlasBackend::Cublaslt
394*da0073e9SAndroid Build Coastguard Worker : at::BlasBackend::Cublas;
395*da0073e9SAndroid Build Coastguard Worker #ifdef C10_MOBILE
396*da0073e9SAndroid Build Coastguard Worker bool release_original_weights = true;
397*da0073e9SAndroid Build Coastguard Worker #else
398*da0073e9SAndroid Build Coastguard Worker bool release_original_weights = false;
399*da0073e9SAndroid Build Coastguard Worker #endif
400*da0073e9SAndroid Build Coastguard Worker bool display_vmap_fallback_warnings_ = false;
401*da0073e9SAndroid Build Coastguard Worker std::optional<at::QEngine> quantized_engine = c10::nullopt;
402*da0073e9SAndroid Build Coastguard Worker bool enable_sparse_tensor_invariant_checks = false;
403*da0073e9SAndroid Build Coastguard Worker bool allow_fp16_reduction_cpu = false;
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker Allocator* prev_allocator_ptr_{nullptr};
406*da0073e9SAndroid Build Coastguard Worker };
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker TORCH_API Context& globalContext();
409*da0073e9SAndroid Build Coastguard Worker
init()410*da0073e9SAndroid Build Coastguard Worker static inline void init() {
411*da0073e9SAndroid Build Coastguard Worker globalContext();
412*da0073e9SAndroid Build Coastguard Worker }
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker TORCH_API Allocator* getCPUAllocator();
415*da0073e9SAndroid Build Coastguard Worker
getDeprecatedTypeProperties(Backend p,ScalarType s)416*da0073e9SAndroid Build Coastguard Worker static inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
417*da0073e9SAndroid Build Coastguard Worker Backend p,
418*da0073e9SAndroid Build Coastguard Worker ScalarType s) {
419*da0073e9SAndroid Build Coastguard Worker return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
420*da0073e9SAndroid Build Coastguard Worker p, s);
421*da0073e9SAndroid Build Coastguard Worker }
422*da0073e9SAndroid Build Coastguard Worker
CPU(ScalarType s)423*da0073e9SAndroid Build Coastguard Worker static inline DeprecatedTypeProperties& CPU(ScalarType s) {
424*da0073e9SAndroid Build Coastguard Worker return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
425*da0073e9SAndroid Build Coastguard Worker Backend::CPU, s);
426*da0073e9SAndroid Build Coastguard Worker }
427*da0073e9SAndroid Build Coastguard Worker
CUDA(ScalarType s)428*da0073e9SAndroid Build Coastguard Worker static inline DeprecatedTypeProperties& CUDA(ScalarType s) {
429*da0073e9SAndroid Build Coastguard Worker return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
430*da0073e9SAndroid Build Coastguard Worker Backend::CUDA, s);
431*da0073e9SAndroid Build Coastguard Worker }
432*da0073e9SAndroid Build Coastguard Worker
HIP(ScalarType s)433*da0073e9SAndroid Build Coastguard Worker static inline DeprecatedTypeProperties& HIP(ScalarType s) {
434*da0073e9SAndroid Build Coastguard Worker return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
435*da0073e9SAndroid Build Coastguard Worker Backend::HIP, s);
436*da0073e9SAndroid Build Coastguard Worker }
437*da0073e9SAndroid Build Coastguard Worker
MPS(ScalarType s)438*da0073e9SAndroid Build Coastguard Worker static inline DeprecatedTypeProperties& MPS(ScalarType s) {
439*da0073e9SAndroid Build Coastguard Worker return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
440*da0073e9SAndroid Build Coastguard Worker Backend::MPS, s);
441*da0073e9SAndroid Build Coastguard Worker }
442*da0073e9SAndroid Build Coastguard Worker
hasCUDA()443*da0073e9SAndroid Build Coastguard Worker static inline bool hasCUDA() {
444*da0073e9SAndroid Build Coastguard Worker return globalContext().hasCUDA();
445*da0073e9SAndroid Build Coastguard Worker }
446*da0073e9SAndroid Build Coastguard Worker
hasMTIA()447*da0073e9SAndroid Build Coastguard Worker static inline bool hasMTIA() {
448*da0073e9SAndroid Build Coastguard Worker return globalContext().hasMTIA();
449*da0073e9SAndroid Build Coastguard Worker }
450*da0073e9SAndroid Build Coastguard Worker
hasHIP()451*da0073e9SAndroid Build Coastguard Worker static inline bool hasHIP() {
452*da0073e9SAndroid Build Coastguard Worker return globalContext().hasHIP();
453*da0073e9SAndroid Build Coastguard Worker }
454*da0073e9SAndroid Build Coastguard Worker
hasIPU()455*da0073e9SAndroid Build Coastguard Worker static inline bool hasIPU() {
456*da0073e9SAndroid Build Coastguard Worker return globalContext().hasIPU();
457*da0073e9SAndroid Build Coastguard Worker }
458*da0073e9SAndroid Build Coastguard Worker
hasXLA()459*da0073e9SAndroid Build Coastguard Worker static inline bool hasXLA() {
460*da0073e9SAndroid Build Coastguard Worker return globalContext().hasXLA();
461*da0073e9SAndroid Build Coastguard Worker }
462*da0073e9SAndroid Build Coastguard Worker
hasMPS()463*da0073e9SAndroid Build Coastguard Worker static inline bool hasMPS() {
464*da0073e9SAndroid Build Coastguard Worker return globalContext().hasMPS();
465*da0073e9SAndroid Build Coastguard Worker }
466*da0073e9SAndroid Build Coastguard Worker
hasMAIA()467*da0073e9SAndroid Build Coastguard Worker static inline bool hasMAIA() {
468*da0073e9SAndroid Build Coastguard Worker return globalContext().hasMAIA();
469*da0073e9SAndroid Build Coastguard Worker }
470*da0073e9SAndroid Build Coastguard Worker
hasXPU()471*da0073e9SAndroid Build Coastguard Worker static inline bool hasXPU() {
472*da0073e9SAndroid Build Coastguard Worker return globalContext().hasXPU();
473*da0073e9SAndroid Build Coastguard Worker }
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker // Despite its name, this function returns the number of *CUDA* GPUs.
getNumGPUs()476*da0073e9SAndroid Build Coastguard Worker static inline size_t getNumGPUs() {
477*da0073e9SAndroid Build Coastguard Worker // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
478*da0073e9SAndroid Build Coastguard Worker // FUNCTION. If you are interested in interrogating the number of
479*da0073e9SAndroid Build Coastguard Worker // devices for a specific device type, add that function to the
480*da0073e9SAndroid Build Coastguard Worker // relevant library (e.g., similar to at::cuda::device_count())
481*da0073e9SAndroid Build Coastguard Worker if (hasCUDA() && hasHIP()) {
482*da0073e9SAndroid Build Coastguard Worker throw std::runtime_error(
483*da0073e9SAndroid Build Coastguard Worker "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
484*da0073e9SAndroid Build Coastguard Worker "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
485*da0073e9SAndroid Build Coastguard Worker "means HIP. Rebuild PyTorch with one or the other disabled.");
486*da0073e9SAndroid Build Coastguard Worker } else if (hasCUDA()) {
487*da0073e9SAndroid Build Coastguard Worker return detail::getCUDAHooks().getNumGPUs();
488*da0073e9SAndroid Build Coastguard Worker } else if (hasHIP()) {
489*da0073e9SAndroid Build Coastguard Worker return detail::getHIPHooks().getNumGPUs();
490*da0073e9SAndroid Build Coastguard Worker } else {
491*da0073e9SAndroid Build Coastguard Worker return 0;
492*da0073e9SAndroid Build Coastguard Worker }
493*da0073e9SAndroid Build Coastguard Worker }
494*da0073e9SAndroid Build Coastguard Worker
hasOpenMP()495*da0073e9SAndroid Build Coastguard Worker static inline bool hasOpenMP() {
496*da0073e9SAndroid Build Coastguard Worker return globalContext().hasOpenMP();
497*da0073e9SAndroid Build Coastguard Worker }
498*da0073e9SAndroid Build Coastguard Worker
hasMKL()499*da0073e9SAndroid Build Coastguard Worker static inline bool hasMKL() {
500*da0073e9SAndroid Build Coastguard Worker return globalContext().hasMKL();
501*da0073e9SAndroid Build Coastguard Worker }
502*da0073e9SAndroid Build Coastguard Worker
hasLAPACK()503*da0073e9SAndroid Build Coastguard Worker static inline bool hasLAPACK() {
504*da0073e9SAndroid Build Coastguard Worker return globalContext().hasLAPACK();
505*da0073e9SAndroid Build Coastguard Worker }
506*da0073e9SAndroid Build Coastguard Worker
hasMAGMA()507*da0073e9SAndroid Build Coastguard Worker static inline bool hasMAGMA() {
508*da0073e9SAndroid Build Coastguard Worker return globalContext().hasMAGMA();
509*da0073e9SAndroid Build Coastguard Worker }
510*da0073e9SAndroid Build Coastguard Worker
hasMKLDNN()511*da0073e9SAndroid Build Coastguard Worker static inline bool hasMKLDNN() {
512*da0073e9SAndroid Build Coastguard Worker return globalContext().hasMKLDNN();
513*da0073e9SAndroid Build Coastguard Worker }
514*da0073e9SAndroid Build Coastguard Worker
manual_seed(uint64_t seed)515*da0073e9SAndroid Build Coastguard Worker static inline void manual_seed(uint64_t seed) {
516*da0073e9SAndroid Build Coastguard Worker auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
517*da0073e9SAndroid Build Coastguard Worker {
518*da0073e9SAndroid Build Coastguard Worker // See Note [Acquire lock when using random generators]
519*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock(gen.mutex());
520*da0073e9SAndroid Build Coastguard Worker gen.set_current_seed(seed);
521*da0073e9SAndroid Build Coastguard Worker }
522*da0073e9SAndroid Build Coastguard Worker // NB: Sometimes we build with CUDA, but we don't have any GPUs
523*da0073e9SAndroid Build Coastguard Worker // available. In that case, we must not seed CUDA; it will fail!
524*da0073e9SAndroid Build Coastguard Worker const auto cuda_num_gpus = detail::getCUDAHooks().getNumGPUs();
525*da0073e9SAndroid Build Coastguard Worker if (hasCUDA() && cuda_num_gpus > 0) {
526*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(cuda_num_gpus)) {
527*da0073e9SAndroid Build Coastguard Worker auto cuda_gen = globalContext().defaultGenerator(
528*da0073e9SAndroid Build Coastguard Worker Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
529*da0073e9SAndroid Build Coastguard Worker {
530*da0073e9SAndroid Build Coastguard Worker // See Note [Acquire lock when using random generators]
531*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock(cuda_gen.mutex());
532*da0073e9SAndroid Build Coastguard Worker cuda_gen.set_current_seed(seed);
533*da0073e9SAndroid Build Coastguard Worker }
534*da0073e9SAndroid Build Coastguard Worker }
535*da0073e9SAndroid Build Coastguard Worker }
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker const auto xpu_num_gpus = detail::getXPUHooks().getNumGPUs();
538*da0073e9SAndroid Build Coastguard Worker if (hasXPU() && xpu_num_gpus) {
539*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(xpu_num_gpus)) {
540*da0073e9SAndroid Build Coastguard Worker auto xpu_gen = globalContext().defaultGenerator(
541*da0073e9SAndroid Build Coastguard Worker Device(at::kXPU, static_cast<c10::DeviceIndex>(i)));
542*da0073e9SAndroid Build Coastguard Worker {
543*da0073e9SAndroid Build Coastguard Worker // See Note [Acquire lock when using random generators]
544*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock(xpu_gen.mutex());
545*da0073e9SAndroid Build Coastguard Worker xpu_gen.set_current_seed(seed);
546*da0073e9SAndroid Build Coastguard Worker }
547*da0073e9SAndroid Build Coastguard Worker }
548*da0073e9SAndroid Build Coastguard Worker }
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker if (hasMPS()) {
551*da0073e9SAndroid Build Coastguard Worker auto mps_gen = globalContext().defaultGenerator(c10::DeviceType::MPS);
552*da0073e9SAndroid Build Coastguard Worker // See Note [Acquire lock when using random generators]
553*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock(mps_gen.mutex());
554*da0073e9SAndroid Build Coastguard Worker mps_gen.set_current_seed(seed);
555*da0073e9SAndroid Build Coastguard Worker }
556*da0073e9SAndroid Build Coastguard Worker }
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker // When the global flag `allow_tf32` is set to true, cuBLAS handles are
559*da0073e9SAndroid Build Coastguard Worker // automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
560*da0073e9SAndroid Build Coastguard Worker // For some operators, such as addmv, TF32 offers no performance improvement
561*da0073e9SAndroid Build Coastguard Worker // but causes precision loss. To help this case, this class implements
562*da0073e9SAndroid Build Coastguard Worker // a RAII guard that can be used to quickly disable TF32 within its scope.
563*da0073e9SAndroid Build Coastguard Worker //
564*da0073e9SAndroid Build Coastguard Worker // Usage:
565*da0073e9SAndroid Build Coastguard Worker // NoTF32Guard disable_tf32;
566*da0073e9SAndroid Build Coastguard Worker struct TORCH_API NoTF32Guard {
567*da0073e9SAndroid Build Coastguard Worker NoTF32Guard();
568*da0073e9SAndroid Build Coastguard Worker ~NoTF32Guard();
569*da0073e9SAndroid Build Coastguard Worker static bool should_disable_tf32();
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker private:
572*da0073e9SAndroid Build Coastguard Worker bool changed = false;
573*da0073e9SAndroid Build Coastguard Worker };
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Worker struct TORCH_API ROCmBackwardPassGuard {
576*da0073e9SAndroid Build Coastguard Worker ROCmBackwardPassGuard();
577*da0073e9SAndroid Build Coastguard Worker ~ROCmBackwardPassGuard();
578*da0073e9SAndroid Build Coastguard Worker static bool is_backward_pass();
579*da0073e9SAndroid Build Coastguard Worker };
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker } // namespace at
582