xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/CUDAScalar.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch_v2.h>
4 #include <ATen/EmptyTensor.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/_local_scalar_dense_native.h>
10 #endif
11 
12 #include <ATen/cuda/CUDAContext.h>
13 
14 #if defined(USE_ROCM)
15 // TODO(lufang): Tensor.item() on AMD HIP is not synced in the Recsys models.
16 // This is just a short term workaround. Issue is tracked as FBA-388 on the AMD side.
17 namespace {
use_sync_mode()18   bool use_sync_mode() {
19     static const bool sync_mode = c10::utils::check_env("HIP_DOUBLE_SYNC_ON_LOCAL_SCALE_DENSE") == true;
20     return sync_mode;
21   }
22 }
23 #endif
24 
25 namespace at::native {
26 
_local_scalar_dense_cuda(const Tensor & self)27 Scalar _local_scalar_dense_cuda(const Tensor& self) {
28   Scalar r;
29 #if defined(USE_ROCM)
30   if (!use_sync_mode()){
31 #endif
32     AT_DISPATCH_V2(
33       self.scalar_type(), "_local_scalar_dense_cuda", AT_WRAP([&] {
34           // Create pinned memory for the scalar value to avoid implicit
35           // locking/sync in cuda library due to pageable memory
36           auto value = at::detail::empty_cpu(
37             {1}, /* size */
38             c10::CppTypeToScalarType<scalar_t>(), /* dtype */
39             std::nullopt, /* layout */
40             std::nullopt, /* device */
41             true, /* pin_memory */
42             std::nullopt /* memory format */
43           );
44           cudaStream_t stream = at::cuda::getCurrentCUDAStream();
45           at::cuda::memcpy_and_sync((void *)value.const_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);
46           r = Scalar(*value.const_data_ptr<scalar_t>());
47         }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
48 #if defined(USE_ROCM)
49   } else {
50     auto cpu_self = self.cpu();
51     AT_DISPATCH_V2(
52       self.scalar_type(), "_local_scalar_dense_hip", AT_WRAP([&] {
53           r = Scalar(*cpu_self.const_data_ptr<scalar_t>());
54         }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
55   }
56 #endif
57   return r;
58 }
59 
60 } // at::native
61