1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <algorithm>
12 #include <array> // std::array
13 #include <cinttypes> // PRId64
14 #include <cmath>
15 #include <cstddef> // size_t
16 #include <limits>
17
18 #include <executorch/runtime/core/array_ref.h>
19 #include <executorch/runtime/core/error.h>
20 #include <executorch/runtime/core/exec_aten/exec_aten.h>
21 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
22 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
23 #include <executorch/runtime/platform/assert.h>
24 #include <executorch/runtime/platform/compiler.h>
25
26 /// All assertion messages should begin with this prefix.
27 #define ET_TENSOR_CHECK_PREFIX__ "Tensors do not match"
28 #define ET_MIN2(a, b) (std::min(a, b))
29 #define ET_MIN3(a, b, c) (std::min(a, std::min(b, c)))
30
31 #define ET_NORMALIZE_IX(IX, UPPER_BOUND) IX < 0 ? IX + UPPER_BOUND : IX
32
33 #define ET_CHECK_VALID_IX(IX, UPPER_BOUND) \
34 ET_CHECK_MSG( \
35 IX >= -static_cast<int64_t>(UPPER_BOUND) && \
36 IX < static_cast<int64_t>(UPPER_BOUND), \
37 "index %" PRId64 " must be within range [-%zd, %zd)", \
38 IX, \
39 UPPER_BOUND, \
40 UPPER_BOUND)
41
42 #define ET_CHECK_VALID_DIM(DIM, UPPER_BOUND) \
43 ET_CHECK_MSG( \
44 DIM >= -static_cast<int64_t>(UPPER_BOUND) && \
45 DIM < static_cast<int64_t>(UPPER_BOUND), \
46 "dim %" PRId64 " must be within range [-%zd, %zd)", \
47 DIM, \
48 UPPER_BOUND, \
49 UPPER_BOUND)
50
51 #define ET_CHECK_NON_ZERO_DIM_SIZE(DIM, T) \
52 const size_t udim = ET_NORMALIZE_IX(DIM, T.dim()); \
53 ET_CHECK_MSG( \
54 T.size(udim) != 0, "Expected dim %zd to have non-zero size.", udim);
55
56 /**
57 * Asserts that all tensors have the same shape.
58 * This also handles a edge case where there is only one element in all the
59 * tensors being compared but the number of dimensions >= 0. In the for loop
60 * iterating over the dimensions we make sure that we pick the smallest
61 * dimension of all the tensors as the upper bound for the for loop.
62 */
63 #define ET_CHECK_SAME_SHAPE2(a__, b__) \
64 ({ \
65 const size_t a_numel__ = (a__).numel(); \
66 const size_t b_numel__ = (b__).numel(); \
67 const size_t a_dim__ = (a__).dim(); \
68 const size_t b_dim__ = (b__).dim(); \
69 ET_CHECK_MSG( \
70 a_numel__ == b_numel__ && \
71 ((a_numel__ == 1 && b_numel__ == 1) || (a_dim__ == b_dim__)), \
72 ET_TENSOR_CHECK_PREFIX__ ": numel={%zu, %zu}, dim={%zu, %zu}", \
73 a_numel__, \
74 b_numel__, \
75 a_dim__, \
76 b_dim__); \
77 for (size_t dim__ = 0; dim__ < ET_MIN2(a_dim__, b_dim__); ++dim__) { \
78 size_t a_size__ = (a__).size(dim__); \
79 size_t b_size__ = (b__).size(dim__); \
80 ET_CHECK_MSG( \
81 a_size__ == b_size__, \
82 ET_TENSOR_CHECK_PREFIX__ " at size(%zu): {%zu, %zu}", \
83 dim__, \
84 a_size__, \
85 b_size__); \
86 } \
87 })
88
89 #define ET_CHECK_SAME_SHAPE3(a__, b__, c__) \
90 ({ \
91 const size_t a_numel__ = (a__).numel(); \
92 const size_t b_numel__ = (b__).numel(); \
93 const size_t c_numel__ = (c__).numel(); \
94 const size_t a_dim__ = (a__).dim(); \
95 const size_t b_dim__ = (b__).dim(); \
96 const size_t c_dim__ = (c__).dim(); \
97 ET_CHECK_MSG( \
98 a_numel__ == b_numel__ && b_numel__ == c_numel__ && \
99 ((a_numel__ == 1 && b_numel__ == 1 && c_numel__ == 1) || \
100 a_dim__ == b_dim__ && b_dim__ == c_dim__), \
101 ET_TENSOR_CHECK_PREFIX__ \
102 ": numel={%zu, %zu, %zu}, dim={%zu, %zu, %zu}", \
103 a_numel__, \
104 b_numel__, \
105 c_numel__, \
106 a_dim__, \
107 b_dim__, \
108 c_dim__); \
109 for (size_t dim__ = 0; dim__ < ET_MIN3(a_dim__, b_dim__, c_dim__); \
110 ++dim__) { \
111 size_t a_size__ = (a__).size(dim__); \
112 size_t b_size__ = (b__).size(dim__); \
113 size_t c_size__ = (c__).size(dim__); \
114 ET_CHECK_MSG( \
115 a_size__ == b_size__ && b_size__ == c_size__, \
116 ET_TENSOR_CHECK_PREFIX__ " at size(%zu): {%zu, %zu, %zu}", \
117 dim__, \
118 a_size__, \
119 b_size__, \
120 c_size__); \
121 } \
122 })
123
124 /// Asserts that all tensors have the same dtype.
125 #define ET_CHECK_SAME_DTYPE2(a__, b__) \
126 ({ \
127 const ::executorch::aten::ScalarType a_type__ = (a__).scalar_type(); \
128 const ::executorch::aten::ScalarType b_type__ = (b__).scalar_type(); \
129 ET_CHECK_MSG( \
130 a_type__ == b_type__, \
131 ET_TENSOR_CHECK_PREFIX__ ": dtype={%" PRId8 ", %" PRId8 "}", \
132 static_cast<int8_t>(a_type__), \
133 static_cast<int8_t>(b_type__)); \
134 })
135
136 #define ET_CHECK_SAME_DTYPE3(a__, b__, c__) \
137 ({ \
138 const ::executorch::aten::ScalarType a_type__ = (a__).scalar_type(); \
139 const ::executorch::aten::ScalarType b_type__ = (b__).scalar_type(); \
140 const ::executorch::aten::ScalarType c_type__ = (c__).scalar_type(); \
141 ET_CHECK_MSG( \
142 a_type__ == b_type__ && b_type__ == c_type__, \
143 ET_TENSOR_CHECK_PREFIX__ ": dtype={%" PRId8 ", %" PRId8 ", %" PRId8 \
144 "}", \
145 static_cast<int8_t>(a_type__), \
146 static_cast<int8_t>(b_type__), \
147 static_cast<int8_t>(c_type__)); \
148 })
149
150 /**
151 * Asserts that all tensors have the same shape and dtype.
152 *
153 * This macro should produce less code/data than calling the SHAPE and DTYPE
154 * macros independently, because it only calls ET_CHECK_MSG once.
155 */
156 #define ET_CHECK_SAME_SHAPE_AND_DTYPE2(a__, b__) \
157 ({ \
158 const size_t a_numel__ = (a__).numel(); \
159 const size_t b_numel__ = (b__).numel(); \
160 const size_t a_dim__ = (a__).dim(); \
161 const size_t b_dim__ = (b__).dim(); \
162 const ::executorch::aten::ScalarType a_type__ = (a__).scalar_type(); \
163 const ::executorch::aten::ScalarType b_type__ = (b__).scalar_type(); \
164 \
165 ET_CHECK_MSG( \
166 a_numel__ == b_numel__ && \
167 ((a_numel__ == 1 && b_numel__ == 1) || a_dim__ == b_dim__) && \
168 a_type__ == b_type__, \
169 ET_TENSOR_CHECK_PREFIX__ \
170 ": numel={%zu, %zu}, dim={%zu, %zu}, dtype={%" PRId8 ", %" PRId8 "}", \
171 a_numel__, \
172 b_numel__, \
173 a_dim__, \
174 b_dim__, \
175 static_cast<int8_t>(a_type__), \
176 static_cast<int8_t>(b_type__)); \
177 for (size_t dim__ = 0; dim__ < ET_MIN2(a_dim__, b_dim__); ++dim__) { \
178 size_t a_size__ = (a__).size(dim__); \
179 size_t b_size__ = (b__).size(dim__); \
180 ET_CHECK_MSG( \
181 a_size__ == b_size__, \
182 ET_TENSOR_CHECK_PREFIX__ " at size(%zu): {%zu, %zu}", \
183 dim__, \
184 a_size__, \
185 b_size__); \
186 } \
187 })
188
189 #define ET_CHECK_SAME_SHAPE_AND_DTYPE3(a__, b__, c__) \
190 ({ \
191 const size_t a_numel__ = (a__).numel(); \
192 const size_t b_numel__ = (b__).numel(); \
193 const size_t c_numel__ = (c__).numel(); \
194 const size_t a_dim__ = (a__).dim(); \
195 const size_t b_dim__ = (b__).dim(); \
196 const size_t c_dim__ = (c__).dim(); \
197 const ::executorch::aten::ScalarType a_type__ = (a__).scalar_type(); \
198 const ::executorch::aten::ScalarType b_type__ = (b__).scalar_type(); \
199 const ::executorch::aten::ScalarType c_type__ = (c__).scalar_type(); \
200 \
201 ET_CHECK_MSG( \
202 a_numel__ == b_numel__ && b_numel__ == c_numel__ && \
203 ((a_numel__ == 1 && b_numel__ == 1 && c_numel__ == 1) || \
204 (a_dim__ == b_dim__ && b_dim__ == c_dim__)) && \
205 a_type__ == b_type__ && b_type__ == c_type__, \
206 ET_TENSOR_CHECK_PREFIX__ \
207 ": numel={%zu, %zu, %zu}, dim={%zu, %zu, %zu}, " \
208 "dtype={%" PRId8 ", %" PRId8 ", %" PRId8 "}", \
209 a_numel__, \
210 b_numel__, \
211 c_numel__, \
212 a_dim__, \
213 b_dim__, \
214 c_dim__, \
215 static_cast<int8_t>(a_type__), \
216 static_cast<int8_t>(b_type__), \
217 static_cast<int8_t>(c_type__)); \
218 for (size_t dim__ = 0; dim__ < ET_MIN3(a_dim__, b_dim__, c_dim__); \
219 ++dim__) { \
220 size_t a_size__ = (a__).size(dim__); \
221 size_t b_size__ = (b__).size(dim__); \
222 size_t c_size__ = (c__).size(dim__); \
223 ET_CHECK_MSG( \
224 a_size__ == b_size__ && b_size__ == c_size__, \
225 ET_TENSOR_CHECK_PREFIX__ " at size(%zu): {%zu, %zu, %zu}", \
226 dim__, \
227 a_size__, \
228 b_size__, \
229 c_size__); \
230 } \
231 })
232
233 /**
234 * Assert that the input tensor is contiguous tensor.
235 */
236 #define ET_CHECK_CONTIGUOUS(a__) \
237 ({ \
238 const ::executorch::aten::ArrayRef<executorch::aten::StridesType> \
239 strides = a__.strides(); \
240 const ::executorch::aten::ArrayRef<executorch::aten::StridesType> sizes = \
241 a__.sizes(); \
242 ET_CHECK_MSG( \
243 strides[strides.size() - 1] == 1, \
244 "The stride of the last dimension shall be 1 for contiguous tensor, " \
245 "not %d", \
246 strides[strides.size() - 1]); \
247 for (size_t i = strides.size() - 1; i > 0; i--) { \
248 ET_CHECK_MSG( \
249 strides[i - 1] == strides[i] * sizes[i], \
250 "The stride of the %zu-th dimension shall equal to " \
251 "strides[%zu] * sizes[%zu], now is %d and %d", \
252 i - 1, \
253 i, \
254 i, \
255 strides[i - 1], \
256 strides[i] * sizes[i]); \
257 } \
258 })
259
260 /**
261 * Assert the input two tensors share same strides.
262 * Noted that this function does not make any check or promise on the contiguity
263 * of any input tensors.
264 */
265 #define ET_CHECK_SAME_STRIDES2(a__, b__) \
266 ({ \
267 ET_CHECK_MSG( \
268 a__.dim() == b__.dim(), \
269 "Two tensors shall have same number of strides, but not %zu and %zu.", \
270 a__.dim(), \
271 b__.dim()); \
272 const ::executorch::aten::ArrayRef<executorch::aten::StridesType> \
273 a_strides = a__.strides(); \
274 const ::executorch::aten::ArrayRef<executorch::aten::StridesType> \
275 b_strides = b__.strides(); \
276 for (size_t i = 0; i < a__.dim(); i++) { \
277 ET_CHECK_MSG( \
278 a_strides[i] == b_strides[i], \
279 "a.strides()[%zu] shall equal to b.strides()[%zu], " \
280 "but now is %d and %d.", \
281 i, \
282 i, \
283 (int32_t)a_strides[i], \
284 (int32_t)b_strides[i]); \
285 } \
286 })
287
288 /**
289 * Assert the input three tensors share same strides.
290 * Noted that this function does not make any check or promise on the contiguity
291 * of any input tensors.
292 */
293 #define ET_CHECK_SAME_STRIDES3(a__, b__, c__) \
294 ({ \
295 ET_CHECK_MSG( \
296 a__.dim() == b__.dim() && b__.dim() == c__.dim(), \
297 "Three tensors shall have same number of strides, " \
298 "but not %zu, %zu and %zu.", \
299 a__.dim(), \
300 b__.dim(), \
301 c__.dim()); \
302 const ::executorch::aten::ArrayRef<executorch::aten::StridesType> \
303 a_strides = a__.strides(); \
304 const ::executorch::aten::ArrayRef<executorch::aten::StridesType> \
305 b_strides = b__.strides(); \
306 const ::executorch::aten::ArrayRef<executorch::aten::StridesType> \
307 c_strides = c__.strides(); \
308 for (size_t i = 0; i < a__.dim(); i++) { \
309 ET_CHECK_MSG( \
310 a_strides[i] == b_strides[i] && b_strides[i] == c_strides[i], \
311 "a_strides[%zu], b_strides[%zu] and c_strides[%zu] " \
312 "shall share same value, but now is %d, %d and %d", \
313 i, \
314 i, \
315 i, \
316 (int32_t)a_strides[i], \
317 (int32_t)b_strides[i], \
318 (int32_t)c_strides[i]); \
319 } \
320 })
321
322 #define ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(t__) \
323 ({ \
324 ET_CHECK_MSG( \
325 is_contiguous_dim_order( \
326 t__.dim_order().data(), t__.dim_order().size()) || \
327 is_channels_last_dim_order( \
328 t__.dim_order().data(), t__.dim_order().size()), \
329 "Tensor must have default or channels last dim order"); \
330 })
331
332 /**
333 * A convenience macro to be used in utility functions that check whether input
334 * tensor(s) are valid, which are expected to return a boolean. Checks whether
335 * `cond` is true; if not, log the failed check and return false.
336 *
337 * @param[in] cond the condition to check
338 */
339 #define ET_LOG_AND_RETURN_IF_FALSE(cond) \
340 do { \
341 if (!(cond)) { \
342 ET_LOG(Error, "Check failed (%s): ", #cond); \
343 return false; \
344 } \
345 } while (false)
346
347 /**
348 * A convenience macro to be used in utility functions that check whether input
349 * tensor(s) are valid, which are expected to return a boolean. Checks whether
350 * `cond` is true; if not, log the failed check with `message` and return false.
351 *
352 * @param[in] cond the condition to check
353 * @param[in] message an additional message to log with `cond`
354 */
355 #define ET_LOG_MSG_AND_RETURN_IF_FALSE(cond, message, ...) \
356 do { \
357 if (!(cond)) { \
358 ET_LOG(Error, "Check failed (%s): " message, #cond, ##__VA_ARGS__); \
359 return false; \
360 } \
361 } while (false)
362
363 /**
364 * If `cond` is false, log `cond` and return from the kernel with a failure
365 * state set.
366 *
367 * @param[in] context the runtime context
368 * @param[in] cond the condition to check
369 * @param[in] error torch::executor::Error enum value (e.g `InvalidArgument`)
370 * @param[in] retval return value of the kernel to allow for early exit
371 */
372 #define ET_KERNEL_CHECK(context, cond, error, retval) \
373 do { \
374 if (!(cond)) { \
375 ET_LOG(Error, "Check failed (%s): ", #cond); \
376 context.fail(torch::executor::Error::error); \
377 return retval; \
378 } \
379 } while (false)
380
381 /**
382 * If `cond` is false, log `message` and return from the kernel with a failure
383 * state set.
384 *
385 * @param[in] context the runtime context
386 * @param[in] cond the condition to check
387 * @param[in] error torch::executor::Error enum value (e.g `InvalidArgument`)
388 * @param[in] retval return value of the kernel to allow for early exit
389 */
390 #define ET_KERNEL_CHECK_MSG(context, cond, error, retval, message, ...) \
391 do { \
392 if (!(cond)) { \
393 ET_LOG(Error, "Check failed (%s): " message, #cond, ##__VA_ARGS__); \
394 context.fail(torch::executor::Error::error); \
395 return retval; \
396 } \
397 } while (false)
398
399 /**
400 * Convenience macro to extract a scalar tensor into a value
401 */
402 #define ET_EXTRACT_SCALAR_TENSOR(scalar_tensor, out_val) \
403 ET_CHECK_MSG( \
404 extract_scalar_tensor(scalar_tensor, &out_val), \
405 #scalar_tensor " could not be extracted: wrong type or out of range");
406
407 namespace executorch {
408 namespace runtime {
409
410 //
411 // Utility functions for checking tensor attributes
412 //
413 //
414
415 /*
416 * Returns true if the given dimension value is between -upper_bound and
417 * upper_bound - 1, inclusive.
418 */
dim_is_valid(int64_t dim,int64_t upper_bound)419 inline bool dim_is_valid(int64_t dim, int64_t upper_bound) {
420 ET_LOG_MSG_AND_RETURN_IF_FALSE(
421 dim >= -upper_bound && dim < upper_bound,
422 "Dimension %" PRId64
423 " is out of range. Dimension should be between %" PRId64 " and %" PRId64
424 ", inclusive.",
425 dim,
426 -upper_bound,
427 upper_bound - 1);
428
429 return true;
430 }
431
432 /*
433 * Returns the tensor's number of dimensions, except when the tensor is zero
434 * dimensional. In this case, it returns 1. This is used to properly handle
435 * the zero dimensional tensors in some kernels, that treat them as 1D tensors
436 * with a single element.
437 */
nonzero_dim(const executorch::aten::Tensor & tensor)438 inline ssize_t nonzero_dim(const executorch::aten::Tensor& tensor) {
439 return tensor.dim() == 0 ? 1 : tensor.dim();
440 }
441
442 /*
443 * Returns the size along a dimension dim, except when the tensor is zero
444 * dimensional. In this case, it returns 1. This is used to properly handle
445 * the zero dimensional tensors in some kernels, that treat them as 1D tensors
446 * with a single element.
447 */
nonempty_size(const executorch::aten::Tensor & tensor,ssize_t dim)448 inline ssize_t nonempty_size(
449 const executorch::aten::Tensor& tensor,
450 ssize_t dim) {
451 return tensor.dim() == 0 ? 1 : tensor.size(dim);
452 }
453
tensor_can_cast_to(executorch::aten::Tensor a,executorch::aten::ScalarType dtype)454 inline bool tensor_can_cast_to(
455 executorch::aten::Tensor a,
456 executorch::aten::ScalarType dtype) {
457 ET_LOG_MSG_AND_RETURN_IF_FALSE(
458 torch::executor::canCast(a.scalar_type(), dtype),
459 "Tensor of dtype %s cannot cast to dtype %s",
460 torch::executor::toString(a.scalar_type()),
461 torch::executor::toString(dtype));
462
463 return true;
464 }
465
tensor_is_bool_type(executorch::aten::Tensor t)466 inline bool tensor_is_bool_type(executorch::aten::Tensor t) {
467 ET_LOG_MSG_AND_RETURN_IF_FALSE(
468 t.scalar_type() == executorch::aten::ScalarType::Bool,
469 "Expected to find bool type, but tensor has type %s",
470 torch::executor::toString(t.scalar_type()));
471
472 return true;
473 }
474
tensor_is_type(executorch::aten::Tensor t,executorch::aten::ScalarType dtype)475 inline bool tensor_is_type(
476 executorch::aten::Tensor t,
477 executorch::aten::ScalarType dtype) {
478 ET_LOG_MSG_AND_RETURN_IF_FALSE(
479 t.scalar_type() == dtype,
480 "Expected to find %s type, but tensor has type %s",
481 torch::executor::toString(dtype),
482 torch::executor::toString(t.scalar_type()));
483
484 return true;
485 }
486
487 inline bool tensor_is_integral_type(
488 executorch::aten::Tensor t,
489 bool includeBool = false) {
490 ET_LOG_MSG_AND_RETURN_IF_FALSE(
491 torch::executor::isIntegralType(t.scalar_type(), includeBool),
492 "Expected to find a integral type, but tensor has type %s",
493 torch::executor::toString(t.scalar_type()));
494
495 return true;
496 }
497
tensor_is_floating_type(executorch::aten::Tensor t)498 inline bool tensor_is_floating_type(executorch::aten::Tensor t) {
499 ET_LOG_MSG_AND_RETURN_IF_FALSE(
500 torch::executor::isFloatingType(t.scalar_type()),
501 "Expected to find a floating type, but tensor has type %s",
502 torch::executor::toString(t.scalar_type()));
503
504 return true;
505 }
506
tensor_is_real_type(executorch::aten::Tensor t)507 inline bool tensor_is_real_type(executorch::aten::Tensor t) {
508 ET_LOG_MSG_AND_RETURN_IF_FALSE(
509 torch::executor::isRealType(t.scalar_type()),
510 "Expected to find a real type, but tensor has type %s",
511 torch::executor::toString(t.scalar_type()));
512
513 return true;
514 }
515
tensor_is_realh_type(executorch::aten::Tensor t)516 inline bool tensor_is_realh_type(executorch::aten::Tensor t) {
517 ET_LOG_MSG_AND_RETURN_IF_FALSE(
518 torch::executor::isRealHType(t.scalar_type()),
519 "Expected to find a real type, but tensor has type %s",
520 torch::executor::toString(t.scalar_type()));
521
522 return true;
523 }
524
tensor_is_realhbf16_type(executorch::aten::Tensor t)525 inline bool tensor_is_realhbf16_type(executorch::aten::Tensor t) {
526 ET_LOG_MSG_AND_RETURN_IF_FALSE(
527 executorch::runtime::isRealHBF16Type(t.scalar_type()),
528 "Expected to find a real type, but tensor has type %s",
529 torch::executor::toString(t.scalar_type()));
530
531 return true;
532 }
533
tensor_is_realhb_type(executorch::aten::Tensor t)534 inline bool tensor_is_realhb_type(executorch::aten::Tensor t) {
535 ET_LOG_MSG_AND_RETURN_IF_FALSE(
536 torch::executor::isRealHBType(t.scalar_type()),
537 "Expected to find a real type, but tensor has type %s",
538 torch::executor::toString(t.scalar_type()));
539
540 return true;
541 }
542
tensor_is_realhbbf16_type(executorch::aten::Tensor t)543 inline bool tensor_is_realhbbf16_type(executorch::aten::Tensor t) {
544 ET_LOG_MSG_AND_RETURN_IF_FALSE(
545 executorch::runtime::isRealHBBF16Type(t.scalar_type()),
546 "Expected to find a real type, but tensor has type %s",
547 torch::executor::toString(t.scalar_type()));
548
549 return true;
550 }
551
tensor_is_complex_type(executorch::aten::Tensor t)552 inline bool tensor_is_complex_type(executorch::aten::Tensor t) {
553 ET_LOG_MSG_AND_RETURN_IF_FALSE(
554 torch::executor::isComplexType(t.scalar_type()),
555 "Expected to find a complex type, but tensor has type %s",
556 torch::executor::toString(t.scalar_type()));
557
558 return true;
559 }
560
tensor_is_bits_type(executorch::aten::Tensor t)561 inline bool tensor_is_bits_type(executorch::aten::Tensor t) {
562 ET_LOG_MSG_AND_RETURN_IF_FALSE(
563 torch::executor::isBitsType(t.scalar_type()),
564 "Expected to find a bits type, but tensor has type %s",
565 torch::executor::toString(t.scalar_type()));
566
567 return true;
568 }
569
tensors_have_same_dtype(executorch::aten::Tensor a,executorch::aten::Tensor b)570 inline bool tensors_have_same_dtype(
571 executorch::aten::Tensor a,
572 executorch::aten::Tensor b) {
573 ET_LOG_MSG_AND_RETURN_IF_FALSE(
574 a.scalar_type() == b.scalar_type(),
575 ET_TENSOR_CHECK_PREFIX__ ": dtype={%s, %s}",
576 torch::executor::toString(a.scalar_type()),
577 torch::executor::toString(b.scalar_type()));
578 return true;
579 }
580
tensors_have_same_dtype(executorch::aten::Tensor a,executorch::aten::Tensor b,executorch::aten::Tensor c)581 inline bool tensors_have_same_dtype(
582 executorch::aten::Tensor a,
583 executorch::aten::Tensor b,
584 executorch::aten::Tensor c) {
585 ET_LOG_MSG_AND_RETURN_IF_FALSE(
586 a.scalar_type() == b.scalar_type() && b.scalar_type() == c.scalar_type(),
587 ET_TENSOR_CHECK_PREFIX__ ": dtype={%s, %s, %s}",
588 torch::executor::toString(a.scalar_type()),
589 torch::executor::toString(b.scalar_type()),
590 torch::executor::toString(c.scalar_type()));
591 return true;
592 }
593
tensor_is_rank(executorch::aten::Tensor t,size_t rank)594 inline bool tensor_is_rank(executorch::aten::Tensor t, size_t rank) {
595 ET_LOG_MSG_AND_RETURN_IF_FALSE(
596 t.dim() == rank,
597 "Expected tensor.dim() to be %zu, but got %zu",
598 static_cast<size_t>(rank),
599 static_cast<size_t>(t.dim()));
600
601 return true;
602 }
603
tensor_has_rank_greater_or_equal_to(executorch::aten::Tensor t,size_t rank)604 inline bool tensor_has_rank_greater_or_equal_to(
605 executorch::aten::Tensor t,
606 size_t rank) {
607 ET_LOG_MSG_AND_RETURN_IF_FALSE(
608 t.dim() >= rank,
609 "Expected tensor.dim() to be >= %zu, but got %zu",
610 static_cast<size_t>(rank),
611 static_cast<size_t>(t.dim()));
612
613 return true;
614 }
615
tensor_has_rank_smaller_or_equal_to(executorch::aten::Tensor t,size_t rank)616 inline bool tensor_has_rank_smaller_or_equal_to(
617 executorch::aten::Tensor t,
618 size_t rank) {
619 ET_LOG_MSG_AND_RETURN_IF_FALSE(
620 t.dim() <= rank,
621 "Expected tensor.dim() to be <= %zu, but got %zu",
622 static_cast<size_t>(rank),
623 static_cast<size_t>(t.dim()));
624
625 return true;
626 }
627
tensor_has_dim(executorch::aten::Tensor t,int64_t d)628 inline bool tensor_has_dim(executorch::aten::Tensor t, int64_t d) {
629 if (t.dim() == 0) {
630 ET_LOG_MSG_AND_RETURN_IF_FALSE(
631 d == 0 || d == -1,
632 "dim must be 0 or -1 for 0-dim tensor, got %" PRId64,
633 d);
634 } else {
635 ET_LOG_MSG_AND_RETURN_IF_FALSE(
636 d > 0 ? d < t.dim() : t.dim() + d >= 0,
637 "%zu-dim tensor does not have dim at index %zu",
638 static_cast<size_t>(t.dim()),
639 static_cast<size_t>(d));
640 }
641 return true;
642 }
643
tensor_has_non_empty_dim(executorch::aten::Tensor t,int64_t d)644 inline bool tensor_has_non_empty_dim(executorch::aten::Tensor t, int64_t d) {
645 const size_t udim = ET_NORMALIZE_IX(d, t.dim());
646 ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(t, d));
647 ET_LOG_AND_RETURN_IF_FALSE(t.size(udim) != 0);
648 return true;
649 }
650
651 inline bool
tensor_dim_has_index(executorch::aten::Tensor t,int64_t d,int64_t ix)652 tensor_dim_has_index(executorch::aten::Tensor t, int64_t d, int64_t ix) {
653 // Indexing ops don't support zero-dim tensors
654 ET_CHECK(t.dim() != 0);
655 if (d < 0) {
656 d += t.dim();
657 }
658 // Dimension must have been already checked by tensor_has_dim
659 ET_CHECK(d >= 0 && d < t.dim());
660
661 ET_LOG_MSG_AND_RETURN_IF_FALSE(
662 ix >= -t.size(d) && ix < t.size(d),
663 "index %" PRId64 " out of range [-%zu,%zu) at dimension %" PRId64 ")",
664 ix,
665 static_cast<size_t>(t.size(d)),
666 static_cast<size_t>(t.size(d)),
667 d);
668 return true;
669 }
670
tensors_have_same_size_at_dims(executorch::aten::Tensor a,size_t dim_a,executorch::aten::Tensor b,size_t dim_b)671 inline bool tensors_have_same_size_at_dims(
672 executorch::aten::Tensor a,
673 size_t dim_a,
674 executorch::aten::Tensor b,
675 size_t dim_b) {
676 ET_LOG_MSG_AND_RETURN_IF_FALSE(
677 dim_a < a.dim(),
678 "Cannot retrieve dim %zu from tensor with dim %zu",
679 static_cast<size_t>(dim_a),
680 static_cast<size_t>(a.dim()));
681 ET_LOG_MSG_AND_RETURN_IF_FALSE(
682 dim_b < b.dim(),
683 "Cannot retrieve dim %zu from tensor with dim %zu",
684 static_cast<size_t>(dim_b),
685 static_cast<size_t>(b.dim()));
686 ET_LOG_MSG_AND_RETURN_IF_FALSE(
687 a.size(dim_a) == b.size(dim_b),
688 ET_TENSOR_CHECK_PREFIX__
689 ": a.size(%zu) = %zu does not match b.size(%zu) = %zu",
690 static_cast<size_t>(dim_a),
691 static_cast<size_t>(a.size(dim_a)),
692 static_cast<size_t>(dim_b),
693 static_cast<size_t>(b.size(dim_b)));
694
695 return true;
696 }
697
tensors_have_same_shape(executorch::aten::Tensor a,executorch::aten::Tensor b)698 inline bool tensors_have_same_shape(
699 executorch::aten::Tensor a,
700 executorch::aten::Tensor b) {
701 if (a.numel() == 1 && b.numel() == 1) {
702 // PyTorch operators treat all scalar tensors as the same shape even if
703 // they have different dims.
704 return true;
705 }
706 if (!(a.sizes() == b.sizes() && a.numel() == b.numel())) {
707 ET_LOG(
708 Error,
709 ET_TENSOR_CHECK_PREFIX__ ": numel=(%zu, %zu), dim=(%zu, %zu)",
710 static_cast<size_t>(a.numel()),
711 static_cast<size_t>(b.numel()),
712 static_cast<size_t>(a.dim()),
713 static_cast<size_t>(b.dim()));
714 for (size_t d = 0; d < ET_MIN2(a.dim(), b.dim()); ++d) {
715 ET_LOG(
716 Error,
717 " size(%zu): (%zu, %zu)",
718 static_cast<size_t>(d),
719 static_cast<size_t>(a.size(d)),
720 static_cast<size_t>(b.size(d)));
721 }
722
723 return false;
724 }
725
726 return true;
727 }
728
tensors_have_same_shape(executorch::aten::Tensor a,executorch::aten::Tensor b,executorch::aten::Tensor c)729 inline bool tensors_have_same_shape(
730 executorch::aten::Tensor a,
731 executorch::aten::Tensor b,
732 executorch::aten::Tensor c) {
733 if (a.numel() == 1 && b.numel() == 1 && c.numel() == 1) {
734 // PyTorch operators treat all scalar tensors as the same shape even if
735 // they have different dims.
736 return true;
737 }
738 bool cond1 = (a.sizes() == b.sizes()) && (a.numel() == b.numel());
739 bool cond2 = (b.sizes() == c.sizes()) && (b.numel() == c.numel());
740
741 if (!(cond1 && cond2)) {
742 ET_LOG(
743 Error,
744 ET_TENSOR_CHECK_PREFIX__ ": numel=(%zu, %zu, %zu), dim=(%zu, %zu, %zu)",
745 static_cast<size_t>(a.numel()),
746 static_cast<size_t>(b.numel()),
747 static_cast<size_t>(c.numel()),
748 static_cast<size_t>(a.dim()),
749 static_cast<size_t>(b.dim()),
750 static_cast<size_t>(c.dim()));
751 for (size_t d = 0; d < ET_MIN3(a.dim(), b.dim(), c.dim()); ++d) {
752 ET_LOG(
753 Error,
754 " size(%zu): (%zu, %zu, %zu)",
755 static_cast<size_t>(d),
756 static_cast<size_t>(a.size(d)),
757 static_cast<size_t>(b.size(d)),
758 static_cast<size_t>(c.size(d)));
759 }
760
761 return false;
762 }
763
764 return true;
765 }
766
tensors_have_same_shape_and_dtype(executorch::aten::Tensor a,executorch::aten::Tensor b)767 inline bool tensors_have_same_shape_and_dtype(
768 executorch::aten::Tensor a,
769 executorch::aten::Tensor b) {
770 return tensors_have_same_shape(a, b) && tensors_have_same_dtype(a, b);
771 }
772
tensors_have_same_shape_and_dtype(executorch::aten::Tensor a,executorch::aten::Tensor b,executorch::aten::Tensor c)773 inline bool tensors_have_same_shape_and_dtype(
774 executorch::aten::Tensor a,
775 executorch::aten::Tensor b,
776 executorch::aten::Tensor c) {
777 return tensors_have_same_shape(a, b, c) && tensors_have_same_dtype(a, b, c);
778 }
779
tensor_has_expected_size(executorch::aten::Tensor a,executorch::aten::ArrayRef<executorch::aten::SizesType> expected_sizes)780 inline bool tensor_has_expected_size(
781 executorch::aten::Tensor a,
782 executorch::aten::ArrayRef<executorch::aten::SizesType> expected_sizes) {
783 if (!(a.sizes() == expected_sizes)) {
784 ET_LOG(
785 Error,
786 ET_TENSOR_CHECK_PREFIX__ ": dim=(%zu, %zu)",
787 static_cast<size_t>(a.dim()),
788 static_cast<size_t>(expected_sizes.size()));
789 size_t a_dim = static_cast<size_t>(a.dim());
790 size_t expected_dim = static_cast<size_t>(expected_sizes.size());
791 for (size_t d = 0; d < ET_MIN2(a_dim, expected_dim); ++d) {
792 ET_LOG(
793 Error,
794 " size(%zu): (%zu, %zu)",
795 static_cast<size_t>(d),
796 static_cast<size_t>(a.size(d)),
797 static_cast<size_t>(expected_sizes[d]));
798 }
799
800 return false;
801 }
802 return true;
803 }
804
tensors_have_same_strides(executorch::aten::Tensor a,executorch::aten::Tensor b)805 inline bool tensors_have_same_strides(
806 executorch::aten::Tensor a,
807 executorch::aten::Tensor b) {
808 if (a.strides() != b.strides()) {
809 ET_LOG(
810 Error,
811 ET_TENSOR_CHECK_PREFIX__ ": dim=(%zu, %zu)",
812 static_cast<size_t>(a.dim()),
813 static_cast<size_t>(b.dim()));
814 for (size_t d = 0; d < ET_MIN2(a.dim(), b.dim()); ++d) {
815 ET_LOG(
816 Error,
817 " stride(%zu): (%zu, %zu)",
818 static_cast<size_t>(d),
819 static_cast<size_t>(a.strides()[d]),
820 static_cast<size_t>(b.strides()[d]));
821 }
822
823 return false;
824 }
825 return true;
826 }
827
tensors_have_same_strides(executorch::aten::Tensor a,executorch::aten::Tensor b,executorch::aten::Tensor c)828 inline bool tensors_have_same_strides(
829 executorch::aten::Tensor a,
830 executorch::aten::Tensor b,
831 executorch::aten::Tensor c) {
832 if (!(a.strides() == b.strides() && b.strides() == c.strides())) {
833 ET_LOG(
834 Error,
835 ET_TENSOR_CHECK_PREFIX__ ": dim=(%zu, %zu, %zu)",
836 static_cast<size_t>(a.dim()),
837 static_cast<size_t>(b.dim()),
838 static_cast<size_t>(c.dim()));
839 for (size_t d = 0; d < ET_MIN3(a.dim(), b.dim(), c.dim()); ++d) {
840 ET_LOG(
841 Error,
842 " stride(%zu): (%zu, %zu, %zu)",
843 static_cast<size_t>(d),
844 static_cast<size_t>(a.strides()[d]),
845 static_cast<size_t>(b.strides()[d]),
846 static_cast<size_t>(c.strides()[d]));
847 }
848
849 return false;
850 }
851 return true;
852 }
853
tensor_is_contiguous(executorch::aten::Tensor t)854 inline bool tensor_is_contiguous(executorch::aten::Tensor t) {
855 const auto strides = t.strides();
856 const auto sizes = t.sizes();
857 // If tensor is 0-dim (i.e. a scalar tensor) it is contiguous
858 if (strides.size() == 0) {
859 return true;
860 }
861 ET_LOG_MSG_AND_RETURN_IF_FALSE(
862 strides[strides.size() - 1] == 1,
863 "Tensor is not contiguous; the stride of the last dimension must be 1, "
864 "but got %zu",
865 static_cast<size_t>(strides[strides.size() - 1]));
866 for (int i = strides.size() - 1; i > 0; --i) {
867 ET_LOG_MSG_AND_RETURN_IF_FALSE(
868 strides[i - 1] == strides[i] * sizes[i],
869 "Tensor is not contiguous; the stride of dim %zu should be equal to "
870 "strides[%zu] * sizes[%zu] = %zu, but found %zu",
871 static_cast<size_t>(i - 1),
872 static_cast<size_t>(i),
873 static_cast<size_t>(i),
874 static_cast<size_t>(strides[i] * sizes[i]),
875 static_cast<size_t>(strides[i - 1]));
876 }
877 return true;
878 }
879
tensors_have_same_rank(executorch::aten::Tensor a,executorch::aten::Tensor b)880 inline bool tensors_have_same_rank(
881 executorch::aten::Tensor a,
882 executorch::aten::Tensor b) {
883 ET_LOG_MSG_AND_RETURN_IF_FALSE(
884 a.dim() == b.dim(),
885 ET_TENSOR_CHECK_PREFIX__ ": rank={%zd, %zd}",
886 ssize_t(a.dim()),
887 ssize_t(b.dim()));
888 return true;
889 }
890
tensor_is_scalar(executorch::aten::Tensor t)891 inline bool tensor_is_scalar(executorch::aten::Tensor t) {
892 return t.dim() == 0 && t.numel() == 1;
893 }
894
895 /**
896 * The expected output size may not be the existing size of any inputs and
897 * outputs if the operator supports both broadcast and dynamic shape.
898 * Therefore such operators needs extra space to store the calculated expected
899 * output size. such dynamic allocation is troublesome in executorch so we can
900 * just hard code a static value of a relatively small value because users
901 * don't create high dimensional tensors.
902 */
903 constexpr size_t kTensorDimensionLimit = 16;
904
905 /// Returns the product of dim[0:dim), not including dim.
getLeadingDims(const executorch::aten::Tensor & tensor,int64_t dim)906 inline size_t getLeadingDims(
907 const executorch::aten::Tensor& tensor,
908 int64_t dim) {
909 ET_CHECK_MSG(
910 dim >= 0 && dim <= tensor.dim(),
911 "Ending dimension %" PRId64
912 " should be in the range [0, tensor.dim() %zd].",
913 dim,
914 ssize_t(tensor.dim()));
915 size_t dims = 1;
916 for (size_t i = 0; i < dim; ++i) {
917 dims *= static_cast<size_t>(tensor.size(i));
918 }
919 return dims;
920 }
921
922 /// Returns the product of dim[dim+1:].
getTrailingDims(const executorch::aten::Tensor & tensor,int64_t dim)923 inline size_t getTrailingDims(
924 const executorch::aten::Tensor& tensor,
925 int64_t dim) {
926 ET_CHECK_MSG(
927 dim >= -1 && dim < tensor.dim(),
928 "Starting dimension %" PRId64
929 " should be in the range [-1, tensor.dim() -1 %zd).",
930 dim,
931 ssize_t(tensor.dim()));
932 size_t dims = 1;
933 for (size_t i = dim + 1; i < tensor.dim(); ++i) {
934 dims *= static_cast<size_t>(tensor.size(i));
935 }
936 return dims;
937 }
938
939 /**
940 * Given a N-dimensional tensor coordinate, return a linear index that can be
941 * used to access the corresponding element in the tensor's data buffer.
942 *
943 * @param[in] tensor The tensor that will be indexed
944 * @param[in] coordinate A n-dimensional array representing the coordinate to
945 * index. It is assumed that the array has kTensorDimensionLimit elements.
946 * @param[out] index The linear index to element at the specified coordinate
947 * in the tensor.
948 */
coordinateToIndex(const executorch::aten::Tensor & tensor,const size_t * const coordinate)949 inline size_t coordinateToIndex(
950 const executorch::aten::Tensor& tensor,
951 const size_t* const coordinate) {
952 size_t index = 0;
953 for (int d = 0; d < tensor.dim(); ++d) {
954 index += coordinate[d] * getTrailingDims(tensor, d);
955 }
956 return index;
957 }
958
959 /**
960 * Produce a memoized array for use with repeated calls to
961 * coordinateToIndexWithTrailingDimsMemo, which will be faster than
962 * repeated calls to coordinateToIndex.
963 */
memoizeTrailingDims(const executorch::aten::Tensor & tensor,size_t trailing_dims_memo[kTensorDimensionLimit])964 inline void memoizeTrailingDims(
965 const executorch::aten::Tensor& tensor,
966 size_t trailing_dims_memo[kTensorDimensionLimit]) {
967 const auto tensorDim = tensor.dim();
968 size_t dims = 1;
969 for (int ii = tensorDim - 1; ii >= 0; --ii) {
970 trailing_dims_memo[ii] = dims;
971 dims *= static_cast<size_t>(tensor.size(ii));
972 }
973 }
974
975 /**
976 * Like coordinateToIndex, but faster for repeated calls with the same
977 * tensor. trailing_dims_memo must be produced by a call to
978 * memoizeTrailingDims.
979 */
coordinateToIndexWithTrailingDimsMemo(const executorch::aten::Tensor & tensor,const size_t * const coordinate,const size_t trailing_dims_memo[kTensorDimensionLimit])980 inline size_t coordinateToIndexWithTrailingDimsMemo(
981 const executorch::aten::Tensor& tensor,
982 const size_t* const coordinate,
983 const size_t trailing_dims_memo[kTensorDimensionLimit]) {
984 size_t index = 0;
985 for (int d = 0; d < tensor.dim(); ++d) {
986 index += coordinate[d] * trailing_dims_memo[d];
987 }
988 return index;
989 }
990
991 /**
992 * Given the linear index return the N-dimensional tensor coordinate. This is
993 * the inverse operation of coordinateToIndex.
994 *
995 * @param[in] tensor The tensor that will be indexed
996 * @param[in] index The linear index to element at the specified coordinate in
997 * the tensor.
998 * @param[out] coordinate A n-dimensional array representing the coordinate to
999 * index. It is assumed that the array has kTensorDimensionLimit elements.
1000 * @returns void
1001 */
indexToCoordinate(const executorch::aten::Tensor & tensor,size_t index,size_t * coordinate)1002 inline void indexToCoordinate(
1003 const executorch::aten::Tensor& tensor,
1004 size_t index,
1005 size_t* coordinate) {
1006 ET_CHECK(index < tensor.numel());
1007 for (auto i = 0; i < tensor.dim(); ++i) {
1008 auto dim = tensor.dim() - 1 - i;
1009 size_t dim_size = tensor.size(dim);
1010 coordinate[dim] = index % dim_size;
1011 index /= dim_size;
1012 }
1013 }
1014
1015 /**
1016 * Extracts an integer value from a scalar Tensor.
1017 *
1018 * @param[in] tensor The source of the value to extract.
1019 * @param[out] out_val The extracted value, on success.
1020 * @returns `true` if a value was extracted, and sets `*out_val` to that
1021 * value. `false` if a value could not be extracted: either it was not an
1022 * integer Scalar Tensor, or the value of that Scalar Tensor could not be
1023 * represented by INT_T.
1024 */
1025 template <
1026 typename INT_T,
1027 typename std::enable_if<
1028 std::is_integral<INT_T>::value && !std::is_same<INT_T, bool>::value,
1029 bool>::type = true>
extract_scalar_tensor(executorch::aten::Tensor tensor,INT_T * out_val)1030 bool extract_scalar_tensor(executorch::aten::Tensor tensor, INT_T* out_val) {
1031 if (tensor.numel() != 1) {
1032 return false;
1033 }
1034 #define CASE_INT_DTYPE(TENSOR_CTYPE, TENSOR_DTYPE) \
1035 case executorch::aten::ScalarType::TENSOR_DTYPE: { \
1036 const TENSOR_CTYPE val = tensor.const_data_ptr<TENSOR_CTYPE>()[0]; \
1037 if (val < std::numeric_limits<INT_T>::lowest() || \
1038 val > std::numeric_limits<INT_T>::max()) { \
1039 return false; \
1040 } \
1041 *out_val = static_cast<INT_T>(val); \
1042 return true; \
1043 }
1044
1045 switch (tensor.scalar_type()) {
1046 ET_FORALL_INT_TYPES(CASE_INT_DTYPE);
1047 default:
1048 return false;
1049 }
1050 #undef CASE_INT_DTYPE
1051 }
1052
1053 /**
1054 * Extracts a floating point value from a scalar Tensor.
1055 *
1056 * @param[in] tensor The source of the value to extract.
1057 * @param[out] out_val The extracted value, on success.
1058 * @returns `true` if a value was extracted, and sets `*out_val` to that
1059 * value. `false` if a value could not be extracted: either it was not a
1060 * floating point Scalar Tensor, or the value of that Scalar Tensor could not
1061 * be represented by FLOAT_T.
1062 */
1063 template <
1064 typename FLOAT_T,
1065 typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
1066 type = true>
extract_scalar_tensor(executorch::aten::Tensor tensor,FLOAT_T * out_val)1067 bool extract_scalar_tensor(executorch::aten::Tensor tensor, FLOAT_T* out_val) {
1068 if (tensor.numel() != 1) {
1069 return false;
1070 }
1071 #define CASE_REAL_DTYPE(TENSOR_CTYPE, TENSOR_DTYPE) \
1072 case executorch::aten::ScalarType::TENSOR_DTYPE: { \
1073 /* ET_FORALL_REAL_TYPES guarantees TENSOR_CTYPE is a real type. */ \
1074 double val = \
1075 static_cast<double>(tensor.const_data_ptr<TENSOR_CTYPE>()[0]); \
1076 if (std::isfinite(val) && \
1077 (val < std::numeric_limits<FLOAT_T>::lowest() || \
1078 val > std::numeric_limits<FLOAT_T>::max())) { \
1079 return false; \
1080 } \
1081 *out_val = static_cast<FLOAT_T>(val); \
1082 return true; \
1083 }
1084
1085 switch (tensor.scalar_type()) {
1086 ET_FORALL_REAL_TYPES(CASE_REAL_DTYPE);
1087 default:
1088 return false;
1089 }
1090 #undef CASE_REAL_DTYPE
1091 }
1092
1093 /**
1094 * Extracts a boolean value from a Scalar.
1095 *
1096 * @param[in] scalar The source of the value to extract.
1097 * @param[out] out_val The extracted value, on success.
1098 * @returns `true` if a value was extracted, and sets `*out_val` to that
1099 * value. `false` if a value could not be extracted, i.e. not a boolean
1100 */
1101 template <
1102 typename BOOL_T,
1103 typename std::enable_if<std::is_same<BOOL_T, bool>::value, bool>::type =
1104 true>
extract_scalar_tensor(executorch::aten::Tensor tensor,BOOL_T * out_val)1105 bool extract_scalar_tensor(executorch::aten::Tensor tensor, BOOL_T* out_val) {
1106 if (tensor.scalar_type() != executorch::aten::ScalarType::Bool) {
1107 return false;
1108 }
1109 if (tensor.numel() != 1) {
1110 return false;
1111 }
1112
1113 bool val = tensor.const_data_ptr<bool>()[0];
1114
1115 *out_val = static_cast<BOOL_T>(val);
1116
1117 return true;
1118 }
1119
1120 /// These APIs should not be used outside of Executor.cpp.
1121 namespace internal {
1122 /**
1123 * Share t_src's data_ptr with t_dst.
1124 */
1125 ET_NODISCARD Error share_tensor_data(
1126 const executorch::aten::Tensor& t_dst,
1127 const executorch::aten::Tensor& t_src);
1128
1129 /**
1130 * Copy t_src's data_ptr to t_dst.
1131 */
1132 ET_NODISCARD Error copy_tensor_data(
1133 const executorch::aten::Tensor& t_dst,
1134 const executorch::aten::Tensor& t_src);
1135
1136 /**
1137 * Set the data_ptr of t to buffer.
1138 */
1139 ET_NODISCARD Error set_tensor_data(
1140 const executorch::aten::Tensor& t,
1141 void* buffer,
1142 size_t buffer_size);
1143
1144 /**
1145 * Reset tensor's data_ptr, clear all the storage for at::Tensor.
1146 */
1147 void reset_data_ptr(const executorch::aten::Tensor& tensor);
1148
1149 /**
1150 * Resize tensor impl
1151 */
1152 ET_NODISCARD Error resize_tensor_impl(
1153 executorch::aten::TensorImpl* impl,
1154 executorch::aten::ArrayRef<executorch::aten::SizesType> new_sizes);
1155
1156 } // namespace internal
1157
1158 /**
1159 * Resize a tensor to new_sizes, rank must stay the same. Currently does not
1160 * expand the tensor if new size exceeds the current capacity. Currently
1161 * fails an ET_CHECK if the tensor cannot be resized.
1162 *
1163 * WARNING: Placeholder API until discussion around runtime context is
1164 * settled, will likely move to be a class method on a TensorResizer object
1165 * passed in through runtimeContext.
1166 */
resize_tensor(executorch::aten::Tensor t,executorch::aten::ArrayRef<executorch::aten::SizesType> new_sizes)1167 ET_NODISCARD inline Error resize_tensor(
1168 executorch::aten::Tensor t,
1169 executorch::aten::ArrayRef<executorch::aten::SizesType> new_sizes) {
1170 return internal::resize_tensor_impl(t.unsafeGetTensorImpl(), new_sizes);
1171 }
1172
1173 /**
1174 * Resize a tensor to new_sizes, rank must stay the same. Currently does not
1175 * expand the tensor if new size exceeds the current capacity. Currently
1176 * fails an ET_CHECK if the tensor cannot be resized.
1177 *
1178 * WARNING: Placeholder API until discussion around runtime context is
1179 * settled, will likely move to be a class method on a TensorResizer object
1180 * passed in through runtimeContext.
1181 */
1182 template <
1183 typename T,
1184 typename std::enable_if<
1185 !std::is_same<executorch::aten::SizesType, T>::value,
1186 int>::type = 0>
resize_tensor(executorch::aten::Tensor t,executorch::aten::ArrayRef<T> new_sizes)1187 ET_NODISCARD inline Error resize_tensor(
1188 executorch::aten::Tensor t,
1189 executorch::aten::ArrayRef<T> new_sizes) {
1190 // Need to cast the input array to an array of Tensor::SizesType
1191 std::array<executorch::aten::SizesType, kTensorDimensionLimit>
1192 new_sizes_casted{};
1193 size_t new_sizes_ndim = new_sizes.size();
1194 for (size_t i = 0; i < new_sizes_ndim; ++i) {
1195 new_sizes_casted[i] =
1196 static_cast<executorch::aten::SizesType>(new_sizes[i]);
1197 }
1198
1199 return internal::resize_tensor_impl(
1200 t.unsafeGetTensorImpl(), {new_sizes_casted.data(), new_sizes_ndim});
1201 }
1202
1203 /// DEPRECATED: Use `resize_tensor()` instead, which can fail non-fatally.
resize(executorch::aten::Tensor t,executorch::aten::ArrayRef<executorch::aten::SizesType> new_sizes)1204 ET_DEPRECATED inline void resize(
1205 executorch::aten::Tensor t,
1206 executorch::aten::ArrayRef<executorch::aten::SizesType> new_sizes) {
1207 Error err = resize_tensor(t, new_sizes);
1208 ET_CHECK_MSG(
1209 err == Error::Ok, "Could not resize Tensor; see logs for details");
1210 }
1211 /**
1212 * Get dim_order of a Tensor and write it to out_dim_order.
1213 * @param tensor The tensor where we want to get dim order from.
1214 * @param out_dim_order Pointing to an array of DimOrderType where we write
1215 * dim order into it.
1216 * @param out_dim_order_size Size of the DimOrderType array.
1217 */
1218 ET_NODISCARD Error get_dim_order(
1219 const executorch::aten::Tensor& tensor,
1220 executorch::aten::DimOrderType* out_dim_order,
1221 size_t out_dim_order_size);
1222
1223 /**
1224 * Checks whether a tensor has a valid dim order. If the dim order could not
1225 * be determined, then this function returns false by default.
1226 */
1227 bool tensor_has_valid_dim_order(executorch::aten::Tensor t);
1228
1229 /**
1230 * Checks whether a tensor has either the default of channels last dim order.
1231 * If the dim order could not be determined, then this function returns false
1232 * by default.
1233 */
1234 bool tensor_is_default_or_channels_last_dim_order(executorch::aten::Tensor t);
1235
1236 /**
1237 * Checks whether a tensor has the default dimension order.
1238 * Logs an error message if the tensor does not meet the expected criteria.
1239 *
1240 * @param t The tensor to check the dimension order of.
1241 * @return True if the tensor has the default dimension order, false otherwise.
1242 */
1243 bool tensor_is_default_dim_order(executorch::aten::Tensor t);
1244
1245 /**
1246 * Checks whether a tensor has the channels last dimension order.
1247 * Logs an error message if the tensor does not meet the expected criteria.
1248 *
1249 * @param t The tensor to check the dimension order of.
1250 * @return True if the tensor has the channels last dimension order, false
1251 * otherwise.
1252 */
1253 bool tensor_is_channels_last_dim_order(executorch::aten::Tensor t);
1254
1255 /**
1256 * Asserts that four tensors have the same dim_order
1257 *
1258 * Note that this macro only tests dim order, but not others like actual data,
1259 * sizes, etc.
1260 *
1261 */
1262 bool tensors_have_same_dim_order(
1263 const executorch::aten::ArrayRef<executorch::aten::Tensor> tensor_list);
1264
1265 /**
1266 * Asserts that two tensors have the same dim_order
1267 *
1268 * Note that this macro only tests dim order, but not others like actual data,
1269 * sizes, etc.
1270 */
1271
tensors_have_same_dim_order(const executorch::aten::Tensor & a,const executorch::aten::Tensor & b)1272 inline bool tensors_have_same_dim_order(
1273 const executorch::aten::Tensor& a,
1274 const executorch::aten::Tensor& b) {
1275 executorch::aten::Tensor tensor_list[2] = {a, b};
1276 return tensors_have_same_dim_order(tensor_list);
1277 }
1278
1279 /**
1280 * Asserts that three tensors have the same dim_order
1281 *
1282 * Note that this macro only tests dim order, but not others like actual data,
1283 * sizes, etc.
1284 *
1285 */
1286
tensors_have_same_dim_order(const executorch::aten::Tensor & a,const executorch::aten::Tensor & b,const executorch::aten::Tensor & c)1287 inline bool tensors_have_same_dim_order(
1288 const executorch::aten::Tensor& a,
1289 const executorch::aten::Tensor& b,
1290 const executorch::aten::Tensor& c) {
1291 executorch::aten::Tensor tensor_list[3] = {a, b, c};
1292 return tensors_have_same_dim_order(tensor_list);
1293 }
1294
1295 /**
1296 * Asserts that four tensors have the same dim_order
1297 *
1298 * Note that this macro only tests dim order, but not others like actual data,
1299 * sizes, etc.
1300 *
1301 */
1302
tensors_have_same_dim_order(const executorch::aten::Tensor & a,const executorch::aten::Tensor & b,const executorch::aten::Tensor & c,const executorch::aten::Tensor & d)1303 inline bool tensors_have_same_dim_order(
1304 const executorch::aten::Tensor& a,
1305 const executorch::aten::Tensor& b,
1306 const executorch::aten::Tensor& c,
1307 const executorch::aten::Tensor& d) {
1308 executorch::aten::Tensor tensor_list[4] = {a, b, c, d};
1309 return tensors_have_same_dim_order(tensor_list);
1310 }
1311
1312 /**
1313 * Given an n-dimensional coordinate array and an array of tensor strides,
1314 * calculates the linear index that can be used to retrieve the value at the
1315 * given coordinates.
1316 * @param coordinate Pointer to the array of coordinates.
1317 * @param strides Pointer to the array of strides.
1318 * @param ndim Number of dimensions in the tensor.
1319 */
calculate_linear_index(const executorch::aten::SizesType * coordinate,const executorch::aten::StridesType * strides,const size_t ndim)1320 inline size_t calculate_linear_index(
1321 const executorch::aten::SizesType* coordinate,
1322 const executorch::aten::StridesType* strides,
1323 const size_t ndim) {
1324 size_t index = 0;
1325 for (size_t i = 0; i < ndim; i++) {
1326 index += coordinate[i] * strides[i];
1327 }
1328 return index;
1329 }
1330
1331 } // namespace runtime
1332 } // namespace executorch
1333
1334 namespace torch {
1335 namespace executor {
1336 // TODO(T197294990): Remove these deprecated aliases once all users have moved
1337 // to the new `::executorch` namespaces.
1338 using ::executorch::runtime::calculate_linear_index;
1339 using ::executorch::runtime::coordinateToIndex;
1340 using ::executorch::runtime::dim_is_valid;
1341 using ::executorch::runtime::extract_scalar_tensor;
1342 using ::executorch::runtime::get_dim_order;
1343 using ::executorch::runtime::getLeadingDims;
1344 using ::executorch::runtime::getTrailingDims;
1345 using ::executorch::runtime::indexToCoordinate;
1346 using ::executorch::runtime::kTensorDimensionLimit;
1347 using ::executorch::runtime::nonempty_size;
1348 using ::executorch::runtime::nonzero_dim;
1349 using ::executorch::runtime::resize;
1350 using ::executorch::runtime::resize_tensor;
1351 using ::executorch::runtime::tensor_can_cast_to;
1352 using ::executorch::runtime::tensor_dim_has_index;
1353 using ::executorch::runtime::tensor_has_dim;
1354 using ::executorch::runtime::tensor_has_expected_size;
1355 using ::executorch::runtime::tensor_has_non_empty_dim;
1356 using ::executorch::runtime::tensor_has_rank_greater_or_equal_to;
1357 using ::executorch::runtime::tensor_has_rank_smaller_or_equal_to;
1358 using ::executorch::runtime::tensor_has_valid_dim_order;
1359 using ::executorch::runtime::tensor_is_bits_type;
1360 using ::executorch::runtime::tensor_is_bool_type;
1361 using ::executorch::runtime::tensor_is_complex_type;
1362 using ::executorch::runtime::tensor_is_contiguous;
1363 using ::executorch::runtime::tensor_is_default_dim_order;
1364 using ::executorch::runtime::tensor_is_default_or_channels_last_dim_order;
1365 using ::executorch::runtime::tensor_is_floating_type;
1366 using ::executorch::runtime::tensor_is_integral_type;
1367 using ::executorch::runtime::tensor_is_rank;
1368 using ::executorch::runtime::tensor_is_real_type;
1369 using ::executorch::runtime::tensor_is_realh_type;
1370 using ::executorch::runtime::tensor_is_realhb_type;
1371 using ::executorch::runtime::tensor_is_scalar;
1372 using ::executorch::runtime::tensors_have_same_dim_order;
1373 using ::executorch::runtime::tensors_have_same_dtype;
1374 using ::executorch::runtime::tensors_have_same_rank;
1375 using ::executorch::runtime::tensors_have_same_shape;
1376 using ::executorch::runtime::tensors_have_same_shape_and_dtype;
1377 using ::executorch::runtime::tensors_have_same_size_at_dims;
1378 using ::executorch::runtime::tensors_have_same_strides;
1379 namespace internal {
1380 using ::executorch::runtime::internal::copy_tensor_data;
1381 using ::executorch::runtime::internal::reset_data_ptr;
1382 using ::executorch::runtime::internal::resize_tensor_impl;
1383 using ::executorch::runtime::internal::set_tensor_data;
1384 using ::executorch::runtime::internal::share_tensor_data;
1385 } // namespace internal
1386 } // namespace executor
1387 } // namespace torch
1388