xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/broadcast_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13 
14 namespace torch {
15 namespace executor {
16 
17 /**
18  * Check whether or not the broadcast_from_shape can be broadcasted onto the
19  * broadcast_to_shape.
20  *
21  * @param[in] broadcast_from_shape The tensor shape which we want to broadcast.
22  * @param[in] broadcast_to_shape The tensor shape which we want to broadcast to.
23  * @returns A bool to indicate whether or not the shape can be broadcasted.
24  *
25  */
26 bool tensor_is_broadcastable_to(
27     const exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
28     const exec_aten::ArrayRef<Tensor::SizesType> broadcast_to_shape);
29 
30 /**
31  * Check whether or not the broadcast_from tensor should and can be broadcasted
32  * onto the broadcast_to tensor. broadcast_tensor should only be called if this
33  * returns true.
34  *
35  * @param[in] broadcast_from The tensor which we want to broadcast from.
36  * @param[in] broadcast_to The tensor to which we want to broadcast to.
37  * @returns A bool to indicate whether or not the tensor can be broadcasted.
38  *
39  */
40 bool tensor_is_broadcastable_to(
41     const Tensor& broadcast_from,
42     const Tensor& broadcast_to);
43 
44 /**
45  * Returns true if the two tensor shapes can both be broadcasted to a common
46  * shape.
47  *
48  * @param[in] a_shape The sizes of the first tensor going to be test.
49  * @param[in] b_shape The sizes of the second tensor going to be test.
50  * @returns true if the tensors are broadcastable, false otherwise.
51  */
52 bool tensors_are_broadcastable_between(
53     const exec_aten::ArrayRef<Tensor::SizesType> a_shape,
54     const exec_aten::ArrayRef<Tensor::SizesType> b_shape);
55 
56 /**
57  * Convenience overload of the above function to accept Tensor inputs.
58  *
59  * @param[in] a The first tensor going to be test.
60  * @param[in] b The second tensor going to be test.
61  * @returns true if the tensors are broadcastable, false otherwise.
62  */
63 bool tensors_are_broadcastable_between(const Tensor& a, const Tensor& b);
64 
65 /**
66  * DEPRECATED: Use `delinearize_index()` and `linearize_access_indexes()` for
67  * index remapping to avoid memory allocation.
68  *
69  * The smaller tensor broadcast_from is “broadcast” across the larger tensor
70  * broadcast_to so that they have compatible shapes.
71  * broadcast_to_shape.size() >= broadcast_from_shape.size() in order for this
72  * to work.
73  *
74  * @param[in] broadcast_from The tensor to which we want to broadcast from.
75  * @param[in] broadcast_to The tensor to which we want to broadcast to.
76  * @returns A new tensor with the same shape as broadcast_to and the data
77  * repeated as appropriate. This tensor contains dynamically allocated memory
78  * and must be freed using free_broadcast_tensor.
79  */
80 ET_DEPRECATED exec_aten::Tensor broadcast_tensor(
81     const exec_aten::Tensor& broadcast_from,
82     const exec_aten::Tensor& broadcast_to);
83 
84 /**
85  * Get the size of the target tensor that two input tensors would be broadcasted
86  * to.
87  *
88  * This function is useful especially for the operator supporting both broadcast
89  * and dynamic shape. At that time there may not be a tensor having the size of
90  * final output, so we need to calculate it.
91  *
92  * @param[in] a_size The size of the first tensor going to be broadcasted.
93  * @param[in] b_size The size of the second tensor going to be broadcasted.
94  * @param[out] out_sizes The memory space storing the size of
95  * broadcasted target tensor
96  * @param[in] out_sizes_len The largest number of element
97  * @param[out] out_dim The dimension of the broadcasted target
98  * tensor
99  */
100 ET_NODISCARD Error get_broadcast_target_size(
101     const exec_aten::ArrayRef<Tensor::SizesType> a_size,
102     const exec_aten::ArrayRef<Tensor::SizesType> b_size,
103     Tensor::SizesType* out_sizes,
104     const size_t out_sizes_len,
105     size_t* out_dim);
106 
107 /**
108  * Convenience overload of the above function to accept Tensor inputs.
109  *
110  * @param[in] a The first tensor going to be broadcasted.
111  * @param[in] b The second tensor going to be broadcasted.
112  * @param[out] out_sizes The memory space storing the size of
113  * broadcasted target tensor
114  * @param[in] out_sizes_len The largest number of element
115  * @param[out] out_dim The dimension of the broadcasted target
116  * tensor
117  */
118 ET_NODISCARD Error get_broadcast_target_size(
119     const Tensor& a,
120     const Tensor& b,
121     Tensor::SizesType* out_sizes,
122     const size_t out_sizes_len,
123     size_t* out_dim);
124 
125 /**
126  * Get the size that two input tensors will be broadcasted to, and resize an
127  * output tensor to the resulting broadcasted size.
128  *
129  * @param[in] a The first tensor going to be broadcasted.
130  * @param[in] b The second tensor going to be broadcasted.
131  * @param[out] out The output tensor that will be resized.
132  */
133 ET_NODISCARD inline Error
resize_to_broadcast_target_size(const Tensor & a,const Tensor & b,Tensor & out)134 resize_to_broadcast_target_size(const Tensor& a, const Tensor& b, Tensor& out) {
135   Tensor::SizesType expected_output_size[kTensorDimensionLimit];
136   size_t expected_output_dim = 0;
137 
138   ET_CHECK_OK_OR_RETURN_ERROR(
139       get_broadcast_target_size(
140           a,
141           b,
142           expected_output_size,
143           kTensorDimensionLimit,
144           &expected_output_dim),
145       "Failed to get broadcast target size");
146 
147   return resize_tensor(out, {expected_output_size, expected_output_dim});
148 }
149 
150 /**
151  * Get the size that three input tensors will be broadcasted to, and resize an
152  * output tensor to the resulting broadcasted size.
153  *
154  * @param[in] a The first tensor going to be broadcasted.
155  * @param[in] b The second tensor going to be broadcasted.
156  * @param[in] c The third tensor going to be broadcasted.
157  * @param[out] out The output tensor that will be resized.
158  */
resize_to_broadcast_target_size(const Tensor & a,const Tensor & b,const Tensor & c,Tensor & out)159 ET_NODISCARD inline Error resize_to_broadcast_target_size(
160     const Tensor& a,
161     const Tensor& b,
162     const Tensor& c,
163     Tensor& out) {
164   Tensor::SizesType interim_output_size[kTensorDimensionLimit];
165   size_t interim_output_dim = 0;
166 
167   // Obtain the broadcast size of the first two input tensors
168   ET_CHECK_OK_OR_RETURN_ERROR(
169       get_broadcast_target_size(
170           a,
171           b,
172           interim_output_size,
173           kTensorDimensionLimit,
174           &interim_output_dim),
175       "Failed to get broadcast target size");
176 
177   Tensor::SizesType expected_output_size[kTensorDimensionLimit];
178   size_t expected_output_dim = 0;
179 
180   // Apply broadcasting to the intermediate broadcast size and the third input
181   // tensor
182   ET_CHECK_OK_OR_RETURN_ERROR(
183       get_broadcast_target_size(
184           {interim_output_size, interim_output_dim},
185           c.sizes(),
186           expected_output_size,
187           kTensorDimensionLimit,
188           &expected_output_dim),
189       "Failed to get broadcast target size");
190 
191   return resize_tensor(out, {expected_output_size, expected_output_dim});
192 }
193 
194 /**
195  * DEPRECATED: Use `delinearize_index()` and `linearize_access_indexes()` for
196  * index remapping to avoid memory allocation.
197  *
198  * Free the dynamically allocated memory in broadcast_tensor. This should only
199  * be used on a tensor returned by broadcast_tensor.
200  *
201  * @param[in] The tensor that was previosuly returned by a call to
202  * broadcast_tensor.
203  * @returns void
204  */
205 ET_DEPRECATED void free_broadcast_tensor(
206     const exec_aten::Tensor& broadcast_tensor);
207 
208 /**
209  * Delinearize a flattened index to per-dimension indexes.
210  *
211  * @param[in] linear_index The flattened index
212  * @param[in] shape The tensor shape
213  * @param[out] out_indexes The per-dimension indexes
214  * @param[in] out_indexes_len The maximum size of the out_indexes array
215  * @returns void
216  */
217 void delinearize_index(
218     size_t linear_index,
219     exec_aten::ArrayRef<Tensor::SizesType> shape,
220     size_t* out_indexes,
221     const size_t out_indexes_len);
222 
223 /**
224  * Delinearize a flattened index to per-dimension indexes.
225  *
226  * @param[in] linear_index The flattened index
227  * @param[in] t The tensor object
228  * @param[out] out_indexes The per-dimension indexes
229  * @param[in] out_indexes_len The maximum size of the out_indexes array
230  * @returns void
231  */
232 void delinearize_index(
233     size_t linear_index,
234     const Tensor& t,
235     size_t* out_indexes,
236     const size_t out_indexes_len);
237 
238 /**
239  * Return the linear index for broatcast_from tensor, given the indexes and
240  * number of dimensions of broadcast_to tensor, and the shape and strides
241  * of broadcast_from tensor.
242  *
243  * @param[in] indexes_broadcast_to The access indexes of broadcast_to tensor.
244  * @param[in] broadcast_to_ndim The number of dims of broadcast_to tensor.
245  * @param[in] broadcast_from_shape The shape of the broadcasted tensor.
246  * @param[in] broadcast_from_strides The strides of the broadcasted tensor.
247  * @returns The flattend index for broadcast_from tensor.
248  */
249 size_t linearize_access_indexes(
250     ArrayRef<size_t> indexes_broadcast_to,
251     ssize_t broadcast_to_ndim,
252     exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
253     exec_aten::ArrayRef<Tensor::StridesType> broadcast_from_strides);
254 
255 /**
256  * Return the linear index for broatcast_from tensor, given the indexes of
257  * broadcast_to tensor and itself.
258  *
259  * @param[in] indexes_broadcast_to The access indexes of broadcast_to tensor.
260  * @param[in] broadcast_to_ndim The number of dims of broadcast_to tensor.
261  * @param[in] broadcast_from The tensor to be broadcasted.
262  * @returns The flattend index for broadcast_from tensor.
263  */
264 size_t linearize_access_indexes(
265     ArrayRef<size_t> indexes_broadcast_to,
266     ssize_t broadcast_to_ndim,
267     const Tensor& broadcast_from);
268 
269 //
270 // Mapping with broadcasting
271 //
272 
273 /**
274  * Useful for binary elementwise operators. For each element of the inputs,
275  * perform a computation and write to the corresponding element of the output.
276  * Tensor broadcasting is applied wherever it is required.
277  */
278 template <typename CTYPE_A, typename CTYPE_B, typename CTYPE_OUT, typename Op>
apply_binary_elementwise_fn(const Op & compute_fun,const Tensor & a,const Tensor & b,const Tensor & out)279 inline void apply_binary_elementwise_fn(
280     const Op& compute_fun,
281     const Tensor& a,
282     const Tensor& b,
283     const Tensor& out) {
284   const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
285   const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
286   const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
287 
288   const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
289   const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
290   CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
291 
292   for (size_t i = 0; i < out.numel(); ++i) {
293     size_t a_linear_index = i;
294     size_t b_linear_index = i;
295 
296     if (any_is_broadcasted) {
297       size_t out_indexes[kTensorDimensionLimit];
298       delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
299 
300       if (a_is_broadcasted) {
301         a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
302       }
303       if (b_is_broadcasted) {
304         b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
305       }
306     }
307 
308     data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
309   }
310 }
311 
312 /**
313  * Useful for ternary elementwise operators. For each element of the inputs,
314  * perform a computation and write to the corresponding element of the output.
315  * Tensor broadcasting is applied wherever it is required.
316  */
317 template <
318     typename CTYPE_A,
319     typename CTYPE_B,
320     typename CTYPE_C,
321     typename CTYPE_OUT,
322     typename Op>
apply_ternary_elementwise_fn(const Op & compute_fun,const Tensor & a,const Tensor & b,const Tensor & c,const Tensor & out)323 inline void apply_ternary_elementwise_fn(
324     const Op& compute_fun,
325     const Tensor& a,
326     const Tensor& b,
327     const Tensor& c,
328     const Tensor& out) {
329   const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
330   const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
331   const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
332   const bool any_is_broadcasted =
333       (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
334 
335   const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
336   const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
337   const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
338   CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
339 
340   for (size_t i = 0; i < out.numel(); ++i) {
341     size_t a_linear_index = i;
342     size_t b_linear_index = i;
343     size_t c_linear_index = i;
344 
345     if (any_is_broadcasted) {
346       size_t out_indexes[kTensorDimensionLimit];
347       delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
348 
349       if (a_is_broadcasted) {
350         a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
351       }
352       if (b_is_broadcasted) {
353         b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
354       }
355       if (c_is_broadcasted) {
356         c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
357       }
358     }
359 
360     data_out[i] = compute_fun(
361         data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]);
362   }
363 }
364 
365 } // namespace executor
366 } // namespace torch
367