1 // Functions that fill Tensors with constants.
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3
4 #include <ATen/native/Fill.h>
5 #include <ATen/core/Tensor.h>
6 #include <ATen/ScalarOps.h>
7 #include <ATen/TensorIterator.h>
8 #include <ATen/TensorOperators.h>
9 #include <c10/util/accumulate.h>
10 #include <c10/util/irange.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/fill_diagonal_native.h>
17 #include <ATen/ops/fill_native.h>
18 #include <ATen/ops/ones.h>
19 #include <ATen/ops/zero_native.h>
20 #endif
21
22 namespace at::native {
23
24 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
fill_out(Tensor & self,const Scalar & value)25 Tensor& fill_out(Tensor& self, const Scalar& value) {
26 if (self.device() == at::kCPU && self.numel() == 1) {
27 return at::detail::scalar_fill(self, value);
28 }
29 auto iter = TensorIteratorConfig()
30 .set_check_mem_overlap(false) // Fill is idempotent, so overlap is okay
31 .check_all_same_dtype(false)
32 .add_output(self)
33 .resize_outputs(false)
34 .build();
35 fill_stub(iter.device_type(), iter, value);
36 return self;
37 }
38
fill_out_quantized(Tensor & self,const Scalar & value)39 static Tensor& fill_out_quantized(Tensor& self, const Scalar& value) {
40 at::Tensor out = at::ones(self.sizes()).to(kFloat) * value;
41 out = out.to(self.device()).to(self.suggest_memory_format());
42 // Trust the `copy_` to handle the quantization and the boundary checks.
43 self.copy_(out);
44 return self;
45 }
46
fill_(Tensor & self,const Scalar & value)47 Tensor& fill_(Tensor& self, const Scalar& value) {
48 return fill_out(self, value);
49 }
50
fill_quantized_(Tensor & self,const Scalar & value)51 Tensor& fill_quantized_(Tensor& self, const Scalar& value) {
52 return fill_out_quantized(self, value);
53 }
54
fill_(Tensor & self,const Tensor & value)55 Tensor& fill_(Tensor& self, const Tensor& value) {
56 TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions.");
57 if (self.device() != value.device()){
58 return fill_out(self, value.item());
59 }
60 // Check if value is a view of self and if it is we clone
61 // it to avoid overwriting self prematurely
62 if(self.is_alias_of(value)) {
63 self.copy_(value.clone());
64 } else{
65 self.copy_(value);
66 }
67 return self;
68 }
69
fill_quantized_(Tensor & self,const Tensor & value)70 Tensor& fill_quantized_(Tensor& self, const Tensor& value) {
71 TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions.");
72 return fill_out_quantized(self, value.item());
73 }
74
fill_meta_(Tensor & self,const Scalar & value)75 Tensor& fill_meta_(Tensor& self, const Scalar& value) {
76 return self;
77 }
78
fill_meta_(Tensor & self,const Tensor & value)79 Tensor& fill_meta_(Tensor& self, const Tensor& value) {
80 TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions.");
81 return self;
82 }
83
fill(const Tensor & self,const Scalar & value)84 Tensor fill(const Tensor& self, const Scalar& value) {
85 return at::empty_like(self).fill_(value);
86 }
87
fill(const Tensor & self,const Tensor & value)88 Tensor fill(const Tensor& self, const Tensor& value) {
89 return at::empty_like(self).fill_(value);
90 }
91
92 DEFINE_DISPATCH(fill_stub);
93
94 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill_diagonal ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
95
fill_diagonal_(Tensor & self,const Scalar & fill_value,bool wrap)96 Tensor& fill_diagonal_(Tensor& self, const Scalar& fill_value, bool wrap) {
97 int64_t nDims = self.dim();
98 TORCH_CHECK(nDims >= 2, "dimensions must larger than 1");
99
100 int64_t height = self.size(0);
101 int64_t width = self.size(1);
102
103 if (nDims > 2) {
104 int64_t dim1 = height;
105 for (const auto i : c10::irange(1, nDims)) {
106 if (self.size(i) != dim1) {
107 AT_ERROR("all dimensions of input must be of equal length");
108 }
109 }
110 }
111
112 int64_t storage_offset = self.storage_offset();
113 std::vector<int64_t> sizes;
114 std::vector<int64_t> strides;
115 int64_t size = std::min(height, width);
116
117 int64_t stride = 0;
118 for (const auto i : c10::irange(nDims)) {
119 stride += self.stride(i);
120 }
121 strides.push_back(stride);
122 sizes.push_back(size);
123
124 auto main_diag = self.as_strided(sizes, strides, storage_offset);
125 main_diag.fill_(fill_value);
126
127 if (wrap && nDims == 2 && height > width + 1) {
128 std::vector<int64_t> wrap_sizes;
129
130 int64_t step = width + 1;
131 int64_t wrap_size = ((self.numel() + step - 1) / step) - size;
132 wrap_sizes.push_back(wrap_size);
133
134 int64_t offset = self.stride(0) * (width + 1);
135
136 auto wrap_diag = self.as_strided(wrap_sizes, strides, storage_offset + offset);
137 wrap_diag.fill_(fill_value);
138 }
139
140 return self;
141 }
142
zero_cpu_(Tensor & self,int64_t nelements)143 static Tensor& zero_cpu_(Tensor &self, int64_t nelements) {
144 void* ptr = self.data_ptr();
145 if (nullptr == ptr) {
146 return self.fill_(0);
147 }
148 auto size_bytes = nelements * self.dtype().itemsize();
149 if (size_bytes > 0) {
150 std::memset(ptr, 0, size_bytes);
151 }
152 return self;
153 }
154
zero_(Tensor & self)155 Tensor& zero_(Tensor &self) {
156 int64_t nelements = c10::multiply_integers(self.sizes());
157 if (self.device() == at::kCPU &&
158 self.is_non_overlapping_and_dense() &&
159 nelements < internal::GRAIN_SIZE) {
160 return zero_cpu_(self, nelements);
161 }
162 return self.fill_(0);
163 }
164
zero_meta_(Tensor & self)165 Tensor& zero_meta_(Tensor& self) {
166 return self;
167 }
168
169 } // namespace at::native
170