xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Fill.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Functions that fill Tensors with constants.
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 
4 #include <ATen/native/Fill.h>
5 #include <ATen/core/Tensor.h>
6 #include <ATen/ScalarOps.h>
7 #include <ATen/TensorIterator.h>
8 #include <ATen/TensorOperators.h>
9 #include <c10/util/accumulate.h>
10 #include <c10/util/irange.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/fill_diagonal_native.h>
17 #include <ATen/ops/fill_native.h>
18 #include <ATen/ops/ones.h>
19 #include <ATen/ops/zero_native.h>
20 #endif
21 
22 namespace at::native {
23 
24 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
fill_out(Tensor & self,const Scalar & value)25 Tensor& fill_out(Tensor& self, const Scalar& value) {
26   if (self.device() == at::kCPU && self.numel() == 1) {
27     return at::detail::scalar_fill(self, value);
28   }
29   auto iter = TensorIteratorConfig()
30     .set_check_mem_overlap(false)  // Fill is idempotent, so overlap is okay
31     .check_all_same_dtype(false)
32     .add_output(self)
33     .resize_outputs(false)
34     .build();
35   fill_stub(iter.device_type(), iter, value);
36   return self;
37 }
38 
fill_out_quantized(Tensor & self,const Scalar & value)39 static Tensor& fill_out_quantized(Tensor& self, const Scalar& value) {
40   at::Tensor out = at::ones(self.sizes()).to(kFloat) * value;
41   out = out.to(self.device()).to(self.suggest_memory_format());
42   // Trust the `copy_` to handle the quantization and the boundary checks.
43   self.copy_(out);
44   return self;
45 }
46 
fill_(Tensor & self,const Scalar & value)47 Tensor& fill_(Tensor& self, const Scalar& value) {
48   return fill_out(self, value);
49 }
50 
fill_quantized_(Tensor & self,const Scalar & value)51 Tensor& fill_quantized_(Tensor& self, const Scalar& value) {
52   return fill_out_quantized(self, value);
53 }
54 
fill_(Tensor & self,const Tensor & value)55 Tensor& fill_(Tensor& self, const Tensor& value) {
56   TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions.");
57   if (self.device() != value.device()){
58     return fill_out(self, value.item());
59   }
60   // Check if value is a view of self and if it is we clone
61   // it to avoid overwriting self prematurely
62   if(self.is_alias_of(value)) {
63     self.copy_(value.clone());
64   } else{
65     self.copy_(value);
66   }
67   return self;
68 }
69 
fill_quantized_(Tensor & self,const Tensor & value)70 Tensor& fill_quantized_(Tensor& self, const Tensor& value) {
71   TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions.");
72   return fill_out_quantized(self, value.item());
73 }
74 
fill_meta_(Tensor & self,const Scalar & value)75 Tensor& fill_meta_(Tensor& self, const Scalar& value) {
76   return self;
77 }
78 
fill_meta_(Tensor & self,const Tensor & value)79 Tensor& fill_meta_(Tensor& self, const Tensor& value) {
80   TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions.");
81   return self;
82 }
83 
fill(const Tensor & self,const Scalar & value)84 Tensor fill(const Tensor& self, const Scalar& value) {
85   return at::empty_like(self).fill_(value);
86 }
87 
fill(const Tensor & self,const Tensor & value)88 Tensor fill(const Tensor& self, const Tensor& value) {
89   return at::empty_like(self).fill_(value);
90 }
91 
92 DEFINE_DISPATCH(fill_stub);
93 
94 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill_diagonal ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
95 
fill_diagonal_(Tensor & self,const Scalar & fill_value,bool wrap)96 Tensor& fill_diagonal_(Tensor& self, const Scalar& fill_value, bool wrap) {
97   int64_t nDims = self.dim();
98   TORCH_CHECK(nDims >= 2, "dimensions must larger than 1");
99 
100   int64_t height = self.size(0);
101   int64_t width = self.size(1);
102 
103   if (nDims > 2) {
104     int64_t dim1 = height;
105     for (const auto i : c10::irange(1, nDims)) {
106       if (self.size(i) != dim1) {
107         AT_ERROR("all dimensions of input must be of equal length");
108       }
109     }
110   }
111 
112   int64_t storage_offset = self.storage_offset();
113   std::vector<int64_t> sizes;
114   std::vector<int64_t> strides;
115   int64_t size = std::min(height, width);
116 
117   int64_t stride = 0;
118   for (const auto i : c10::irange(nDims)) {
119     stride += self.stride(i);
120   }
121   strides.push_back(stride);
122   sizes.push_back(size);
123 
124   auto main_diag = self.as_strided(sizes, strides, storage_offset);
125   main_diag.fill_(fill_value);
126 
127   if (wrap && nDims == 2 && height > width + 1) {
128     std::vector<int64_t> wrap_sizes;
129 
130     int64_t step = width + 1;
131     int64_t wrap_size = ((self.numel() + step - 1) / step) - size;
132     wrap_sizes.push_back(wrap_size);
133 
134     int64_t offset = self.stride(0) * (width + 1);
135 
136     auto wrap_diag = self.as_strided(wrap_sizes, strides, storage_offset + offset);
137     wrap_diag.fill_(fill_value);
138   }
139 
140   return self;
141 }
142 
zero_cpu_(Tensor & self,int64_t nelements)143 static Tensor& zero_cpu_(Tensor &self, int64_t nelements) {
144   void* ptr = self.data_ptr();
145   if (nullptr == ptr) {
146     return self.fill_(0);
147   }
148   auto size_bytes = nelements * self.dtype().itemsize();
149   if (size_bytes > 0) {
150     std::memset(ptr, 0, size_bytes);
151   }
152   return self;
153 }
154 
zero_(Tensor & self)155 Tensor& zero_(Tensor &self) {
156   int64_t nelements = c10::multiply_integers(self.sizes());
157   if (self.device() == at::kCPU &&
158       self.is_non_overlapping_and_dense() &&
159       nelements < internal::GRAIN_SIZE) {
160     return zero_cpu_(self, nelements);
161   }
162   return self.fill_(0);
163 }
164 
zero_meta_(Tensor & self)165 Tensor& zero_meta_(Tensor& self) {
166   return self;
167 }
168 
169 } // namespace at::native
170