1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/backends/vulkan/runtime/api/api.h>
12
13 namespace vkcompute {
14
15 //
16 // Tensor output size calculation functions
17 //
18
19 std::vector<int64_t> calculate_broadcasted_output_size(
20 const api::vTensor& t1,
21 const api::vTensor& t2);
22
23 //
24 // Tensor property checking functions
25 //
26
27 bool check_ndim_is(const api::vTensor& t, size_t ndim);
28
29 bool check_same_ndim(const api::vTensor& t1, const api::vTensor& t2);
30
31 bool check_same_sizes_at(
32 const api::vTensor& t1,
33 int64_t d1,
34 const api::vTensor& t2,
35 int64_t d2);
36
37 bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim);
38
39 bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2);
40
41 bool check_same_packed_dim(
42 const api::vTensor& t1,
43 const api::vTensor& t2,
44 const api::vTensor& t3);
45
46 //
47 // Broadcast flag functions
48 //
49
50 utils::ivec2 create_broadcast_params(
51 const api::vTensor& t1,
52 const api::vTensor& t2);
53
54 //
55 // Work group size calculation functions
56 //
57
58 utils::uvec3 adaptive_work_group_size(const utils::uvec3& global_work_group);
59
60 //
61 // Tensor dim utilities
62 //
63
64 template <
65 typename T,
66 typename std::enable_if<
67 std::is_integral<T>::value && std::is_signed<T>::value,
68 int>::type = 0>
normalize(const T & nchw_dim,const int64_t ndim)69 T normalize(const T& nchw_dim, const int64_t ndim) {
70 return (nchw_dim % ndim + ndim) % ndim;
71 }
72
73 template <
74 typename T,
75 typename std::enable_if<
76 std::is_integral<T>::value && std::is_signed<T>::value,
77 int>::type = 0>
nchw_dim_to_whcn_dim(const T & nchw_dim,const int64_t ndim)78 T nchw_dim_to_whcn_dim(const T& nchw_dim, const int64_t ndim) {
79 return ndim - 1 - nchw_dim;
80 }
81
82 } // namespace vkcompute
83