1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <ATen/Tensor.h> 4*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Scalar.h> 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker #ifndef AT_PER_OPERATOR_HEADERS 7*da0073e9SAndroid Build Coastguard Worker #include <ATen/Functions.h> 8*da0073e9SAndroid Build Coastguard Worker #else 9*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/scalar_tensor.h> 10*da0073e9SAndroid Build Coastguard Worker #endif 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker namespace at::detail { 13*da0073e9SAndroid Build Coastguard Worker // When filling a number to 1-element CPU tensor, we want to skip 14*da0073e9SAndroid Build Coastguard Worker // everything but manipulate data ptr directly. 15*da0073e9SAndroid Build Coastguard Worker // Ideally this fast pass should be implemented in TensorIterator, 16*da0073e9SAndroid Build Coastguard Worker // but we also want to skip compute_types which in not avoidable 17*da0073e9SAndroid Build Coastguard Worker // in TensorIterator for now. 18*da0073e9SAndroid Build Coastguard Worker Tensor& scalar_fill(Tensor& self, const Scalar& value); 19*da0073e9SAndroid Build Coastguard Worker TORCH_API Tensor scalar_tensor_static( 20*da0073e9SAndroid Build Coastguard Worker const Scalar& s, 21*da0073e9SAndroid Build Coastguard Worker std::optional<ScalarType> dtype_opt, 22*da0073e9SAndroid Build Coastguard Worker std::optional<Device> device_opt); 23*da0073e9SAndroid Build Coastguard Worker } // namespace at::detail 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker // This is in the c10 namespace because we use ADL to find the functions in it. 26*da0073e9SAndroid Build Coastguard Worker namespace c10 { 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker // FIXME: this should be (and was) Scalar::toTensor, but there is currently no 29*da0073e9SAndroid Build Coastguard Worker // way to implement this without going through Derived Types (which are not part 30*da0073e9SAndroid Build Coastguard Worker // of core). 31*da0073e9SAndroid Build Coastguard Worker inline at::Tensor scalar_to_tensor( 32*da0073e9SAndroid Build Coastguard Worker const Scalar& s, 33*da0073e9SAndroid Build Coastguard Worker const Device device = at::kCPU) { 34*da0073e9SAndroid Build Coastguard Worker // This is the fast track we have for CPU scalar tensors. 35*da0073e9SAndroid Build Coastguard Worker if (device == at::kCPU) { 36*da0073e9SAndroid Build Coastguard Worker return at::detail::scalar_tensor_static(s, s.type(), at::kCPU); 37*da0073e9SAndroid Build Coastguard Worker } 38*da0073e9SAndroid Build Coastguard Worker return at::scalar_tensor(s, at::device(device).dtype(s.type())); 39*da0073e9SAndroid Build Coastguard Worker } 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker } // namespace c10 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker namespace at::native { 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker inline Tensor wrapped_scalar_tensor( 46*da0073e9SAndroid Build Coastguard Worker const Scalar& scalar, 47*da0073e9SAndroid Build Coastguard Worker const Device device = at::kCPU) { 48*da0073e9SAndroid Build Coastguard Worker auto tensor = scalar_to_tensor(scalar, device); 49*da0073e9SAndroid Build Coastguard Worker tensor.unsafeGetTensorImpl()->set_wrapped_number(true); 50*da0073e9SAndroid Build Coastguard Worker return tensor; 51*da0073e9SAndroid Build Coastguard Worker } 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker } // namespace at::native 54