xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/nested/NestedTensorMath.h>
2 #include  <ATen/native/nested/NestedTensorBinaryOps.h>
3 
4 #include <ATen/AccumulateType.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #include <ATen/NestedTensorImpl.h>
9 #include <ATen/ScalarOps.h>
10 #include <ATen/TensorIndexing.h>
11 #include <ATen/TensorOperators.h>
12 #include <ATen/TensorUtils.h>
13 #include <ATen/core/Tensor.h>
14 #include <ATen/native/layer_norm.h>
15 #include <ATen/native/nested/NestedTensorUtils.h>
16 
17 namespace at::native {
18 
19 DEFINE_DISPATCH(nested_dense_elementwise_stub);
20 REGISTER_NO_CPU_DISPATCH(nested_dense_elementwise_stub);
21 
22 std::pair<NestedTensorImpl*, NestedTensorImpl*>
get_elementwise_nested_tensor_impl(const Tensor & self,const Tensor & other,const std::string & op_name)23 static get_elementwise_nested_tensor_impl(
24     const Tensor& self,
25     const Tensor& other,
26     const std::string& op_name) {
27   if (self.is_nested() && !(other.is_nested())) {
28     TORCH_CHECK(
29         false,
30         "Expected both self and other to be nested, but got a nested self and non-nested other");
31   } else if (!(self.is_nested()) && other.is_nested()) {
32     TORCH_CHECK(
33         false,
34         "Expected both self and other to be nested, but got a non-nested self and nested other");
35   } else if (!(self.is_nested()) || !(other.is_nested())) {
36     TORCH_CHECK(
37         false,
38         "Expected both self and other to be nested, but got a non-nested self and non-nested other");
39   }
40 
41   auto self_ptr = get_nested_tensor_impl(self);
42   auto other_ptr = get_nested_tensor_impl(other);
43 
44   TORCH_CHECK(
45       self.dim() == other.dim(),
46       op_name,
47       " does not support broadcasting when given a NestedTensor");
48   TORCH_CHECK(
49       at::equal(
50           self_ptr->get_nested_sizes(),
51           other_ptr->get_nested_sizes()),
52       op_name,
53       " does not support broadcasting when given a NestedTensor");
54   TORCH_CHECK(
55       at::equal(
56           self_ptr->get_nested_strides(),
57           other_ptr->get_nested_strides()),
58       op_name,
59       " requires strides to match when given NestedTensors");
60   const auto self_offsets = self_ptr->get_storage_offsets();
61   int64_t *self_offsets_ptr = self_offsets.data_ptr<int64_t>();
62   int64_t *other_offsets_ptr = other_ptr->get_storage_offsets().data_ptr<int64_t>();
63   bool offsets_match = true;
64   for (auto i = 0; i < self_offsets.size(0); i++) {
65     offsets_match = offsets_match && (self_offsets_ptr[i] == other_offsets_ptr[i]);
66   }
67   TORCH_CHECK(
68       offsets_match,
69       op_name,
70       " requires offsets to match when given NestedTensors");
71   return std::make_pair(self_ptr, other_ptr);
72 }
73 
74 template <typename Func>
NestedTensor_elementwise_Tensor(const Tensor & self,const Tensor & other,const std::string & op_name,bool supports_striding,Func f)75 Tensor NestedTensor_elementwise_Tensor(
76     const Tensor& self,
77     const Tensor& other,
78     const std::string& op_name,
79     bool supports_striding,
80     Func f) {
81   Tensor self_contiguous = self;
82   Tensor other_contiguous = other;
83   // self is a scalar
84   if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) {
85     auto other_impl = get_nested_tensor_impl(other);
86     return wrap_buffer(
87       f(self, other_impl->get_unsafe_storage_as_tensor()),
88       other_impl->get_nested_sizes().clone(),
89       other_impl->get_nested_strides().clone(),
90       other_impl->get_storage_offsets()
91     );
92   }
93   // other is a scalar
94   if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) {
95     auto self_impl = get_nested_tensor_impl(self);
96     return wrap_buffer(
97       f(self_impl->get_unsafe_storage_as_tensor(), other),
98       self_impl->get_nested_sizes().clone(),
99       self_impl->get_nested_strides().clone(),
100       self_impl->get_storage_offsets()
101     );
102   }
103   // special case when other is dense (CUDA only for now)
104   if (self.is_nested() && !other.is_nested() && self.is_cuda() && other.is_cuda()) {
105     auto self_ptr = get_nested_tensor_impl(self);
106     auto other_ = other;
107     // check for the [B, *, D], [B, 1, D] case -> use custom kernel
108     // TODO: this if statement is ugly and hopefully we will remove this in the near future
109     bool is_broadcastable_3d = (
110         self_ptr->dim() == 3 &&
111         other.dim() == 3 &&
112         self_ptr->size(0) == other.size(0) &&
113         other.size(1) == 1 &&
114         self_ptr->opt_size(2).has_value() &&
115         self_ptr->opt_size(2).value() == other.size(2));
116     // check for the [B, *], [B, 1] case -> treat as 3D with [B, *, 1], [B, 1, 1]
117     bool is_broadcastable_2d = (
118         self_ptr->dim() == 2 &&
119         other.dim() == 2 &&
120         self_ptr->size(0) == other.size(0) &&
121         other.size(1) == 1);
122     if(is_broadcastable_2d) {
123         other_ = other.unsqueeze(-1);
124         is_broadcastable_3d = true;
125     }
126 
127     if (is_broadcastable_3d) {
128       self_contiguous = self.contiguous();
129       self_ptr = get_nested_tensor_impl(self_contiguous);
130       const auto self_buffer = self_ptr->get_buffer();
131       const auto self_sizes = self_ptr->get_nested_sizes();
132       auto result_buffer = at::empty_like(self_buffer);
133       auto result = wrap_buffer(result_buffer, self_sizes);
134       if (op_name == "add") {
135         nested_dense_elementwise_stub(self.device().type(), result, self, other_, NESTED_DENSE_OP::ADD);
136       } else if (op_name == "mul") {
137         nested_dense_elementwise_stub(self.device().type(), result, self, other_, NESTED_DENSE_OP::MUL);
138       } else {
139         TORCH_CHECK(false, "Unsupported nested dense elementwise op: ", op_name, ".");
140       }
141       return result;
142     }
143 
144     // check for the [B, C, *, *], [C, 1, 1] case
145     bool is_broadcastable_4d_3d = (
146         self_ptr->dim() == 4 &&
147         other.dim() == 3 &&
148         self_ptr->opt_size(1).has_value() &&
149         self_ptr->size(1) == other.size(0) &&
150         other.size(1) == 1 &&
151         other.size(2) == 1);
152     if (is_broadcastable_4d_3d) {
153       std::vector<Tensor> results;
154       for (const auto& t : self.unbind()) {
155         results.push_back(f(t, other));
156       }
157       return at::_nested_tensor_from_tensor_list(results);
158     }
159 
160     TORCH_CHECK(
161         false,
162         "Expected both self and other to be nested, but got a nested self and non-nested other for op: ",
163         op_name,
164         ".");
165   }
166 
167   self_contiguous = supports_striding ? self.contiguous() : self;
168   other_contiguous = supports_striding ? other.contiguous() : other;
169 
170   auto [self_impl, other_impl] =
171       get_elementwise_nested_tensor_impl(self_contiguous, other_contiguous, op_name);
172   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl);
173   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl);
174   return wrap_buffer(
175       f(self_impl->get_unsafe_storage_as_tensor(),
176         other_impl->get_unsafe_storage_as_tensor()),
177       self_impl->get_nested_sizes(),
178       self_impl->get_nested_strides(),
179       self_impl->get_storage_offsets());
180 }
181 
NestedTensor_add_Tensor(const Tensor & self,const Tensor & other,const Scalar & alpha)182 Tensor NestedTensor_add_Tensor(
183     const Tensor& self,
184     const Tensor& other,
185     const Scalar& alpha) {
186   return NestedTensor_elementwise_Tensor(
187       self, other, "add", true /* supports_striding*/, [alpha](const Tensor& b1, const Tensor& b2) {
188         return at::add(b1, b2, alpha);
189       });
190 }
191 
NestedTensor_sub_Tensor(const Tensor & self,const Tensor & other,const Scalar & alpha)192 Tensor NestedTensor_sub_Tensor(
193     const Tensor& self,
194     const Tensor& other,
195     const Scalar& alpha) {
196   return NestedTensor_elementwise_Tensor(
197       self, other, "sub", true /* supports_striding*/, [alpha](const Tensor& b1, const Tensor& b2) {
198         return at::sub(b1, b2, alpha);
199       });
200 }
201 
NestedTensor_mul_Tensor(const Tensor & self,const Tensor & other)202 Tensor NestedTensor_mul_Tensor(const Tensor& self, const Tensor& other) {
203   return NestedTensor_elementwise_Tensor(
204       self, other, "mul", false /* supports_striding*/, [](const Tensor& b1, const Tensor& b2) {
205         return at::mul(b1, b2);
206       });
207 }
208 
209 // Only usable on the C++ side; scalars are converted to tensors coming from Python.
NestedTensor_mul_Scalar(const Tensor & self,const Scalar & other)210 Tensor NestedTensor_mul_Scalar(const Tensor& self, const Scalar& other) {
211   return NestedTensor_mul_Tensor(self, wrapped_scalar_tensor(other));
212 }
213 
NestedTensor_div_Tensor(const Tensor & self,const Tensor & other)214 Tensor NestedTensor_div_Tensor(const Tensor& self, const Tensor& other) {
215   return NestedTensor_elementwise_Tensor(
216       self, other, "div", false /* supports_striding*/, [](const Tensor& b1, const Tensor& b2) {
217         return at::div(b1, b2);
218       });
219 }
220 
221 // Only usable on the C++ side; scalars are converted to tensors coming from Python.
NestedTensor_div_Scalar(const Tensor & self,const Scalar & other)222 Tensor NestedTensor_div_Scalar(const Tensor& self, const Scalar& other) {
223   return NestedTensor_div_Tensor(self, wrapped_scalar_tensor(other));
224 }
NestedTensor_masked_fill(const Tensor & self,const Tensor & mask,const Scalar & value)225 Tensor NestedTensor_masked_fill(
226     const Tensor& self,
227     const Tensor& mask,
228     const Scalar& value) {
229   return NestedTensor_elementwise_Tensor(
230       self, mask, "masked_fill", false /* supports_striding*/, [value](const Tensor& b1, const Tensor& b2) {
231         return at::masked_fill(b1, b2, value);
232       });
233 }
234 
235 
236 template <typename Func>
NestedTensor_elementwise__Tensor(Tensor & self,const Tensor & other,const std::string & op_name,Func f)237 Tensor& NestedTensor_elementwise__Tensor(
238     Tensor& self,
239     const Tensor& other,
240     const std::string& op_name,
241     Func f) {
242   // self is a scalar
243   if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) {
244     auto other_impl = get_nested_tensor_impl(other);
245     f(self, other_impl->get_buffer());
246     return self;
247   }
248   // other is a scalar
249   if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) {
250     auto self_impl = get_nested_tensor_impl(self);
251     f(self_impl->get_buffer(), other);
252     return self;
253   }
254   auto [self_impl, other_impl] =
255       get_elementwise_nested_tensor_impl(self, other, op_name);
256   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl);
257   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl);
258   const auto& nt_self = *self_impl;
259   const auto& nt_other = *other_impl;
260   f(nt_self.get_buffer().view({-1}), nt_other.get_buffer().view({-1}));
261   return self;
262 }
263 
NestedTensor_add__Tensor(Tensor & self,const Tensor & other,const Scalar & alpha)264 Tensor& NestedTensor_add__Tensor(
265     Tensor& self,
266     const Tensor& other,
267     const Scalar& alpha) {
268   return NestedTensor_elementwise__Tensor(
269       self, other, "add_", [alpha](const Tensor& b1, const Tensor& b2) {
270         return b1.add_(b2, alpha);
271       });
272 }
273 
NestedTensor_mul__Tensor(Tensor & self,const Tensor & other)274 Tensor& NestedTensor_mul__Tensor(Tensor& self, const Tensor& other) {
275   return NestedTensor_elementwise__Tensor(
276       self, other, "mul_", [](const Tensor& b1, const Tensor& b2) {
277         return b1.mul_(b2);
278       });
279 }
280 
281 // Only usable on the C++ side; scalars are converted to tensors coming from Python.
NestedTensor_mul__Scalar(Tensor & self,const Scalar & other)282 Tensor& NestedTensor_mul__Scalar(Tensor& self, const Scalar& other) {
283   return NestedTensor_mul__Tensor(self, wrapped_scalar_tensor(other));
284 }
285 
fill_nested_(Tensor & self,const Scalar & value)286 Tensor& fill_nested_(Tensor& self, const Scalar& value) {
287   const auto& self_buf = get_nested_tensor_impl(self)->get_buffer();
288   self_buf.fill_(value);
289   return self;
290 }
291 
fill_nested_(Tensor & self,const Tensor & value)292 Tensor& fill_nested_(Tensor& self, const Tensor& value) {
293   const auto& self_buf = get_nested_tensor_impl(self)->get_buffer();
294   self_buf.fill_(value);
295   return self;
296 }
297 
ge_scalar_nested(const Tensor & self,const Scalar & other)298 Tensor ge_scalar_nested(const Tensor& self, const Scalar& other) {
299   return NestedTensor_elementwise_Tensor(
300       self, wrapped_scalar_tensor(other), "ge", false /*supports_striding*/,
301       [](const Tensor& b1, const Tensor& b2) {
302         return b1.ge(b2);
303       });
304 }
305 
gt_scalar_nested(const Tensor & self,const Scalar & other)306 Tensor gt_scalar_nested(const Tensor& self, const Scalar& other) {
307   return NestedTensor_elementwise_Tensor(
308       self, wrapped_scalar_tensor(other), "gt", false /*supports_striding*/,
309       [](const Tensor& b1, const Tensor& b2) {
310         return b1.gt(b2);
311       });
312 }
313 
eq_scalar_nested(const Tensor & self,const Scalar & other)314 Tensor eq_scalar_nested(const Tensor& self, const Scalar& other) {
315   return NestedTensor_elementwise_Tensor(
316       self, wrapped_scalar_tensor(other), "eq", false /*supports_striding*/,
317       [](const Tensor& b1, const Tensor& b2) {
318         return b1.eq(b2);
319       });
320 }
321 
eq_tensor_nested(const Tensor & self,const Tensor & other)322 Tensor eq_tensor_nested(const Tensor& self, const Tensor& other) {
323   TORCH_CHECK(!other.is_nested(), "eq does not support nested tensor as other value.");
324   return NestedTensor_elementwise_Tensor(
325       self, other, "eq", false /*supports_striding*/,
326       [](const Tensor& b1, const Tensor& b2) {
327         return b1.eq(b2);
328       });
329 }
330 
331 } // namespace at::native
332