xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ScalarOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/Dispatch.h>
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/Dispatch_v2.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/EmptyTensor.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/ScalarOps.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker namespace at {
8*da0073e9SAndroid Build Coastguard Worker namespace {
9*da0073e9SAndroid Build Coastguard Worker template <typename scalar_t>
fill_inplace(Tensor & self,const Scalar & value_scalar)10*da0073e9SAndroid Build Coastguard Worker inline void fill_inplace(Tensor& self, const Scalar& value_scalar) {
11*da0073e9SAndroid Build Coastguard Worker   auto value = value_scalar.to<scalar_t>();
12*da0073e9SAndroid Build Coastguard Worker   scalar_t* dptr = static_cast<scalar_t*>(self.data_ptr());
13*da0073e9SAndroid Build Coastguard Worker   *dptr = value;
14*da0073e9SAndroid Build Coastguard Worker }
15*da0073e9SAndroid Build Coastguard Worker }
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker namespace detail {
scalar_fill(Tensor & self,const Scalar & value)18*da0073e9SAndroid Build Coastguard Worker Tensor& scalar_fill(Tensor& self, const Scalar& value) {
19*da0073e9SAndroid Build Coastguard Worker   AT_DISPATCH_V2(
20*da0073e9SAndroid Build Coastguard Worker       self.scalar_type(), "fill_out", AT_WRAP([&]() {
21*da0073e9SAndroid Build Coastguard Worker         fill_inplace<scalar_t>(self, value);
22*da0073e9SAndroid Build Coastguard Worker       }), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
23*da0073e9SAndroid Build Coastguard Worker   return self;
24*da0073e9SAndroid Build Coastguard Worker }
25*da0073e9SAndroid Build Coastguard Worker 
scalar_tensor_static(const Scalar & s,std::optional<ScalarType> dtype_opt,std::optional<Device> device_opt)26*da0073e9SAndroid Build Coastguard Worker Tensor scalar_tensor_static(const Scalar& s, std::optional<ScalarType> dtype_opt, std::optional<Device> device_opt) {
27*da0073e9SAndroid Build Coastguard Worker   at::tracer::impl::NoTracerDispatchMode tracer_guard;
28*da0073e9SAndroid Build Coastguard Worker   at::AutoDispatchBelowAutograd mode;
29*da0073e9SAndroid Build Coastguard Worker   Tensor result = at::detail::empty_cpu(
30*da0073e9SAndroid Build Coastguard Worker       {}, dtype_opt, std::nullopt, device_opt, std::nullopt, std::nullopt);
31*da0073e9SAndroid Build Coastguard Worker   scalar_fill(result, s);
32*da0073e9SAndroid Build Coastguard Worker   return result;
33*da0073e9SAndroid Build Coastguard Worker }
34*da0073e9SAndroid Build Coastguard Worker } // namespace detail
35*da0073e9SAndroid Build Coastguard Worker } // namespace at
36