1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch_v2.h>
4 #include <ATen/EmptyTensor.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/_local_scalar_dense_native.h>
10 #endif
11
12 #include <ATen/cuda/CUDAContext.h>
13
14 #if defined(USE_ROCM)
15 // TODO(lufang): Tensor.item() on AMD HIP is not synced in the Recsys models.
16 // This is just a short term workaround. Issue is tracked as FBA-388 on the AMD side.
17 namespace {
use_sync_mode()18 bool use_sync_mode() {
19 static const bool sync_mode = c10::utils::check_env("HIP_DOUBLE_SYNC_ON_LOCAL_SCALE_DENSE") == true;
20 return sync_mode;
21 }
22 }
23 #endif
24
25 namespace at::native {
26
_local_scalar_dense_cuda(const Tensor & self)27 Scalar _local_scalar_dense_cuda(const Tensor& self) {
28 Scalar r;
29 #if defined(USE_ROCM)
30 if (!use_sync_mode()){
31 #endif
32 AT_DISPATCH_V2(
33 self.scalar_type(), "_local_scalar_dense_cuda", AT_WRAP([&] {
34 // Create pinned memory for the scalar value to avoid implicit
35 // locking/sync in cuda library due to pageable memory
36 auto value = at::detail::empty_cpu(
37 {1}, /* size */
38 c10::CppTypeToScalarType<scalar_t>(), /* dtype */
39 std::nullopt, /* layout */
40 std::nullopt, /* device */
41 true, /* pin_memory */
42 std::nullopt /* memory format */
43 );
44 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
45 at::cuda::memcpy_and_sync((void *)value.const_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);
46 r = Scalar(*value.const_data_ptr<scalar_t>());
47 }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
48 #if defined(USE_ROCM)
49 } else {
50 auto cpu_self = self.cpu();
51 AT_DISPATCH_V2(
52 self.scalar_type(), "_local_scalar_dense_hip", AT_WRAP([&] {
53 r = Scalar(*cpu_self.const_data_ptr<scalar_t>());
54 }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
55 }
56 #endif
57 return r;
58 }
59
60 } // at::native
61