xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/backends/vulkan/runtime/api/api.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
14 
15 #include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
16 
17 namespace vkcompute {
18 
19 struct Kernel1dParams final {
20   int kernel_size;
21   int stride;
22   int padding;
23   int dilation;
24   int in_group_size;
25   int out_group_size;
26 };
27 
28 struct Kernel2dParams final {
29   utils::ivec2 kernel_size;
30   utils::ivec2 stride;
31   utils::ivec2 padding;
32   utils::ivec2 dilation;
33 };
34 
35 Kernel2dParams create_kernel2d_params(
36     ComputeGraph& graph,
37     const ValueRef weight,
38     const bool kernel_size_only,
39     const ValueRef stride,
40     const ValueRef padding,
41     const ValueRef dilation);
42 
43 Kernel2dParams create_kernel2d_params(
44     ComputeGraph& graph,
45     const ValueRef kernel_size,
46     const ValueRef stride,
47     const ValueRef padding);
48 
49 int64_t calc_out_size(
50     const int64_t in_size,
51     const int64_t kernel_size,
52     const int64_t stride,
53     const int64_t padding,
54     const int64_t dilation,
55     const bool ceil_mode);
56 
57 std::vector<int64_t> calc_out_sizes_hw(
58     ComputeGraph& graph,
59     const std::vector<int64_t>& in_sizes,
60     const ValueRef weight,
61     const bool kernel_size_only,
62     const std::vector<ValueRef>& args,
63     const bool transposed = false);
64 
65 } // namespace vkcompute
66