1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13
14 namespace torch {
15 namespace executor {
16
17 /**
18 * Check whether or not the broadcast_from_shape can be broadcasted onto the
19 * broadcast_to_shape.
20 *
21 * @param[in] broadcast_from_shape The tensor shape which we want to broadcast.
22 * @param[in] broadcast_to_shape The tensor shape which we want to broadcast to.
23 * @returns A bool to indicate whether or not the shape can be broadcasted.
24 *
25 */
26 bool tensor_is_broadcastable_to(
27 const exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
28 const exec_aten::ArrayRef<Tensor::SizesType> broadcast_to_shape);
29
30 /**
31 * Check whether or not the broadcast_from tensor should and can be broadcasted
32 * onto the broadcast_to tensor. broadcast_tensor should only be called if this
33 * returns true.
34 *
35 * @param[in] broadcast_from The tensor which we want to broadcast from.
36 * @param[in] broadcast_to The tensor to which we want to broadcast to.
37 * @returns A bool to indicate whether or not the tensor can be broadcasted.
38 *
39 */
40 bool tensor_is_broadcastable_to(
41 const Tensor& broadcast_from,
42 const Tensor& broadcast_to);
43
44 /**
45 * Returns true if the two tensor shapes can both be broadcasted to a common
46 * shape.
47 *
48 * @param[in] a_shape The sizes of the first tensor going to be test.
49 * @param[in] b_shape The sizes of the second tensor going to be test.
50 * @returns true if the tensors are broadcastable, false otherwise.
51 */
52 bool tensors_are_broadcastable_between(
53 const exec_aten::ArrayRef<Tensor::SizesType> a_shape,
54 const exec_aten::ArrayRef<Tensor::SizesType> b_shape);
55
56 /**
57 * Convenience overload of the above function to accept Tensor inputs.
58 *
59 * @param[in] a The first tensor going to be test.
60 * @param[in] b The second tensor going to be test.
61 * @returns true if the tensors are broadcastable, false otherwise.
62 */
63 bool tensors_are_broadcastable_between(const Tensor& a, const Tensor& b);
64
65 /**
66 * DEPRECATED: Use `delinearize_index()` and `linearize_access_indexes()` for
67 * index remapping to avoid memory allocation.
68 *
69 * The smaller tensor broadcast_from is “broadcast” across the larger tensor
70 * broadcast_to so that they have compatible shapes.
71 * broadcast_to_shape.size() >= broadcast_from_shape.size() in order for this
72 * to work.
73 *
74 * @param[in] broadcast_from The tensor to which we want to broadcast from.
75 * @param[in] broadcast_to The tensor to which we want to broadcast to.
76 * @returns A new tensor with the same shape as broadcast_to and the data
77 * repeated as appropriate. This tensor contains dynamically allocated memory
78 * and must be freed using free_broadcast_tensor.
79 */
80 ET_DEPRECATED exec_aten::Tensor broadcast_tensor(
81 const exec_aten::Tensor& broadcast_from,
82 const exec_aten::Tensor& broadcast_to);
83
84 /**
85 * Get the size of the target tensor that two input tensors would be broadcasted
86 * to.
87 *
88 * This function is useful especially for the operator supporting both broadcast
89 * and dynamic shape. At that time there may not be a tensor having the size of
90 * final output, so we need to calculate it.
91 *
92 * @param[in] a_size The size of the first tensor going to be broadcasted.
93 * @param[in] b_size The size of the second tensor going to be broadcasted.
94 * @param[out] out_sizes The memory space storing the size of
95 * broadcasted target tensor
96 * @param[in] out_sizes_len The largest number of element
97 * @param[out] out_dim The dimension of the broadcasted target
98 * tensor
99 */
100 ET_NODISCARD Error get_broadcast_target_size(
101 const exec_aten::ArrayRef<Tensor::SizesType> a_size,
102 const exec_aten::ArrayRef<Tensor::SizesType> b_size,
103 Tensor::SizesType* out_sizes,
104 const size_t out_sizes_len,
105 size_t* out_dim);
106
107 /**
108 * Convenience overload of the above function to accept Tensor inputs.
109 *
110 * @param[in] a The first tensor going to be broadcasted.
111 * @param[in] b The second tensor going to be broadcasted.
112 * @param[out] out_sizes The memory space storing the size of
113 * broadcasted target tensor
114 * @param[in] out_sizes_len The largest number of element
115 * @param[out] out_dim The dimension of the broadcasted target
116 * tensor
117 */
118 ET_NODISCARD Error get_broadcast_target_size(
119 const Tensor& a,
120 const Tensor& b,
121 Tensor::SizesType* out_sizes,
122 const size_t out_sizes_len,
123 size_t* out_dim);
124
125 /**
126 * Get the size that two input tensors will be broadcasted to, and resize an
127 * output tensor to the resulting broadcasted size.
128 *
129 * @param[in] a The first tensor going to be broadcasted.
130 * @param[in] b The second tensor going to be broadcasted.
131 * @param[out] out The output tensor that will be resized.
132 */
133 ET_NODISCARD inline Error
resize_to_broadcast_target_size(const Tensor & a,const Tensor & b,Tensor & out)134 resize_to_broadcast_target_size(const Tensor& a, const Tensor& b, Tensor& out) {
135 Tensor::SizesType expected_output_size[kTensorDimensionLimit];
136 size_t expected_output_dim = 0;
137
138 ET_CHECK_OK_OR_RETURN_ERROR(
139 get_broadcast_target_size(
140 a,
141 b,
142 expected_output_size,
143 kTensorDimensionLimit,
144 &expected_output_dim),
145 "Failed to get broadcast target size");
146
147 return resize_tensor(out, {expected_output_size, expected_output_dim});
148 }
149
150 /**
151 * Get the size that three input tensors will be broadcasted to, and resize an
152 * output tensor to the resulting broadcasted size.
153 *
154 * @param[in] a The first tensor going to be broadcasted.
155 * @param[in] b The second tensor going to be broadcasted.
156 * @param[in] c The third tensor going to be broadcasted.
157 * @param[out] out The output tensor that will be resized.
158 */
resize_to_broadcast_target_size(const Tensor & a,const Tensor & b,const Tensor & c,Tensor & out)159 ET_NODISCARD inline Error resize_to_broadcast_target_size(
160 const Tensor& a,
161 const Tensor& b,
162 const Tensor& c,
163 Tensor& out) {
164 Tensor::SizesType interim_output_size[kTensorDimensionLimit];
165 size_t interim_output_dim = 0;
166
167 // Obtain the broadcast size of the first two input tensors
168 ET_CHECK_OK_OR_RETURN_ERROR(
169 get_broadcast_target_size(
170 a,
171 b,
172 interim_output_size,
173 kTensorDimensionLimit,
174 &interim_output_dim),
175 "Failed to get broadcast target size");
176
177 Tensor::SizesType expected_output_size[kTensorDimensionLimit];
178 size_t expected_output_dim = 0;
179
180 // Apply broadcasting to the intermediate broadcast size and the third input
181 // tensor
182 ET_CHECK_OK_OR_RETURN_ERROR(
183 get_broadcast_target_size(
184 {interim_output_size, interim_output_dim},
185 c.sizes(),
186 expected_output_size,
187 kTensorDimensionLimit,
188 &expected_output_dim),
189 "Failed to get broadcast target size");
190
191 return resize_tensor(out, {expected_output_size, expected_output_dim});
192 }
193
194 /**
195 * DEPRECATED: Use `delinearize_index()` and `linearize_access_indexes()` for
196 * index remapping to avoid memory allocation.
197 *
198 * Free the dynamically allocated memory in broadcast_tensor. This should only
199 * be used on a tensor returned by broadcast_tensor.
200 *
201 * @param[in] The tensor that was previosuly returned by a call to
202 * broadcast_tensor.
203 * @returns void
204 */
205 ET_DEPRECATED void free_broadcast_tensor(
206 const exec_aten::Tensor& broadcast_tensor);
207
208 /**
209 * Delinearize a flattened index to per-dimension indexes.
210 *
211 * @param[in] linear_index The flattened index
212 * @param[in] shape The tensor shape
213 * @param[out] out_indexes The per-dimension indexes
214 * @param[in] out_indexes_len The maximum size of the out_indexes array
215 * @returns void
216 */
217 void delinearize_index(
218 size_t linear_index,
219 exec_aten::ArrayRef<Tensor::SizesType> shape,
220 size_t* out_indexes,
221 const size_t out_indexes_len);
222
223 /**
224 * Delinearize a flattened index to per-dimension indexes.
225 *
226 * @param[in] linear_index The flattened index
227 * @param[in] t The tensor object
228 * @param[out] out_indexes The per-dimension indexes
229 * @param[in] out_indexes_len The maximum size of the out_indexes array
230 * @returns void
231 */
232 void delinearize_index(
233 size_t linear_index,
234 const Tensor& t,
235 size_t* out_indexes,
236 const size_t out_indexes_len);
237
238 /**
239 * Return the linear index for broatcast_from tensor, given the indexes and
240 * number of dimensions of broadcast_to tensor, and the shape and strides
241 * of broadcast_from tensor.
242 *
243 * @param[in] indexes_broadcast_to The access indexes of broadcast_to tensor.
244 * @param[in] broadcast_to_ndim The number of dims of broadcast_to tensor.
245 * @param[in] broadcast_from_shape The shape of the broadcasted tensor.
246 * @param[in] broadcast_from_strides The strides of the broadcasted tensor.
247 * @returns The flattend index for broadcast_from tensor.
248 */
249 size_t linearize_access_indexes(
250 ArrayRef<size_t> indexes_broadcast_to,
251 ssize_t broadcast_to_ndim,
252 exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
253 exec_aten::ArrayRef<Tensor::StridesType> broadcast_from_strides);
254
255 /**
256 * Return the linear index for broatcast_from tensor, given the indexes of
257 * broadcast_to tensor and itself.
258 *
259 * @param[in] indexes_broadcast_to The access indexes of broadcast_to tensor.
260 * @param[in] broadcast_to_ndim The number of dims of broadcast_to tensor.
261 * @param[in] broadcast_from The tensor to be broadcasted.
262 * @returns The flattend index for broadcast_from tensor.
263 */
264 size_t linearize_access_indexes(
265 ArrayRef<size_t> indexes_broadcast_to,
266 ssize_t broadcast_to_ndim,
267 const Tensor& broadcast_from);
268
269 //
270 // Mapping with broadcasting
271 //
272
273 /**
274 * Useful for binary elementwise operators. For each element of the inputs,
275 * perform a computation and write to the corresponding element of the output.
276 * Tensor broadcasting is applied wherever it is required.
277 */
278 template <typename CTYPE_A, typename CTYPE_B, typename CTYPE_OUT, typename Op>
apply_binary_elementwise_fn(const Op & compute_fun,const Tensor & a,const Tensor & b,const Tensor & out)279 inline void apply_binary_elementwise_fn(
280 const Op& compute_fun,
281 const Tensor& a,
282 const Tensor& b,
283 const Tensor& out) {
284 const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
285 const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
286 const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
287
288 const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
289 const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
290 CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
291
292 for (size_t i = 0; i < out.numel(); ++i) {
293 size_t a_linear_index = i;
294 size_t b_linear_index = i;
295
296 if (any_is_broadcasted) {
297 size_t out_indexes[kTensorDimensionLimit];
298 delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
299
300 if (a_is_broadcasted) {
301 a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
302 }
303 if (b_is_broadcasted) {
304 b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
305 }
306 }
307
308 data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
309 }
310 }
311
312 /**
313 * Useful for ternary elementwise operators. For each element of the inputs,
314 * perform a computation and write to the corresponding element of the output.
315 * Tensor broadcasting is applied wherever it is required.
316 */
317 template <
318 typename CTYPE_A,
319 typename CTYPE_B,
320 typename CTYPE_C,
321 typename CTYPE_OUT,
322 typename Op>
apply_ternary_elementwise_fn(const Op & compute_fun,const Tensor & a,const Tensor & b,const Tensor & c,const Tensor & out)323 inline void apply_ternary_elementwise_fn(
324 const Op& compute_fun,
325 const Tensor& a,
326 const Tensor& b,
327 const Tensor& c,
328 const Tensor& out) {
329 const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
330 const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
331 const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
332 const bool any_is_broadcasted =
333 (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
334
335 const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
336 const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
337 const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
338 CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
339
340 for (size_t i = 0; i < out.numel(); ++i) {
341 size_t a_linear_index = i;
342 size_t b_linear_index = i;
343 size_t c_linear_index = i;
344
345 if (any_is_broadcasted) {
346 size_t out_indexes[kTensorDimensionLimit];
347 delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
348
349 if (a_is_broadcasted) {
350 a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
351 }
352 if (b_is_broadcasted) {
353 b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
354 }
355 if (c_is_broadcasted) {
356 c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
357 }
358 }
359
360 data_out[i] = compute_fun(
361 data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]);
362 }
363 }
364
365 } // namespace executor
366 } // namespace torch
367