1 #include <ATen/native/nested/NestedTensorMath.h>
2 #include <ATen/native/nested/NestedTensorBinaryOps.h>
3
4 #include <ATen/AccumulateType.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #include <ATen/NestedTensorImpl.h>
9 #include <ATen/ScalarOps.h>
10 #include <ATen/TensorIndexing.h>
11 #include <ATen/TensorOperators.h>
12 #include <ATen/TensorUtils.h>
13 #include <ATen/core/Tensor.h>
14 #include <ATen/native/layer_norm.h>
15 #include <ATen/native/nested/NestedTensorUtils.h>
16
17 namespace at::native {
18
19 DEFINE_DISPATCH(nested_dense_elementwise_stub);
20 REGISTER_NO_CPU_DISPATCH(nested_dense_elementwise_stub);
21
22 std::pair<NestedTensorImpl*, NestedTensorImpl*>
get_elementwise_nested_tensor_impl(const Tensor & self,const Tensor & other,const std::string & op_name)23 static get_elementwise_nested_tensor_impl(
24 const Tensor& self,
25 const Tensor& other,
26 const std::string& op_name) {
27 if (self.is_nested() && !(other.is_nested())) {
28 TORCH_CHECK(
29 false,
30 "Expected both self and other to be nested, but got a nested self and non-nested other");
31 } else if (!(self.is_nested()) && other.is_nested()) {
32 TORCH_CHECK(
33 false,
34 "Expected both self and other to be nested, but got a non-nested self and nested other");
35 } else if (!(self.is_nested()) || !(other.is_nested())) {
36 TORCH_CHECK(
37 false,
38 "Expected both self and other to be nested, but got a non-nested self and non-nested other");
39 }
40
41 auto self_ptr = get_nested_tensor_impl(self);
42 auto other_ptr = get_nested_tensor_impl(other);
43
44 TORCH_CHECK(
45 self.dim() == other.dim(),
46 op_name,
47 " does not support broadcasting when given a NestedTensor");
48 TORCH_CHECK(
49 at::equal(
50 self_ptr->get_nested_sizes(),
51 other_ptr->get_nested_sizes()),
52 op_name,
53 " does not support broadcasting when given a NestedTensor");
54 TORCH_CHECK(
55 at::equal(
56 self_ptr->get_nested_strides(),
57 other_ptr->get_nested_strides()),
58 op_name,
59 " requires strides to match when given NestedTensors");
60 const auto self_offsets = self_ptr->get_storage_offsets();
61 int64_t *self_offsets_ptr = self_offsets.data_ptr<int64_t>();
62 int64_t *other_offsets_ptr = other_ptr->get_storage_offsets().data_ptr<int64_t>();
63 bool offsets_match = true;
64 for (auto i = 0; i < self_offsets.size(0); i++) {
65 offsets_match = offsets_match && (self_offsets_ptr[i] == other_offsets_ptr[i]);
66 }
67 TORCH_CHECK(
68 offsets_match,
69 op_name,
70 " requires offsets to match when given NestedTensors");
71 return std::make_pair(self_ptr, other_ptr);
72 }
73
74 template <typename Func>
NestedTensor_elementwise_Tensor(const Tensor & self,const Tensor & other,const std::string & op_name,bool supports_striding,Func f)75 Tensor NestedTensor_elementwise_Tensor(
76 const Tensor& self,
77 const Tensor& other,
78 const std::string& op_name,
79 bool supports_striding,
80 Func f) {
81 Tensor self_contiguous = self;
82 Tensor other_contiguous = other;
83 // self is a scalar
84 if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) {
85 auto other_impl = get_nested_tensor_impl(other);
86 return wrap_buffer(
87 f(self, other_impl->get_unsafe_storage_as_tensor()),
88 other_impl->get_nested_sizes().clone(),
89 other_impl->get_nested_strides().clone(),
90 other_impl->get_storage_offsets()
91 );
92 }
93 // other is a scalar
94 if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) {
95 auto self_impl = get_nested_tensor_impl(self);
96 return wrap_buffer(
97 f(self_impl->get_unsafe_storage_as_tensor(), other),
98 self_impl->get_nested_sizes().clone(),
99 self_impl->get_nested_strides().clone(),
100 self_impl->get_storage_offsets()
101 );
102 }
103 // special case when other is dense (CUDA only for now)
104 if (self.is_nested() && !other.is_nested() && self.is_cuda() && other.is_cuda()) {
105 auto self_ptr = get_nested_tensor_impl(self);
106 auto other_ = other;
107 // check for the [B, *, D], [B, 1, D] case -> use custom kernel
108 // TODO: this if statement is ugly and hopefully we will remove this in the near future
109 bool is_broadcastable_3d = (
110 self_ptr->dim() == 3 &&
111 other.dim() == 3 &&
112 self_ptr->size(0) == other.size(0) &&
113 other.size(1) == 1 &&
114 self_ptr->opt_size(2).has_value() &&
115 self_ptr->opt_size(2).value() == other.size(2));
116 // check for the [B, *], [B, 1] case -> treat as 3D with [B, *, 1], [B, 1, 1]
117 bool is_broadcastable_2d = (
118 self_ptr->dim() == 2 &&
119 other.dim() == 2 &&
120 self_ptr->size(0) == other.size(0) &&
121 other.size(1) == 1);
122 if(is_broadcastable_2d) {
123 other_ = other.unsqueeze(-1);
124 is_broadcastable_3d = true;
125 }
126
127 if (is_broadcastable_3d) {
128 self_contiguous = self.contiguous();
129 self_ptr = get_nested_tensor_impl(self_contiguous);
130 const auto self_buffer = self_ptr->get_buffer();
131 const auto self_sizes = self_ptr->get_nested_sizes();
132 auto result_buffer = at::empty_like(self_buffer);
133 auto result = wrap_buffer(result_buffer, self_sizes);
134 if (op_name == "add") {
135 nested_dense_elementwise_stub(self.device().type(), result, self, other_, NESTED_DENSE_OP::ADD);
136 } else if (op_name == "mul") {
137 nested_dense_elementwise_stub(self.device().type(), result, self, other_, NESTED_DENSE_OP::MUL);
138 } else {
139 TORCH_CHECK(false, "Unsupported nested dense elementwise op: ", op_name, ".");
140 }
141 return result;
142 }
143
144 // check for the [B, C, *, *], [C, 1, 1] case
145 bool is_broadcastable_4d_3d = (
146 self_ptr->dim() == 4 &&
147 other.dim() == 3 &&
148 self_ptr->opt_size(1).has_value() &&
149 self_ptr->size(1) == other.size(0) &&
150 other.size(1) == 1 &&
151 other.size(2) == 1);
152 if (is_broadcastable_4d_3d) {
153 std::vector<Tensor> results;
154 for (const auto& t : self.unbind()) {
155 results.push_back(f(t, other));
156 }
157 return at::_nested_tensor_from_tensor_list(results);
158 }
159
160 TORCH_CHECK(
161 false,
162 "Expected both self and other to be nested, but got a nested self and non-nested other for op: ",
163 op_name,
164 ".");
165 }
166
167 self_contiguous = supports_striding ? self.contiguous() : self;
168 other_contiguous = supports_striding ? other.contiguous() : other;
169
170 auto [self_impl, other_impl] =
171 get_elementwise_nested_tensor_impl(self_contiguous, other_contiguous, op_name);
172 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl);
173 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl);
174 return wrap_buffer(
175 f(self_impl->get_unsafe_storage_as_tensor(),
176 other_impl->get_unsafe_storage_as_tensor()),
177 self_impl->get_nested_sizes(),
178 self_impl->get_nested_strides(),
179 self_impl->get_storage_offsets());
180 }
181
NestedTensor_add_Tensor(const Tensor & self,const Tensor & other,const Scalar & alpha)182 Tensor NestedTensor_add_Tensor(
183 const Tensor& self,
184 const Tensor& other,
185 const Scalar& alpha) {
186 return NestedTensor_elementwise_Tensor(
187 self, other, "add", true /* supports_striding*/, [alpha](const Tensor& b1, const Tensor& b2) {
188 return at::add(b1, b2, alpha);
189 });
190 }
191
NestedTensor_sub_Tensor(const Tensor & self,const Tensor & other,const Scalar & alpha)192 Tensor NestedTensor_sub_Tensor(
193 const Tensor& self,
194 const Tensor& other,
195 const Scalar& alpha) {
196 return NestedTensor_elementwise_Tensor(
197 self, other, "sub", true /* supports_striding*/, [alpha](const Tensor& b1, const Tensor& b2) {
198 return at::sub(b1, b2, alpha);
199 });
200 }
201
NestedTensor_mul_Tensor(const Tensor & self,const Tensor & other)202 Tensor NestedTensor_mul_Tensor(const Tensor& self, const Tensor& other) {
203 return NestedTensor_elementwise_Tensor(
204 self, other, "mul", false /* supports_striding*/, [](const Tensor& b1, const Tensor& b2) {
205 return at::mul(b1, b2);
206 });
207 }
208
209 // Only usable on the C++ side; scalars are converted to tensors coming from Python.
NestedTensor_mul_Scalar(const Tensor & self,const Scalar & other)210 Tensor NestedTensor_mul_Scalar(const Tensor& self, const Scalar& other) {
211 return NestedTensor_mul_Tensor(self, wrapped_scalar_tensor(other));
212 }
213
NestedTensor_div_Tensor(const Tensor & self,const Tensor & other)214 Tensor NestedTensor_div_Tensor(const Tensor& self, const Tensor& other) {
215 return NestedTensor_elementwise_Tensor(
216 self, other, "div", false /* supports_striding*/, [](const Tensor& b1, const Tensor& b2) {
217 return at::div(b1, b2);
218 });
219 }
220
221 // Only usable on the C++ side; scalars are converted to tensors coming from Python.
NestedTensor_div_Scalar(const Tensor & self,const Scalar & other)222 Tensor NestedTensor_div_Scalar(const Tensor& self, const Scalar& other) {
223 return NestedTensor_div_Tensor(self, wrapped_scalar_tensor(other));
224 }
NestedTensor_masked_fill(const Tensor & self,const Tensor & mask,const Scalar & value)225 Tensor NestedTensor_masked_fill(
226 const Tensor& self,
227 const Tensor& mask,
228 const Scalar& value) {
229 return NestedTensor_elementwise_Tensor(
230 self, mask, "masked_fill", false /* supports_striding*/, [value](const Tensor& b1, const Tensor& b2) {
231 return at::masked_fill(b1, b2, value);
232 });
233 }
234
235
236 template <typename Func>
NestedTensor_elementwise__Tensor(Tensor & self,const Tensor & other,const std::string & op_name,Func f)237 Tensor& NestedTensor_elementwise__Tensor(
238 Tensor& self,
239 const Tensor& other,
240 const std::string& op_name,
241 Func f) {
242 // self is a scalar
243 if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) {
244 auto other_impl = get_nested_tensor_impl(other);
245 f(self, other_impl->get_buffer());
246 return self;
247 }
248 // other is a scalar
249 if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) {
250 auto self_impl = get_nested_tensor_impl(self);
251 f(self_impl->get_buffer(), other);
252 return self;
253 }
254 auto [self_impl, other_impl] =
255 get_elementwise_nested_tensor_impl(self, other, op_name);
256 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl);
257 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl);
258 const auto& nt_self = *self_impl;
259 const auto& nt_other = *other_impl;
260 f(nt_self.get_buffer().view({-1}), nt_other.get_buffer().view({-1}));
261 return self;
262 }
263
NestedTensor_add__Tensor(Tensor & self,const Tensor & other,const Scalar & alpha)264 Tensor& NestedTensor_add__Tensor(
265 Tensor& self,
266 const Tensor& other,
267 const Scalar& alpha) {
268 return NestedTensor_elementwise__Tensor(
269 self, other, "add_", [alpha](const Tensor& b1, const Tensor& b2) {
270 return b1.add_(b2, alpha);
271 });
272 }
273
NestedTensor_mul__Tensor(Tensor & self,const Tensor & other)274 Tensor& NestedTensor_mul__Tensor(Tensor& self, const Tensor& other) {
275 return NestedTensor_elementwise__Tensor(
276 self, other, "mul_", [](const Tensor& b1, const Tensor& b2) {
277 return b1.mul_(b2);
278 });
279 }
280
281 // Only usable on the C++ side; scalars are converted to tensors coming from Python.
NestedTensor_mul__Scalar(Tensor & self,const Scalar & other)282 Tensor& NestedTensor_mul__Scalar(Tensor& self, const Scalar& other) {
283 return NestedTensor_mul__Tensor(self, wrapped_scalar_tensor(other));
284 }
285
fill_nested_(Tensor & self,const Scalar & value)286 Tensor& fill_nested_(Tensor& self, const Scalar& value) {
287 const auto& self_buf = get_nested_tensor_impl(self)->get_buffer();
288 self_buf.fill_(value);
289 return self;
290 }
291
fill_nested_(Tensor & self,const Tensor & value)292 Tensor& fill_nested_(Tensor& self, const Tensor& value) {
293 const auto& self_buf = get_nested_tensor_impl(self)->get_buffer();
294 self_buf.fill_(value);
295 return self;
296 }
297
ge_scalar_nested(const Tensor & self,const Scalar & other)298 Tensor ge_scalar_nested(const Tensor& self, const Scalar& other) {
299 return NestedTensor_elementwise_Tensor(
300 self, wrapped_scalar_tensor(other), "ge", false /*supports_striding*/,
301 [](const Tensor& b1, const Tensor& b2) {
302 return b1.ge(b2);
303 });
304 }
305
gt_scalar_nested(const Tensor & self,const Scalar & other)306 Tensor gt_scalar_nested(const Tensor& self, const Scalar& other) {
307 return NestedTensor_elementwise_Tensor(
308 self, wrapped_scalar_tensor(other), "gt", false /*supports_striding*/,
309 [](const Tensor& b1, const Tensor& b2) {
310 return b1.gt(b2);
311 });
312 }
313
eq_scalar_nested(const Tensor & self,const Scalar & other)314 Tensor eq_scalar_nested(const Tensor& self, const Scalar& other) {
315 return NestedTensor_elementwise_Tensor(
316 self, wrapped_scalar_tensor(other), "eq", false /*supports_striding*/,
317 [](const Tensor& b1, const Tensor& b2) {
318 return b1.eq(b2);
319 });
320 }
321
eq_tensor_nested(const Tensor & self,const Tensor & other)322 Tensor eq_tensor_nested(const Tensor& self, const Tensor& other) {
323 TORCH_CHECK(!other.is_nested(), "eq does not support nested tensor as other value.");
324 return NestedTensor_elementwise_Tensor(
325 self, other, "eq", false /*supports_striding*/,
326 [](const Tensor& b1, const Tensor& b2) {
327 return b1.eq(b2);
328 });
329 }
330
331 } // namespace at::native
332