xref: /aosp_15_r20/external/executorch/runtime/core/exec_aten/util/tensor_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <algorithm>
12 #include <array> // std::array
13 #include <cinttypes> // PRId64
14 #include <cmath>
15 #include <cstddef> // size_t
16 #include <limits>
17 
18 #include <executorch/runtime/core/array_ref.h>
19 #include <executorch/runtime/core/error.h>
20 #include <executorch/runtime/core/exec_aten/exec_aten.h>
21 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
22 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
23 #include <executorch/runtime/platform/assert.h>
24 #include <executorch/runtime/platform/compiler.h>
25 
26 /// All assertion messages should begin with this prefix.
27 #define ET_TENSOR_CHECK_PREFIX__ "Tensors do not match"
28 #define ET_MIN2(a, b) (std::min(a, b))
29 #define ET_MIN3(a, b, c) (std::min(a, std::min(b, c)))
30 
31 #define ET_NORMALIZE_IX(IX, UPPER_BOUND) IX < 0 ? IX + UPPER_BOUND : IX
32 
33 #define ET_CHECK_VALID_IX(IX, UPPER_BOUND)                  \
34   ET_CHECK_MSG(                                             \
35       IX >= -static_cast<int64_t>(UPPER_BOUND) &&           \
36           IX < static_cast<int64_t>(UPPER_BOUND),           \
37       "index %" PRId64 " must be within range [-%zd, %zd)", \
38       IX,                                                   \
39       UPPER_BOUND,                                          \
40       UPPER_BOUND)
41 
42 #define ET_CHECK_VALID_DIM(DIM, UPPER_BOUND)              \
43   ET_CHECK_MSG(                                           \
44       DIM >= -static_cast<int64_t>(UPPER_BOUND) &&        \
45           DIM < static_cast<int64_t>(UPPER_BOUND),        \
46       "dim %" PRId64 " must be within range [-%zd, %zd)", \
47       DIM,                                                \
48       UPPER_BOUND,                                        \
49       UPPER_BOUND)
50 
51 #define ET_CHECK_NON_ZERO_DIM_SIZE(DIM, T)           \
52   const size_t udim = ET_NORMALIZE_IX(DIM, T.dim()); \
53   ET_CHECK_MSG(                                      \
54       T.size(udim) != 0, "Expected dim %zd to have non-zero size.", udim);
55 
56 /**
57  * Asserts that all tensors have the same shape.
58  * This also handles a edge case where there is only one element in all the
59  * tensors being compared but the number of dimensions >= 0. In the for loop
60  * iterating over the dimensions we make sure that we pick the smallest
61  * dimension of all the tensors as the upper bound for the for loop.
62  */
63 #define ET_CHECK_SAME_SHAPE2(a__, b__)                                    \
64   ({                                                                      \
65     const size_t a_numel__ = (a__).numel();                               \
66     const size_t b_numel__ = (b__).numel();                               \
67     const size_t a_dim__ = (a__).dim();                                   \
68     const size_t b_dim__ = (b__).dim();                                   \
69     ET_CHECK_MSG(                                                         \
70         a_numel__ == b_numel__ &&                                         \
71             ((a_numel__ == 1 && b_numel__ == 1) || (a_dim__ == b_dim__)), \
72         ET_TENSOR_CHECK_PREFIX__ ": numel={%zu, %zu}, dim={%zu, %zu}",    \
73         a_numel__,                                                        \
74         b_numel__,                                                        \
75         a_dim__,                                                          \
76         b_dim__);                                                         \
77     for (size_t dim__ = 0; dim__ < ET_MIN2(a_dim__, b_dim__); ++dim__) {  \
78       size_t a_size__ = (a__).size(dim__);                                \
79       size_t b_size__ = (b__).size(dim__);                                \
80       ET_CHECK_MSG(                                                       \
81           a_size__ == b_size__,                                           \
82           ET_TENSOR_CHECK_PREFIX__ " at size(%zu): {%zu, %zu}",           \
83           dim__,                                                          \
84           a_size__,                                                       \
85           b_size__);                                                      \
86     }                                                                     \
87   })
88 
89 #define ET_CHECK_SAME_SHAPE3(a__, b__, c__)                            \
90   ({                                                                   \
91     const size_t a_numel__ = (a__).numel();                            \
92     const size_t b_numel__ = (b__).numel();                            \
93     const size_t c_numel__ = (c__).numel();                            \
94     const size_t a_dim__ = (a__).dim();                                \
95     const size_t b_dim__ = (b__).dim();                                \
96     const size_t c_dim__ = (c__).dim();                                \
97     ET_CHECK_MSG(                                                      \
98         a_numel__ == b_numel__ && b_numel__ == c_numel__ &&            \
99             ((a_numel__ == 1 && b_numel__ == 1 && c_numel__ == 1) ||   \
100              a_dim__ == b_dim__ && b_dim__ == c_dim__),                \
101         ET_TENSOR_CHECK_PREFIX__                                       \
102         ": numel={%zu, %zu, %zu}, dim={%zu, %zu, %zu}",                \
103         a_numel__,                                                     \
104         b_numel__,                                                     \
105         c_numel__,                                                     \
106         a_dim__,                                                       \
107         b_dim__,                                                       \
108         c_dim__);                                                      \
109     for (size_t dim__ = 0; dim__ < ET_MIN3(a_dim__, b_dim__, c_dim__); \
110          ++dim__) {                                                    \
111       size_t a_size__ = (a__).size(dim__);                             \
112       size_t b_size__ = (b__).size(dim__);                             \
113       size_t c_size__ = (c__).size(dim__);                             \
114       ET_CHECK_MSG(                                                    \
115           a_size__ == b_size__ && b_size__ == c_size__,                \
116           ET_TENSOR_CHECK_PREFIX__ " at size(%zu): {%zu, %zu, %zu}",   \
117           dim__,                                                       \
118           a_size__,                                                    \
119           b_size__,                                                    \
120           c_size__);                                                   \
121     }                                                                  \
122   })
123 
124 /// Asserts that all tensors have the same dtype.
125 #define ET_CHECK_SAME_DTYPE2(a__, b__)                                   \
126   ({                                                                     \
127     const ::executorch::aten::ScalarType a_type__ = (a__).scalar_type(); \
128     const ::executorch::aten::ScalarType b_type__ = (b__).scalar_type(); \
129     ET_CHECK_MSG(                                                        \
130         a_type__ == b_type__,                                            \
131         ET_TENSOR_CHECK_PREFIX__ ": dtype={%" PRId8 ", %" PRId8 "}",     \
132         static_cast<int8_t>(a_type__),                                   \
133         static_cast<int8_t>(b_type__));                                  \
134   })
135 
136 #define ET_CHECK_SAME_DTYPE3(a__, b__, c__)                                 \
137   ({                                                                        \
138     const ::executorch::aten::ScalarType a_type__ = (a__).scalar_type();    \
139     const ::executorch::aten::ScalarType b_type__ = (b__).scalar_type();    \
140     const ::executorch::aten::ScalarType c_type__ = (c__).scalar_type();    \
141     ET_CHECK_MSG(                                                           \
142         a_type__ == b_type__ && b_type__ == c_type__,                       \
143         ET_TENSOR_CHECK_PREFIX__ ": dtype={%" PRId8 ", %" PRId8 ", %" PRId8 \
144                                  "}",                                       \
145         static_cast<int8_t>(a_type__),                                      \
146         static_cast<int8_t>(b_type__),                                      \
147         static_cast<int8_t>(c_type__));                                     \
148   })
149 
150 /**
151  * Asserts that all tensors have the same shape and dtype.
152  *
153  * This macro should produce less code/data than calling the SHAPE and DTYPE
154  * macros independently, because it only calls ET_CHECK_MSG once.
155  */
156 #define ET_CHECK_SAME_SHAPE_AND_DTYPE2(a__, b__)                              \
157   ({                                                                          \
158     const size_t a_numel__ = (a__).numel();                                   \
159     const size_t b_numel__ = (b__).numel();                                   \
160     const size_t a_dim__ = (a__).dim();                                       \
161     const size_t b_dim__ = (b__).dim();                                       \
162     const ::executorch::aten::ScalarType a_type__ = (a__).scalar_type();      \
163     const ::executorch::aten::ScalarType b_type__ = (b__).scalar_type();      \
164                                                                               \
165     ET_CHECK_MSG(                                                             \
166         a_numel__ == b_numel__ &&                                             \
167             ((a_numel__ == 1 && b_numel__ == 1) || a_dim__ == b_dim__) &&     \
168             a_type__ == b_type__,                                             \
169         ET_TENSOR_CHECK_PREFIX__                                              \
170         ": numel={%zu, %zu}, dim={%zu, %zu}, dtype={%" PRId8 ", %" PRId8 "}", \
171         a_numel__,                                                            \
172         b_numel__,                                                            \
173         a_dim__,                                                              \
174         b_dim__,                                                              \
175         static_cast<int8_t>(a_type__),                                        \
176         static_cast<int8_t>(b_type__));                                       \
177     for (size_t dim__ = 0; dim__ < ET_MIN2(a_dim__, b_dim__); ++dim__) {      \
178       size_t a_size__ = (a__).size(dim__);                                    \
179       size_t b_size__ = (b__).size(dim__);                                    \
180       ET_CHECK_MSG(                                                           \
181           a_size__ == b_size__,                                               \
182           ET_TENSOR_CHECK_PREFIX__ " at size(%zu): {%zu, %zu}",               \
183           dim__,                                                              \
184           a_size__,                                                           \
185           b_size__);                                                          \
186     }                                                                         \
187   })
188 
189 #define ET_CHECK_SAME_SHAPE_AND_DTYPE3(a__, b__, c__)                    \
190   ({                                                                     \
191     const size_t a_numel__ = (a__).numel();                              \
192     const size_t b_numel__ = (b__).numel();                              \
193     const size_t c_numel__ = (c__).numel();                              \
194     const size_t a_dim__ = (a__).dim();                                  \
195     const size_t b_dim__ = (b__).dim();                                  \
196     const size_t c_dim__ = (c__).dim();                                  \
197     const ::executorch::aten::ScalarType a_type__ = (a__).scalar_type(); \
198     const ::executorch::aten::ScalarType b_type__ = (b__).scalar_type(); \
199     const ::executorch::aten::ScalarType c_type__ = (c__).scalar_type(); \
200                                                                          \
201     ET_CHECK_MSG(                                                        \
202         a_numel__ == b_numel__ && b_numel__ == c_numel__ &&              \
203             ((a_numel__ == 1 && b_numel__ == 1 && c_numel__ == 1) ||     \
204              (a_dim__ == b_dim__ && b_dim__ == c_dim__)) &&              \
205             a_type__ == b_type__ && b_type__ == c_type__,                \
206         ET_TENSOR_CHECK_PREFIX__                                         \
207         ": numel={%zu, %zu, %zu}, dim={%zu, %zu, %zu}, "                 \
208         "dtype={%" PRId8 ", %" PRId8 ", %" PRId8 "}",                    \
209         a_numel__,                                                       \
210         b_numel__,                                                       \
211         c_numel__,                                                       \
212         a_dim__,                                                         \
213         b_dim__,                                                         \
214         c_dim__,                                                         \
215         static_cast<int8_t>(a_type__),                                   \
216         static_cast<int8_t>(b_type__),                                   \
217         static_cast<int8_t>(c_type__));                                  \
218     for (size_t dim__ = 0; dim__ < ET_MIN3(a_dim__, b_dim__, c_dim__);   \
219          ++dim__) {                                                      \
220       size_t a_size__ = (a__).size(dim__);                               \
221       size_t b_size__ = (b__).size(dim__);                               \
222       size_t c_size__ = (c__).size(dim__);                               \
223       ET_CHECK_MSG(                                                      \
224           a_size__ == b_size__ && b_size__ == c_size__,                  \
225           ET_TENSOR_CHECK_PREFIX__ " at size(%zu): {%zu, %zu, %zu}",     \
226           dim__,                                                         \
227           a_size__,                                                      \
228           b_size__,                                                      \
229           c_size__);                                                     \
230     }                                                                    \
231   })
232 
233 /**
234  * Assert that the input tensor is contiguous tensor.
235  */
236 #define ET_CHECK_CONTIGUOUS(a__)                                              \
237   ({                                                                          \
238     const ::executorch::aten::ArrayRef<executorch::aten::StridesType>         \
239         strides = a__.strides();                                              \
240     const ::executorch::aten::ArrayRef<executorch::aten::StridesType> sizes = \
241         a__.sizes();                                                          \
242     ET_CHECK_MSG(                                                             \
243         strides[strides.size() - 1] == 1,                                     \
244         "The stride of the last dimension shall be 1 for contiguous tensor, " \
245         "not %d",                                                             \
246         strides[strides.size() - 1]);                                         \
247     for (size_t i = strides.size() - 1; i > 0; i--) {                         \
248       ET_CHECK_MSG(                                                           \
249           strides[i - 1] == strides[i] * sizes[i],                            \
250           "The stride of the %zu-th dimension shall equal to "                \
251           "strides[%zu] * sizes[%zu], now is %d and %d",                      \
252           i - 1,                                                              \
253           i,                                                                  \
254           i,                                                                  \
255           strides[i - 1],                                                     \
256           strides[i] * sizes[i]);                                             \
257     }                                                                         \
258   })
259 
260 /**
261  * Assert the input two tensors share same strides.
262  * Noted that this function does not make any check or promise on the contiguity
263  * of any input tensors.
264  */
265 #define ET_CHECK_SAME_STRIDES2(a__, b__)                                       \
266   ({                                                                           \
267     ET_CHECK_MSG(                                                              \
268         a__.dim() == b__.dim(),                                                \
269         "Two tensors shall have same number of strides, but not %zu and %zu.", \
270         a__.dim(),                                                             \
271         b__.dim());                                                            \
272     const ::executorch::aten::ArrayRef<executorch::aten::StridesType>          \
273         a_strides = a__.strides();                                             \
274     const ::executorch::aten::ArrayRef<executorch::aten::StridesType>          \
275         b_strides = b__.strides();                                             \
276     for (size_t i = 0; i < a__.dim(); i++) {                                   \
277       ET_CHECK_MSG(                                                            \
278           a_strides[i] == b_strides[i],                                        \
279           "a.strides()[%zu] shall equal to b.strides()[%zu], "                 \
280           "but now is %d and %d.",                                             \
281           i,                                                                   \
282           i,                                                                   \
283           (int32_t)a_strides[i],                                               \
284           (int32_t)b_strides[i]);                                              \
285     }                                                                          \
286   })
287 
288 /**
289  * Assert the input three tensors share same strides.
290  * Noted that this function does not make any check or promise on the contiguity
291  * of any input tensors.
292  */
293 #define ET_CHECK_SAME_STRIDES3(a__, b__, c__)                           \
294   ({                                                                    \
295     ET_CHECK_MSG(                                                       \
296         a__.dim() == b__.dim() && b__.dim() == c__.dim(),               \
297         "Three tensors shall have same number of strides, "             \
298         "but not %zu, %zu and %zu.",                                    \
299         a__.dim(),                                                      \
300         b__.dim(),                                                      \
301         c__.dim());                                                     \
302     const ::executorch::aten::ArrayRef<executorch::aten::StridesType>   \
303         a_strides = a__.strides();                                      \
304     const ::executorch::aten::ArrayRef<executorch::aten::StridesType>   \
305         b_strides = b__.strides();                                      \
306     const ::executorch::aten::ArrayRef<executorch::aten::StridesType>   \
307         c_strides = c__.strides();                                      \
308     for (size_t i = 0; i < a__.dim(); i++) {                            \
309       ET_CHECK_MSG(                                                     \
310           a_strides[i] == b_strides[i] && b_strides[i] == c_strides[i], \
311           "a_strides[%zu], b_strides[%zu] and c_strides[%zu] "          \
312           "shall share same value, but now is %d, %d and %d",           \
313           i,                                                            \
314           i,                                                            \
315           i,                                                            \
316           (int32_t)a_strides[i],                                        \
317           (int32_t)b_strides[i],                                        \
318           (int32_t)c_strides[i]);                                       \
319     }                                                                   \
320   })
321 
322 #define ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(t__)           \
323   ({                                                             \
324     ET_CHECK_MSG(                                                \
325         is_contiguous_dim_order(                                 \
326             t__.dim_order().data(), t__.dim_order().size()) ||   \
327             is_channels_last_dim_order(                          \
328                 t__.dim_order().data(), t__.dim_order().size()), \
329         "Tensor must have default or channels last dim order");  \
330   })
331 
332 /**
333  * A convenience macro to be used in utility functions that check whether input
334  * tensor(s) are valid, which are expected to return a boolean. Checks whether
335  * `cond` is true; if not, log the failed check and return false.
336  *
337  * @param[in] cond the condition to check
338  */
339 #define ET_LOG_AND_RETURN_IF_FALSE(cond)           \
340   do {                                             \
341     if (!(cond)) {                                 \
342       ET_LOG(Error, "Check failed (%s): ", #cond); \
343       return false;                                \
344     }                                              \
345   } while (false)
346 
347 /**
348  * A convenience macro to be used in utility functions that check whether input
349  * tensor(s) are valid, which are expected to return a boolean. Checks whether
350  * `cond` is true; if not, log the failed check with `message` and return false.
351  *
352  * @param[in] cond the condition to check
353  * @param[in] message an additional message to log with `cond`
354  */
355 #define ET_LOG_MSG_AND_RETURN_IF_FALSE(cond, message, ...)                \
356   do {                                                                    \
357     if (!(cond)) {                                                        \
358       ET_LOG(Error, "Check failed (%s): " message, #cond, ##__VA_ARGS__); \
359       return false;                                                       \
360     }                                                                     \
361   } while (false)
362 
363 /**
364  * If `cond` is false, log `cond` and return from the kernel with a failure
365  * state set.
366  *
367  * @param[in] context the runtime context
368  * @param[in] cond the condition to check
369  * @param[in] error torch::executor::Error enum value (e.g `InvalidArgument`)
370  * @param[in] retval return value of the kernel to allow for early exit
371  */
372 #define ET_KERNEL_CHECK(context, cond, error, retval) \
373   do {                                                \
374     if (!(cond)) {                                    \
375       ET_LOG(Error, "Check failed (%s): ", #cond);    \
376       context.fail(torch::executor::Error::error);    \
377       return retval;                                  \
378     }                                                 \
379   } while (false)
380 
381 /**
382  * If `cond` is false, log `message` and return from the kernel with a failure
383  * state set.
384  *
385  * @param[in] context the runtime context
386  * @param[in] cond the condition to check
387  * @param[in] error torch::executor::Error enum value (e.g `InvalidArgument`)
388  * @param[in] retval return value of the kernel to allow for early exit
389  */
390 #define ET_KERNEL_CHECK_MSG(context, cond, error, retval, message, ...)   \
391   do {                                                                    \
392     if (!(cond)) {                                                        \
393       ET_LOG(Error, "Check failed (%s): " message, #cond, ##__VA_ARGS__); \
394       context.fail(torch::executor::Error::error);                        \
395       return retval;                                                      \
396     }                                                                     \
397   } while (false)
398 
399 /**
400  * Convenience macro to extract a scalar tensor into a value
401  */
402 #define ET_EXTRACT_SCALAR_TENSOR(scalar_tensor, out_val) \
403   ET_CHECK_MSG(                                          \
404       extract_scalar_tensor(scalar_tensor, &out_val),    \
405       #scalar_tensor " could not be extracted: wrong type or out of range");
406 
407 namespace executorch {
408 namespace runtime {
409 
410 //
411 // Utility functions for checking tensor attributes
412 //
413 //
414 
415 /*
416  * Returns true if the given dimension value is between -upper_bound and
417  * upper_bound - 1, inclusive.
418  */
dim_is_valid(int64_t dim,int64_t upper_bound)419 inline bool dim_is_valid(int64_t dim, int64_t upper_bound) {
420   ET_LOG_MSG_AND_RETURN_IF_FALSE(
421       dim >= -upper_bound && dim < upper_bound,
422       "Dimension %" PRId64
423       " is out of range. Dimension should be between %" PRId64 " and %" PRId64
424       ", inclusive.",
425       dim,
426       -upper_bound,
427       upper_bound - 1);
428 
429   return true;
430 }
431 
432 /*
433  * Returns the tensor's number of dimensions, except when the tensor is zero
434  * dimensional. In this case, it returns 1. This is used to properly handle
435  * the zero dimensional tensors in some kernels, that treat them as 1D tensors
436  * with a single element.
437  */
nonzero_dim(const executorch::aten::Tensor & tensor)438 inline ssize_t nonzero_dim(const executorch::aten::Tensor& tensor) {
439   return tensor.dim() == 0 ? 1 : tensor.dim();
440 }
441 
442 /*
443  * Returns the size along a dimension dim, except when the tensor is zero
444  * dimensional. In this case, it returns 1. This is used to properly handle
445  * the zero dimensional tensors in some kernels, that treat them as 1D tensors
446  * with a single element.
447  */
nonempty_size(const executorch::aten::Tensor & tensor,ssize_t dim)448 inline ssize_t nonempty_size(
449     const executorch::aten::Tensor& tensor,
450     ssize_t dim) {
451   return tensor.dim() == 0 ? 1 : tensor.size(dim);
452 }
453 
tensor_can_cast_to(executorch::aten::Tensor a,executorch::aten::ScalarType dtype)454 inline bool tensor_can_cast_to(
455     executorch::aten::Tensor a,
456     executorch::aten::ScalarType dtype) {
457   ET_LOG_MSG_AND_RETURN_IF_FALSE(
458       torch::executor::canCast(a.scalar_type(), dtype),
459       "Tensor of dtype %s cannot cast to dtype %s",
460       torch::executor::toString(a.scalar_type()),
461       torch::executor::toString(dtype));
462 
463   return true;
464 }
465 
tensor_is_bool_type(executorch::aten::Tensor t)466 inline bool tensor_is_bool_type(executorch::aten::Tensor t) {
467   ET_LOG_MSG_AND_RETURN_IF_FALSE(
468       t.scalar_type() == executorch::aten::ScalarType::Bool,
469       "Expected to find bool type, but tensor has type %s",
470       torch::executor::toString(t.scalar_type()));
471 
472   return true;
473 }
474 
tensor_is_type(executorch::aten::Tensor t,executorch::aten::ScalarType dtype)475 inline bool tensor_is_type(
476     executorch::aten::Tensor t,
477     executorch::aten::ScalarType dtype) {
478   ET_LOG_MSG_AND_RETURN_IF_FALSE(
479       t.scalar_type() == dtype,
480       "Expected to find %s type, but tensor has type %s",
481       torch::executor::toString(dtype),
482       torch::executor::toString(t.scalar_type()));
483 
484   return true;
485 }
486 
487 inline bool tensor_is_integral_type(
488     executorch::aten::Tensor t,
489     bool includeBool = false) {
490   ET_LOG_MSG_AND_RETURN_IF_FALSE(
491       torch::executor::isIntegralType(t.scalar_type(), includeBool),
492       "Expected to find a integral type, but tensor has type %s",
493       torch::executor::toString(t.scalar_type()));
494 
495   return true;
496 }
497 
tensor_is_floating_type(executorch::aten::Tensor t)498 inline bool tensor_is_floating_type(executorch::aten::Tensor t) {
499   ET_LOG_MSG_AND_RETURN_IF_FALSE(
500       torch::executor::isFloatingType(t.scalar_type()),
501       "Expected to find a floating type, but tensor has type %s",
502       torch::executor::toString(t.scalar_type()));
503 
504   return true;
505 }
506 
tensor_is_real_type(executorch::aten::Tensor t)507 inline bool tensor_is_real_type(executorch::aten::Tensor t) {
508   ET_LOG_MSG_AND_RETURN_IF_FALSE(
509       torch::executor::isRealType(t.scalar_type()),
510       "Expected to find a real type, but tensor has type %s",
511       torch::executor::toString(t.scalar_type()));
512 
513   return true;
514 }
515 
tensor_is_realh_type(executorch::aten::Tensor t)516 inline bool tensor_is_realh_type(executorch::aten::Tensor t) {
517   ET_LOG_MSG_AND_RETURN_IF_FALSE(
518       torch::executor::isRealHType(t.scalar_type()),
519       "Expected to find a real type, but tensor has type %s",
520       torch::executor::toString(t.scalar_type()));
521 
522   return true;
523 }
524 
tensor_is_realhbf16_type(executorch::aten::Tensor t)525 inline bool tensor_is_realhbf16_type(executorch::aten::Tensor t) {
526   ET_LOG_MSG_AND_RETURN_IF_FALSE(
527       executorch::runtime::isRealHBF16Type(t.scalar_type()),
528       "Expected to find a real type, but tensor has type %s",
529       torch::executor::toString(t.scalar_type()));
530 
531   return true;
532 }
533 
tensor_is_realhb_type(executorch::aten::Tensor t)534 inline bool tensor_is_realhb_type(executorch::aten::Tensor t) {
535   ET_LOG_MSG_AND_RETURN_IF_FALSE(
536       torch::executor::isRealHBType(t.scalar_type()),
537       "Expected to find a real type, but tensor has type %s",
538       torch::executor::toString(t.scalar_type()));
539 
540   return true;
541 }
542 
tensor_is_realhbbf16_type(executorch::aten::Tensor t)543 inline bool tensor_is_realhbbf16_type(executorch::aten::Tensor t) {
544   ET_LOG_MSG_AND_RETURN_IF_FALSE(
545       executorch::runtime::isRealHBBF16Type(t.scalar_type()),
546       "Expected to find a real type, but tensor has type %s",
547       torch::executor::toString(t.scalar_type()));
548 
549   return true;
550 }
551 
tensor_is_complex_type(executorch::aten::Tensor t)552 inline bool tensor_is_complex_type(executorch::aten::Tensor t) {
553   ET_LOG_MSG_AND_RETURN_IF_FALSE(
554       torch::executor::isComplexType(t.scalar_type()),
555       "Expected to find a complex type, but tensor has type %s",
556       torch::executor::toString(t.scalar_type()));
557 
558   return true;
559 }
560 
tensor_is_bits_type(executorch::aten::Tensor t)561 inline bool tensor_is_bits_type(executorch::aten::Tensor t) {
562   ET_LOG_MSG_AND_RETURN_IF_FALSE(
563       torch::executor::isBitsType(t.scalar_type()),
564       "Expected to find a bits type, but tensor has type %s",
565       torch::executor::toString(t.scalar_type()));
566 
567   return true;
568 }
569 
tensors_have_same_dtype(executorch::aten::Tensor a,executorch::aten::Tensor b)570 inline bool tensors_have_same_dtype(
571     executorch::aten::Tensor a,
572     executorch::aten::Tensor b) {
573   ET_LOG_MSG_AND_RETURN_IF_FALSE(
574       a.scalar_type() == b.scalar_type(),
575       ET_TENSOR_CHECK_PREFIX__ ": dtype={%s, %s}",
576       torch::executor::toString(a.scalar_type()),
577       torch::executor::toString(b.scalar_type()));
578   return true;
579 }
580 
tensors_have_same_dtype(executorch::aten::Tensor a,executorch::aten::Tensor b,executorch::aten::Tensor c)581 inline bool tensors_have_same_dtype(
582     executorch::aten::Tensor a,
583     executorch::aten::Tensor b,
584     executorch::aten::Tensor c) {
585   ET_LOG_MSG_AND_RETURN_IF_FALSE(
586       a.scalar_type() == b.scalar_type() && b.scalar_type() == c.scalar_type(),
587       ET_TENSOR_CHECK_PREFIX__ ": dtype={%s, %s, %s}",
588       torch::executor::toString(a.scalar_type()),
589       torch::executor::toString(b.scalar_type()),
590       torch::executor::toString(c.scalar_type()));
591   return true;
592 }
593 
tensor_is_rank(executorch::aten::Tensor t,size_t rank)594 inline bool tensor_is_rank(executorch::aten::Tensor t, size_t rank) {
595   ET_LOG_MSG_AND_RETURN_IF_FALSE(
596       t.dim() == rank,
597       "Expected tensor.dim() to be %zu, but got %zu",
598       static_cast<size_t>(rank),
599       static_cast<size_t>(t.dim()));
600 
601   return true;
602 }
603 
tensor_has_rank_greater_or_equal_to(executorch::aten::Tensor t,size_t rank)604 inline bool tensor_has_rank_greater_or_equal_to(
605     executorch::aten::Tensor t,
606     size_t rank) {
607   ET_LOG_MSG_AND_RETURN_IF_FALSE(
608       t.dim() >= rank,
609       "Expected tensor.dim() to be >= %zu, but got %zu",
610       static_cast<size_t>(rank),
611       static_cast<size_t>(t.dim()));
612 
613   return true;
614 }
615 
tensor_has_rank_smaller_or_equal_to(executorch::aten::Tensor t,size_t rank)616 inline bool tensor_has_rank_smaller_or_equal_to(
617     executorch::aten::Tensor t,
618     size_t rank) {
619   ET_LOG_MSG_AND_RETURN_IF_FALSE(
620       t.dim() <= rank,
621       "Expected tensor.dim() to be <= %zu, but got %zu",
622       static_cast<size_t>(rank),
623       static_cast<size_t>(t.dim()));
624 
625   return true;
626 }
627 
tensor_has_dim(executorch::aten::Tensor t,int64_t d)628 inline bool tensor_has_dim(executorch::aten::Tensor t, int64_t d) {
629   if (t.dim() == 0) {
630     ET_LOG_MSG_AND_RETURN_IF_FALSE(
631         d == 0 || d == -1,
632         "dim must be 0 or -1 for 0-dim tensor, got %" PRId64,
633         d);
634   } else {
635     ET_LOG_MSG_AND_RETURN_IF_FALSE(
636         d > 0 ? d < t.dim() : t.dim() + d >= 0,
637         "%zu-dim tensor does not have dim at index %zu",
638         static_cast<size_t>(t.dim()),
639         static_cast<size_t>(d));
640   }
641   return true;
642 }
643 
tensor_has_non_empty_dim(executorch::aten::Tensor t,int64_t d)644 inline bool tensor_has_non_empty_dim(executorch::aten::Tensor t, int64_t d) {
645   const size_t udim = ET_NORMALIZE_IX(d, t.dim());
646   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(t, d));
647   ET_LOG_AND_RETURN_IF_FALSE(t.size(udim) != 0);
648   return true;
649 }
650 
651 inline bool
tensor_dim_has_index(executorch::aten::Tensor t,int64_t d,int64_t ix)652 tensor_dim_has_index(executorch::aten::Tensor t, int64_t d, int64_t ix) {
653   // Indexing ops don't support zero-dim tensors
654   ET_CHECK(t.dim() != 0);
655   if (d < 0) {
656     d += t.dim();
657   }
658   // Dimension must have been already checked by tensor_has_dim
659   ET_CHECK(d >= 0 && d < t.dim());
660 
661   ET_LOG_MSG_AND_RETURN_IF_FALSE(
662       ix >= -t.size(d) && ix < t.size(d),
663       "index %" PRId64 " out of range [-%zu,%zu) at dimension %" PRId64 ")",
664       ix,
665       static_cast<size_t>(t.size(d)),
666       static_cast<size_t>(t.size(d)),
667       d);
668   return true;
669 }
670 
tensors_have_same_size_at_dims(executorch::aten::Tensor a,size_t dim_a,executorch::aten::Tensor b,size_t dim_b)671 inline bool tensors_have_same_size_at_dims(
672     executorch::aten::Tensor a,
673     size_t dim_a,
674     executorch::aten::Tensor b,
675     size_t dim_b) {
676   ET_LOG_MSG_AND_RETURN_IF_FALSE(
677       dim_a < a.dim(),
678       "Cannot retrieve dim %zu from tensor with dim %zu",
679       static_cast<size_t>(dim_a),
680       static_cast<size_t>(a.dim()));
681   ET_LOG_MSG_AND_RETURN_IF_FALSE(
682       dim_b < b.dim(),
683       "Cannot retrieve dim %zu from tensor with dim %zu",
684       static_cast<size_t>(dim_b),
685       static_cast<size_t>(b.dim()));
686   ET_LOG_MSG_AND_RETURN_IF_FALSE(
687       a.size(dim_a) == b.size(dim_b),
688       ET_TENSOR_CHECK_PREFIX__
689       ": a.size(%zu) = %zu does not match b.size(%zu) = %zu",
690       static_cast<size_t>(dim_a),
691       static_cast<size_t>(a.size(dim_a)),
692       static_cast<size_t>(dim_b),
693       static_cast<size_t>(b.size(dim_b)));
694 
695   return true;
696 }
697 
tensors_have_same_shape(executorch::aten::Tensor a,executorch::aten::Tensor b)698 inline bool tensors_have_same_shape(
699     executorch::aten::Tensor a,
700     executorch::aten::Tensor b) {
701   if (a.numel() == 1 && b.numel() == 1) {
702     // PyTorch operators treat all scalar tensors as the same shape even if
703     // they have different dims.
704     return true;
705   }
706   if (!(a.sizes() == b.sizes() && a.numel() == b.numel())) {
707     ET_LOG(
708         Error,
709         ET_TENSOR_CHECK_PREFIX__ ": numel=(%zu,  %zu), dim=(%zu, %zu)",
710         static_cast<size_t>(a.numel()),
711         static_cast<size_t>(b.numel()),
712         static_cast<size_t>(a.dim()),
713         static_cast<size_t>(b.dim()));
714     for (size_t d = 0; d < ET_MIN2(a.dim(), b.dim()); ++d) {
715       ET_LOG(
716           Error,
717           "    size(%zu): (%zu, %zu)",
718           static_cast<size_t>(d),
719           static_cast<size_t>(a.size(d)),
720           static_cast<size_t>(b.size(d)));
721     }
722 
723     return false;
724   }
725 
726   return true;
727 }
728 
tensors_have_same_shape(executorch::aten::Tensor a,executorch::aten::Tensor b,executorch::aten::Tensor c)729 inline bool tensors_have_same_shape(
730     executorch::aten::Tensor a,
731     executorch::aten::Tensor b,
732     executorch::aten::Tensor c) {
733   if (a.numel() == 1 && b.numel() == 1 && c.numel() == 1) {
734     // PyTorch operators treat all scalar tensors as the same shape even if
735     // they have different dims.
736     return true;
737   }
738   bool cond1 = (a.sizes() == b.sizes()) && (a.numel() == b.numel());
739   bool cond2 = (b.sizes() == c.sizes()) && (b.numel() == c.numel());
740 
741   if (!(cond1 && cond2)) {
742     ET_LOG(
743         Error,
744         ET_TENSOR_CHECK_PREFIX__ ": numel=(%zu, %zu, %zu), dim=(%zu, %zu, %zu)",
745         static_cast<size_t>(a.numel()),
746         static_cast<size_t>(b.numel()),
747         static_cast<size_t>(c.numel()),
748         static_cast<size_t>(a.dim()),
749         static_cast<size_t>(b.dim()),
750         static_cast<size_t>(c.dim()));
751     for (size_t d = 0; d < ET_MIN3(a.dim(), b.dim(), c.dim()); ++d) {
752       ET_LOG(
753           Error,
754           "    size(%zu): (%zu, %zu, %zu)",
755           static_cast<size_t>(d),
756           static_cast<size_t>(a.size(d)),
757           static_cast<size_t>(b.size(d)),
758           static_cast<size_t>(c.size(d)));
759     }
760 
761     return false;
762   }
763 
764   return true;
765 }
766 
tensors_have_same_shape_and_dtype(executorch::aten::Tensor a,executorch::aten::Tensor b)767 inline bool tensors_have_same_shape_and_dtype(
768     executorch::aten::Tensor a,
769     executorch::aten::Tensor b) {
770   return tensors_have_same_shape(a, b) && tensors_have_same_dtype(a, b);
771 }
772 
tensors_have_same_shape_and_dtype(executorch::aten::Tensor a,executorch::aten::Tensor b,executorch::aten::Tensor c)773 inline bool tensors_have_same_shape_and_dtype(
774     executorch::aten::Tensor a,
775     executorch::aten::Tensor b,
776     executorch::aten::Tensor c) {
777   return tensors_have_same_shape(a, b, c) && tensors_have_same_dtype(a, b, c);
778 }
779 
tensor_has_expected_size(executorch::aten::Tensor a,executorch::aten::ArrayRef<executorch::aten::SizesType> expected_sizes)780 inline bool tensor_has_expected_size(
781     executorch::aten::Tensor a,
782     executorch::aten::ArrayRef<executorch::aten::SizesType> expected_sizes) {
783   if (!(a.sizes() == expected_sizes)) {
784     ET_LOG(
785         Error,
786         ET_TENSOR_CHECK_PREFIX__ ": dim=(%zu, %zu)",
787         static_cast<size_t>(a.dim()),
788         static_cast<size_t>(expected_sizes.size()));
789     size_t a_dim = static_cast<size_t>(a.dim());
790     size_t expected_dim = static_cast<size_t>(expected_sizes.size());
791     for (size_t d = 0; d < ET_MIN2(a_dim, expected_dim); ++d) {
792       ET_LOG(
793           Error,
794           "    size(%zu): (%zu, %zu)",
795           static_cast<size_t>(d),
796           static_cast<size_t>(a.size(d)),
797           static_cast<size_t>(expected_sizes[d]));
798     }
799 
800     return false;
801   }
802   return true;
803 }
804 
tensors_have_same_strides(executorch::aten::Tensor a,executorch::aten::Tensor b)805 inline bool tensors_have_same_strides(
806     executorch::aten::Tensor a,
807     executorch::aten::Tensor b) {
808   if (a.strides() != b.strides()) {
809     ET_LOG(
810         Error,
811         ET_TENSOR_CHECK_PREFIX__ ": dim=(%zu, %zu)",
812         static_cast<size_t>(a.dim()),
813         static_cast<size_t>(b.dim()));
814     for (size_t d = 0; d < ET_MIN2(a.dim(), b.dim()); ++d) {
815       ET_LOG(
816           Error,
817           "    stride(%zu): (%zu, %zu)",
818           static_cast<size_t>(d),
819           static_cast<size_t>(a.strides()[d]),
820           static_cast<size_t>(b.strides()[d]));
821     }
822 
823     return false;
824   }
825   return true;
826 }
827 
tensors_have_same_strides(executorch::aten::Tensor a,executorch::aten::Tensor b,executorch::aten::Tensor c)828 inline bool tensors_have_same_strides(
829     executorch::aten::Tensor a,
830     executorch::aten::Tensor b,
831     executorch::aten::Tensor c) {
832   if (!(a.strides() == b.strides() && b.strides() == c.strides())) {
833     ET_LOG(
834         Error,
835         ET_TENSOR_CHECK_PREFIX__ ": dim=(%zu, %zu, %zu)",
836         static_cast<size_t>(a.dim()),
837         static_cast<size_t>(b.dim()),
838         static_cast<size_t>(c.dim()));
839     for (size_t d = 0; d < ET_MIN3(a.dim(), b.dim(), c.dim()); ++d) {
840       ET_LOG(
841           Error,
842           "    stride(%zu): (%zu, %zu, %zu)",
843           static_cast<size_t>(d),
844           static_cast<size_t>(a.strides()[d]),
845           static_cast<size_t>(b.strides()[d]),
846           static_cast<size_t>(c.strides()[d]));
847     }
848 
849     return false;
850   }
851   return true;
852 }
853 
tensor_is_contiguous(executorch::aten::Tensor t)854 inline bool tensor_is_contiguous(executorch::aten::Tensor t) {
855   const auto strides = t.strides();
856   const auto sizes = t.sizes();
857   // If tensor is 0-dim (i.e. a scalar tensor) it is contiguous
858   if (strides.size() == 0) {
859     return true;
860   }
861   ET_LOG_MSG_AND_RETURN_IF_FALSE(
862       strides[strides.size() - 1] == 1,
863       "Tensor is not contiguous; the stride of the last dimension must be 1, "
864       "but got %zu",
865       static_cast<size_t>(strides[strides.size() - 1]));
866   for (int i = strides.size() - 1; i > 0; --i) {
867     ET_LOG_MSG_AND_RETURN_IF_FALSE(
868         strides[i - 1] == strides[i] * sizes[i],
869         "Tensor is not contiguous; the stride of dim %zu should be equal to "
870         "strides[%zu] * sizes[%zu] = %zu, but found %zu",
871         static_cast<size_t>(i - 1),
872         static_cast<size_t>(i),
873         static_cast<size_t>(i),
874         static_cast<size_t>(strides[i] * sizes[i]),
875         static_cast<size_t>(strides[i - 1]));
876   }
877   return true;
878 }
879 
tensors_have_same_rank(executorch::aten::Tensor a,executorch::aten::Tensor b)880 inline bool tensors_have_same_rank(
881     executorch::aten::Tensor a,
882     executorch::aten::Tensor b) {
883   ET_LOG_MSG_AND_RETURN_IF_FALSE(
884       a.dim() == b.dim(),
885       ET_TENSOR_CHECK_PREFIX__ ": rank={%zd, %zd}",
886       ssize_t(a.dim()),
887       ssize_t(b.dim()));
888   return true;
889 }
890 
tensor_is_scalar(executorch::aten::Tensor t)891 inline bool tensor_is_scalar(executorch::aten::Tensor t) {
892   return t.dim() == 0 && t.numel() == 1;
893 }
894 
895 /**
896  * The expected output size may not be the existing size of any inputs and
897  * outputs if the operator supports both broadcast and dynamic shape.
898  * Therefore such operators needs extra space to store the calculated expected
899  * output size. such dynamic allocation is troublesome in executorch so we can
900  * just hard code a static value of a relatively small value because users
901  * don't create high dimensional tensors.
902  */
903 constexpr size_t kTensorDimensionLimit = 16;
904 
905 /// Returns the product of dim[0:dim), not including dim.
getLeadingDims(const executorch::aten::Tensor & tensor,int64_t dim)906 inline size_t getLeadingDims(
907     const executorch::aten::Tensor& tensor,
908     int64_t dim) {
909   ET_CHECK_MSG(
910       dim >= 0 && dim <= tensor.dim(),
911       "Ending dimension %" PRId64
912       " should be in the range [0, tensor.dim() %zd].",
913       dim,
914       ssize_t(tensor.dim()));
915   size_t dims = 1;
916   for (size_t i = 0; i < dim; ++i) {
917     dims *= static_cast<size_t>(tensor.size(i));
918   }
919   return dims;
920 }
921 
922 /// Returns the product of dim[dim+1:].
getTrailingDims(const executorch::aten::Tensor & tensor,int64_t dim)923 inline size_t getTrailingDims(
924     const executorch::aten::Tensor& tensor,
925     int64_t dim) {
926   ET_CHECK_MSG(
927       dim >= -1 && dim < tensor.dim(),
928       "Starting dimension %" PRId64
929       " should be in the range [-1, tensor.dim() -1 %zd).",
930       dim,
931       ssize_t(tensor.dim()));
932   size_t dims = 1;
933   for (size_t i = dim + 1; i < tensor.dim(); ++i) {
934     dims *= static_cast<size_t>(tensor.size(i));
935   }
936   return dims;
937 }
938 
939 /**
940  * Given a N-dimensional tensor coordinate, return a linear index that can be
941  * used to access the corresponding element in the tensor's data buffer.
942  *
943  * @param[in] tensor The tensor that will be indexed
944  * @param[in] coordinate A n-dimensional array representing the coordinate to
945  * index. It is assumed that the array has kTensorDimensionLimit elements.
946  * @param[out] index The linear index to element at the specified coordinate
947  * in the tensor.
948  */
coordinateToIndex(const executorch::aten::Tensor & tensor,const size_t * const coordinate)949 inline size_t coordinateToIndex(
950     const executorch::aten::Tensor& tensor,
951     const size_t* const coordinate) {
952   size_t index = 0;
953   for (int d = 0; d < tensor.dim(); ++d) {
954     index += coordinate[d] * getTrailingDims(tensor, d);
955   }
956   return index;
957 }
958 
959 /**
960  * Produce a memoized array for use with repeated calls to
961  * coordinateToIndexWithTrailingDimsMemo, which will be faster than
962  * repeated calls to coordinateToIndex.
963  */
memoizeTrailingDims(const executorch::aten::Tensor & tensor,size_t trailing_dims_memo[kTensorDimensionLimit])964 inline void memoizeTrailingDims(
965     const executorch::aten::Tensor& tensor,
966     size_t trailing_dims_memo[kTensorDimensionLimit]) {
967   const auto tensorDim = tensor.dim();
968   size_t dims = 1;
969   for (int ii = tensorDim - 1; ii >= 0; --ii) {
970     trailing_dims_memo[ii] = dims;
971     dims *= static_cast<size_t>(tensor.size(ii));
972   }
973 }
974 
975 /**
976  * Like coordinateToIndex, but faster for repeated calls with the same
977  * tensor. trailing_dims_memo must be produced by a call to
978  * memoizeTrailingDims.
979  */
coordinateToIndexWithTrailingDimsMemo(const executorch::aten::Tensor & tensor,const size_t * const coordinate,const size_t trailing_dims_memo[kTensorDimensionLimit])980 inline size_t coordinateToIndexWithTrailingDimsMemo(
981     const executorch::aten::Tensor& tensor,
982     const size_t* const coordinate,
983     const size_t trailing_dims_memo[kTensorDimensionLimit]) {
984   size_t index = 0;
985   for (int d = 0; d < tensor.dim(); ++d) {
986     index += coordinate[d] * trailing_dims_memo[d];
987   }
988   return index;
989 }
990 
991 /**
992  * Given the linear index return the N-dimensional tensor coordinate. This is
993  * the inverse operation of coordinateToIndex.
994  *
995  * @param[in] tensor The tensor that will be indexed
996  * @param[in] index The linear index to element at the specified coordinate in
997  * the tensor.
998  * @param[out] coordinate A n-dimensional array representing the coordinate to
999  * index. It is assumed that the array has kTensorDimensionLimit elements.
1000  * @returns void
1001  */
indexToCoordinate(const executorch::aten::Tensor & tensor,size_t index,size_t * coordinate)1002 inline void indexToCoordinate(
1003     const executorch::aten::Tensor& tensor,
1004     size_t index,
1005     size_t* coordinate) {
1006   ET_CHECK(index < tensor.numel());
1007   for (auto i = 0; i < tensor.dim(); ++i) {
1008     auto dim = tensor.dim() - 1 - i;
1009     size_t dim_size = tensor.size(dim);
1010     coordinate[dim] = index % dim_size;
1011     index /= dim_size;
1012   }
1013 }
1014 
1015 /**
1016  * Extracts an integer value from a scalar Tensor.
1017  *
1018  * @param[in] tensor The source of the value to extract.
1019  * @param[out] out_val The extracted value, on success.
1020  * @returns `true` if a value was extracted, and sets `*out_val` to that
1021  * value. `false` if a value could not be extracted: either it was not an
1022  * integer Scalar Tensor, or the value of that Scalar Tensor could not be
1023  * represented by INT_T.
1024  */
1025 template <
1026     typename INT_T,
1027     typename std::enable_if<
1028         std::is_integral<INT_T>::value && !std::is_same<INT_T, bool>::value,
1029         bool>::type = true>
extract_scalar_tensor(executorch::aten::Tensor tensor,INT_T * out_val)1030 bool extract_scalar_tensor(executorch::aten::Tensor tensor, INT_T* out_val) {
1031   if (tensor.numel() != 1) {
1032     return false;
1033   }
1034 #define CASE_INT_DTYPE(TENSOR_CTYPE, TENSOR_DTYPE)                     \
1035   case executorch::aten::ScalarType::TENSOR_DTYPE: {                   \
1036     const TENSOR_CTYPE val = tensor.const_data_ptr<TENSOR_CTYPE>()[0]; \
1037     if (val < std::numeric_limits<INT_T>::lowest() ||                  \
1038         val > std::numeric_limits<INT_T>::max()) {                     \
1039       return false;                                                    \
1040     }                                                                  \
1041     *out_val = static_cast<INT_T>(val);                                \
1042     return true;                                                       \
1043   }
1044 
1045   switch (tensor.scalar_type()) {
1046     ET_FORALL_INT_TYPES(CASE_INT_DTYPE);
1047     default:
1048       return false;
1049   }
1050 #undef CASE_INT_DTYPE
1051 }
1052 
1053 /**
1054  * Extracts a floating point value from a scalar Tensor.
1055  *
1056  * @param[in] tensor The source of the value to extract.
1057  * @param[out] out_val The extracted value, on success.
1058  * @returns `true` if a value was extracted, and sets `*out_val` to that
1059  * value. `false` if a value could not be extracted: either it was not a
1060  * floating point Scalar Tensor, or the value of that Scalar Tensor could not
1061  * be represented by FLOAT_T.
1062  */
1063 template <
1064     typename FLOAT_T,
1065     typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
1066         type = true>
extract_scalar_tensor(executorch::aten::Tensor tensor,FLOAT_T * out_val)1067 bool extract_scalar_tensor(executorch::aten::Tensor tensor, FLOAT_T* out_val) {
1068   if (tensor.numel() != 1) {
1069     return false;
1070   }
1071 #define CASE_REAL_DTYPE(TENSOR_CTYPE, TENSOR_DTYPE)                    \
1072   case executorch::aten::ScalarType::TENSOR_DTYPE: {                   \
1073     /* ET_FORALL_REAL_TYPES guarantees TENSOR_CTYPE is a real type. */ \
1074     double val =                                                       \
1075         static_cast<double>(tensor.const_data_ptr<TENSOR_CTYPE>()[0]); \
1076     if (std::isfinite(val) &&                                          \
1077         (val < std::numeric_limits<FLOAT_T>::lowest() ||               \
1078          val > std::numeric_limits<FLOAT_T>::max())) {                 \
1079       return false;                                                    \
1080     }                                                                  \
1081     *out_val = static_cast<FLOAT_T>(val);                              \
1082     return true;                                                       \
1083   }
1084 
1085   switch (tensor.scalar_type()) {
1086     ET_FORALL_REAL_TYPES(CASE_REAL_DTYPE);
1087     default:
1088       return false;
1089   }
1090 #undef CASE_REAL_DTYPE
1091 }
1092 
1093 /**
1094  * Extracts a boolean value from a Scalar.
1095  *
1096  * @param[in] scalar The source of the value to extract.
1097  * @param[out] out_val The extracted value, on success.
1098  * @returns `true` if a value was extracted, and sets `*out_val` to that
1099  * value. `false` if a value could not be extracted, i.e. not a boolean
1100  */
1101 template <
1102     typename BOOL_T,
1103     typename std::enable_if<std::is_same<BOOL_T, bool>::value, bool>::type =
1104         true>
extract_scalar_tensor(executorch::aten::Tensor tensor,BOOL_T * out_val)1105 bool extract_scalar_tensor(executorch::aten::Tensor tensor, BOOL_T* out_val) {
1106   if (tensor.scalar_type() != executorch::aten::ScalarType::Bool) {
1107     return false;
1108   }
1109   if (tensor.numel() != 1) {
1110     return false;
1111   }
1112 
1113   bool val = tensor.const_data_ptr<bool>()[0];
1114 
1115   *out_val = static_cast<BOOL_T>(val);
1116 
1117   return true;
1118 }
1119 
1120 /// These APIs should not be used outside of Executor.cpp.
1121 namespace internal {
1122 /**
1123  * Share t_src's data_ptr with t_dst.
1124  */
1125 ET_NODISCARD Error share_tensor_data(
1126     const executorch::aten::Tensor& t_dst,
1127     const executorch::aten::Tensor& t_src);
1128 
1129 /**
1130  * Copy t_src's data_ptr to t_dst.
1131  */
1132 ET_NODISCARD Error copy_tensor_data(
1133     const executorch::aten::Tensor& t_dst,
1134     const executorch::aten::Tensor& t_src);
1135 
1136 /**
1137  * Set the data_ptr of t to buffer.
1138  */
1139 ET_NODISCARD Error set_tensor_data(
1140     const executorch::aten::Tensor& t,
1141     void* buffer,
1142     size_t buffer_size);
1143 
1144 /**
1145  * Reset tensor's data_ptr, clear all the storage for at::Tensor.
1146  */
1147 void reset_data_ptr(const executorch::aten::Tensor& tensor);
1148 
1149 /**
1150  * Resize tensor impl
1151  */
1152 ET_NODISCARD Error resize_tensor_impl(
1153     executorch::aten::TensorImpl* impl,
1154     executorch::aten::ArrayRef<executorch::aten::SizesType> new_sizes);
1155 
1156 } // namespace internal
1157 
1158 /**
1159  * Resize a tensor to new_sizes, rank must stay the same. Currently does not
1160  * expand the tensor if new size exceeds the current capacity. Currently
1161  * fails an ET_CHECK if the tensor cannot be resized.
1162  *
1163  * WARNING: Placeholder API until discussion around runtime context is
1164  * settled, will likely move to be a class method on a TensorResizer object
1165  * passed in through runtimeContext.
1166  */
resize_tensor(executorch::aten::Tensor t,executorch::aten::ArrayRef<executorch::aten::SizesType> new_sizes)1167 ET_NODISCARD inline Error resize_tensor(
1168     executorch::aten::Tensor t,
1169     executorch::aten::ArrayRef<executorch::aten::SizesType> new_sizes) {
1170   return internal::resize_tensor_impl(t.unsafeGetTensorImpl(), new_sizes);
1171 }
1172 
1173 /**
1174  * Resize a tensor to new_sizes, rank must stay the same. Currently does not
1175  * expand the tensor if new size exceeds the current capacity. Currently
1176  * fails an ET_CHECK if the tensor cannot be resized.
1177  *
1178  * WARNING: Placeholder API until discussion around runtime context is
1179  * settled, will likely move to be a class method on a TensorResizer object
1180  * passed in through runtimeContext.
1181  */
1182 template <
1183     typename T,
1184     typename std::enable_if<
1185         !std::is_same<executorch::aten::SizesType, T>::value,
1186         int>::type = 0>
resize_tensor(executorch::aten::Tensor t,executorch::aten::ArrayRef<T> new_sizes)1187 ET_NODISCARD inline Error resize_tensor(
1188     executorch::aten::Tensor t,
1189     executorch::aten::ArrayRef<T> new_sizes) {
1190   // Need to cast the input array to an array of Tensor::SizesType
1191   std::array<executorch::aten::SizesType, kTensorDimensionLimit>
1192       new_sizes_casted{};
1193   size_t new_sizes_ndim = new_sizes.size();
1194   for (size_t i = 0; i < new_sizes_ndim; ++i) {
1195     new_sizes_casted[i] =
1196         static_cast<executorch::aten::SizesType>(new_sizes[i]);
1197   }
1198 
1199   return internal::resize_tensor_impl(
1200       t.unsafeGetTensorImpl(), {new_sizes_casted.data(), new_sizes_ndim});
1201 }
1202 
1203 /// DEPRECATED: Use `resize_tensor()` instead, which can fail non-fatally.
resize(executorch::aten::Tensor t,executorch::aten::ArrayRef<executorch::aten::SizesType> new_sizes)1204 ET_DEPRECATED inline void resize(
1205     executorch::aten::Tensor t,
1206     executorch::aten::ArrayRef<executorch::aten::SizesType> new_sizes) {
1207   Error err = resize_tensor(t, new_sizes);
1208   ET_CHECK_MSG(
1209       err == Error::Ok, "Could not resize Tensor; see logs for details");
1210 }
1211 /**
1212  * Get dim_order of a Tensor and write it to out_dim_order.
1213  * @param tensor The tensor where we want to get dim order from.
1214  * @param out_dim_order Pointing to an array of DimOrderType where we write
1215  * dim order into it.
1216  * @param out_dim_order_size Size of the DimOrderType array.
1217  */
1218 ET_NODISCARD Error get_dim_order(
1219     const executorch::aten::Tensor& tensor,
1220     executorch::aten::DimOrderType* out_dim_order,
1221     size_t out_dim_order_size);
1222 
1223 /**
1224  * Checks whether a tensor has a valid dim order. If the dim order could not
1225  * be determined, then this function returns false by default.
1226  */
1227 bool tensor_has_valid_dim_order(executorch::aten::Tensor t);
1228 
1229 /**
1230  * Checks whether a tensor has either the default of channels last dim order.
1231  * If the dim order could not be determined, then this function returns false
1232  * by default.
1233  */
1234 bool tensor_is_default_or_channels_last_dim_order(executorch::aten::Tensor t);
1235 
1236 /**
1237  * Checks whether a tensor has the default dimension order.
1238  * Logs an error message if the tensor does not meet the expected criteria.
1239  *
1240  * @param t The tensor to check the dimension order of.
1241  * @return True if the tensor has the default dimension order, false otherwise.
1242  */
1243 bool tensor_is_default_dim_order(executorch::aten::Tensor t);
1244 
1245 /**
1246  * Checks whether a tensor has the channels last dimension order.
1247  * Logs an error message if the tensor does not meet the expected criteria.
1248  *
1249  * @param t The tensor to check the dimension order of.
1250  * @return True if the tensor has the channels last dimension order, false
1251  * otherwise.
1252  */
1253 bool tensor_is_channels_last_dim_order(executorch::aten::Tensor t);
1254 
1255 /**
1256  * Asserts that four tensors have the same dim_order
1257  *
1258  * Note that this macro only tests dim order, but not others like actual data,
1259  * sizes, etc.
1260  *
1261  */
1262 bool tensors_have_same_dim_order(
1263     const executorch::aten::ArrayRef<executorch::aten::Tensor> tensor_list);
1264 
1265 /**
1266  * Asserts that two tensors have the same dim_order
1267  *
1268  * Note that this macro only tests dim order, but not others like actual data,
1269  * sizes, etc.
1270  */
1271 
tensors_have_same_dim_order(const executorch::aten::Tensor & a,const executorch::aten::Tensor & b)1272 inline bool tensors_have_same_dim_order(
1273     const executorch::aten::Tensor& a,
1274     const executorch::aten::Tensor& b) {
1275   executorch::aten::Tensor tensor_list[2] = {a, b};
1276   return tensors_have_same_dim_order(tensor_list);
1277 }
1278 
1279 /**
1280  * Asserts that three tensors have the same dim_order
1281  *
1282  * Note that this macro only tests dim order, but not others like actual data,
1283  * sizes, etc.
1284  *
1285  */
1286 
tensors_have_same_dim_order(const executorch::aten::Tensor & a,const executorch::aten::Tensor & b,const executorch::aten::Tensor & c)1287 inline bool tensors_have_same_dim_order(
1288     const executorch::aten::Tensor& a,
1289     const executorch::aten::Tensor& b,
1290     const executorch::aten::Tensor& c) {
1291   executorch::aten::Tensor tensor_list[3] = {a, b, c};
1292   return tensors_have_same_dim_order(tensor_list);
1293 }
1294 
1295 /**
1296  * Asserts that four tensors have the same dim_order
1297  *
1298  * Note that this macro only tests dim order, but not others like actual data,
1299  * sizes, etc.
1300  *
1301  */
1302 
tensors_have_same_dim_order(const executorch::aten::Tensor & a,const executorch::aten::Tensor & b,const executorch::aten::Tensor & c,const executorch::aten::Tensor & d)1303 inline bool tensors_have_same_dim_order(
1304     const executorch::aten::Tensor& a,
1305     const executorch::aten::Tensor& b,
1306     const executorch::aten::Tensor& c,
1307     const executorch::aten::Tensor& d) {
1308   executorch::aten::Tensor tensor_list[4] = {a, b, c, d};
1309   return tensors_have_same_dim_order(tensor_list);
1310 }
1311 
1312 /**
1313  * Given an n-dimensional coordinate array and an array of tensor strides,
1314  * calculates the linear index that can be used to retrieve the value at the
1315  * given coordinates.
1316  * @param coordinate Pointer to the array of coordinates.
1317  * @param strides Pointer to the array of strides.
1318  * @param ndim Number of dimensions in the tensor.
1319  */
calculate_linear_index(const executorch::aten::SizesType * coordinate,const executorch::aten::StridesType * strides,const size_t ndim)1320 inline size_t calculate_linear_index(
1321     const executorch::aten::SizesType* coordinate,
1322     const executorch::aten::StridesType* strides,
1323     const size_t ndim) {
1324   size_t index = 0;
1325   for (size_t i = 0; i < ndim; i++) {
1326     index += coordinate[i] * strides[i];
1327   }
1328   return index;
1329 }
1330 
1331 } // namespace runtime
1332 } // namespace executorch
1333 
1334 namespace torch {
1335 namespace executor {
1336 // TODO(T197294990): Remove these deprecated aliases once all users have moved
1337 // to the new `::executorch` namespaces.
1338 using ::executorch::runtime::calculate_linear_index;
1339 using ::executorch::runtime::coordinateToIndex;
1340 using ::executorch::runtime::dim_is_valid;
1341 using ::executorch::runtime::extract_scalar_tensor;
1342 using ::executorch::runtime::get_dim_order;
1343 using ::executorch::runtime::getLeadingDims;
1344 using ::executorch::runtime::getTrailingDims;
1345 using ::executorch::runtime::indexToCoordinate;
1346 using ::executorch::runtime::kTensorDimensionLimit;
1347 using ::executorch::runtime::nonempty_size;
1348 using ::executorch::runtime::nonzero_dim;
1349 using ::executorch::runtime::resize;
1350 using ::executorch::runtime::resize_tensor;
1351 using ::executorch::runtime::tensor_can_cast_to;
1352 using ::executorch::runtime::tensor_dim_has_index;
1353 using ::executorch::runtime::tensor_has_dim;
1354 using ::executorch::runtime::tensor_has_expected_size;
1355 using ::executorch::runtime::tensor_has_non_empty_dim;
1356 using ::executorch::runtime::tensor_has_rank_greater_or_equal_to;
1357 using ::executorch::runtime::tensor_has_rank_smaller_or_equal_to;
1358 using ::executorch::runtime::tensor_has_valid_dim_order;
1359 using ::executorch::runtime::tensor_is_bits_type;
1360 using ::executorch::runtime::tensor_is_bool_type;
1361 using ::executorch::runtime::tensor_is_complex_type;
1362 using ::executorch::runtime::tensor_is_contiguous;
1363 using ::executorch::runtime::tensor_is_default_dim_order;
1364 using ::executorch::runtime::tensor_is_default_or_channels_last_dim_order;
1365 using ::executorch::runtime::tensor_is_floating_type;
1366 using ::executorch::runtime::tensor_is_integral_type;
1367 using ::executorch::runtime::tensor_is_rank;
1368 using ::executorch::runtime::tensor_is_real_type;
1369 using ::executorch::runtime::tensor_is_realh_type;
1370 using ::executorch::runtime::tensor_is_realhb_type;
1371 using ::executorch::runtime::tensor_is_scalar;
1372 using ::executorch::runtime::tensors_have_same_dim_order;
1373 using ::executorch::runtime::tensors_have_same_dtype;
1374 using ::executorch::runtime::tensors_have_same_rank;
1375 using ::executorch::runtime::tensors_have_same_shape;
1376 using ::executorch::runtime::tensors_have_same_shape_and_dtype;
1377 using ::executorch::runtime::tensors_have_same_size_at_dims;
1378 using ::executorch::runtime::tensors_have_same_strides;
1379 namespace internal {
1380 using ::executorch::runtime::internal::copy_tensor_data;
1381 using ::executorch::runtime::internal::reset_data_ptr;
1382 using ::executorch::runtime::internal::resize_tensor_impl;
1383 using ::executorch::runtime::internal::set_tensor_data;
1384 using ::executorch::runtime::internal::share_tensor_data;
1385 } // namespace internal
1386 } // namespace executor
1387 } // namespace torch
1388