xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ScalarOps.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/Tensor.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Scalar.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #ifndef AT_PER_OPERATOR_HEADERS
7*da0073e9SAndroid Build Coastguard Worker #include <ATen/Functions.h>
8*da0073e9SAndroid Build Coastguard Worker #else
9*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/scalar_tensor.h>
10*da0073e9SAndroid Build Coastguard Worker #endif
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker namespace at::detail {
13*da0073e9SAndroid Build Coastguard Worker // When filling a number to 1-element CPU tensor, we want to skip
14*da0073e9SAndroid Build Coastguard Worker // everything but manipulate data ptr directly.
15*da0073e9SAndroid Build Coastguard Worker // Ideally this fast pass should be implemented in TensorIterator,
16*da0073e9SAndroid Build Coastguard Worker // but we also want to skip compute_types which in not avoidable
17*da0073e9SAndroid Build Coastguard Worker // in TensorIterator for now.
18*da0073e9SAndroid Build Coastguard Worker Tensor& scalar_fill(Tensor& self, const Scalar& value);
19*da0073e9SAndroid Build Coastguard Worker TORCH_API Tensor scalar_tensor_static(
20*da0073e9SAndroid Build Coastguard Worker     const Scalar& s,
21*da0073e9SAndroid Build Coastguard Worker     std::optional<ScalarType> dtype_opt,
22*da0073e9SAndroid Build Coastguard Worker     std::optional<Device> device_opt);
23*da0073e9SAndroid Build Coastguard Worker } // namespace at::detail
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker // This is in the c10 namespace because we use ADL to find the functions in it.
26*da0073e9SAndroid Build Coastguard Worker namespace c10 {
27*da0073e9SAndroid Build Coastguard Worker 
28*da0073e9SAndroid Build Coastguard Worker // FIXME: this should be (and was) Scalar::toTensor, but there is currently no
29*da0073e9SAndroid Build Coastguard Worker // way to implement this without going through Derived Types (which are not part
30*da0073e9SAndroid Build Coastguard Worker // of core).
31*da0073e9SAndroid Build Coastguard Worker inline at::Tensor scalar_to_tensor(
32*da0073e9SAndroid Build Coastguard Worker     const Scalar& s,
33*da0073e9SAndroid Build Coastguard Worker     const Device device = at::kCPU) {
34*da0073e9SAndroid Build Coastguard Worker   // This is the fast track we have for CPU scalar tensors.
35*da0073e9SAndroid Build Coastguard Worker   if (device == at::kCPU) {
36*da0073e9SAndroid Build Coastguard Worker     return at::detail::scalar_tensor_static(s, s.type(), at::kCPU);
37*da0073e9SAndroid Build Coastguard Worker   }
38*da0073e9SAndroid Build Coastguard Worker   return at::scalar_tensor(s, at::device(device).dtype(s.type()));
39*da0073e9SAndroid Build Coastguard Worker }
40*da0073e9SAndroid Build Coastguard Worker 
41*da0073e9SAndroid Build Coastguard Worker } // namespace c10
42*da0073e9SAndroid Build Coastguard Worker 
43*da0073e9SAndroid Build Coastguard Worker namespace at::native {
44*da0073e9SAndroid Build Coastguard Worker 
45*da0073e9SAndroid Build Coastguard Worker inline Tensor wrapped_scalar_tensor(
46*da0073e9SAndroid Build Coastguard Worker     const Scalar& scalar,
47*da0073e9SAndroid Build Coastguard Worker     const Device device = at::kCPU) {
48*da0073e9SAndroid Build Coastguard Worker   auto tensor = scalar_to_tensor(scalar, device);
49*da0073e9SAndroid Build Coastguard Worker   tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
50*da0073e9SAndroid Build Coastguard Worker   return tensor;
51*da0073e9SAndroid Build Coastguard Worker }
52*da0073e9SAndroid Build Coastguard Worker 
53*da0073e9SAndroid Build Coastguard Worker } // namespace at::native
54