xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/backends/vulkan/runtime/api/api.h>
12 
13 namespace vkcompute {
14 
15 //
16 // Tensor output size calculation functions
17 //
18 
19 std::vector<int64_t> calculate_broadcasted_output_size(
20     const api::vTensor& t1,
21     const api::vTensor& t2);
22 
23 //
24 // Tensor property checking functions
25 //
26 
27 bool check_ndim_is(const api::vTensor& t, size_t ndim);
28 
29 bool check_same_ndim(const api::vTensor& t1, const api::vTensor& t2);
30 
31 bool check_same_sizes_at(
32     const api::vTensor& t1,
33     int64_t d1,
34     const api::vTensor& t2,
35     int64_t d2);
36 
37 bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim);
38 
39 bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2);
40 
41 bool check_same_packed_dim(
42     const api::vTensor& t1,
43     const api::vTensor& t2,
44     const api::vTensor& t3);
45 
46 //
47 // Broadcast flag functions
48 //
49 
50 utils::ivec2 create_broadcast_params(
51     const api::vTensor& t1,
52     const api::vTensor& t2);
53 
54 //
55 // Work group size calculation functions
56 //
57 
58 utils::uvec3 adaptive_work_group_size(const utils::uvec3& global_work_group);
59 
60 //
61 // Tensor dim utilities
62 //
63 
64 template <
65     typename T,
66     typename std::enable_if<
67         std::is_integral<T>::value && std::is_signed<T>::value,
68         int>::type = 0>
normalize(const T & nchw_dim,const int64_t ndim)69 T normalize(const T& nchw_dim, const int64_t ndim) {
70   return (nchw_dim % ndim + ndim) % ndim;
71 }
72 
73 template <
74     typename T,
75     typename std::enable_if<
76         std::is_integral<T>::value && std::is_signed<T>::value,
77         int>::type = 0>
nchw_dim_to_whcn_dim(const T & nchw_dim,const int64_t ndim)78 T nchw_dim_to_whcn_dim(const T& nchw_dim, const int64_t ndim) {
79   return ndim - 1 - nchw_dim;
80 }
81 
82 } // namespace vkcompute
83