1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/backends/vulkan/runtime/api/api.h> 12 13 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h> 14 15 #include <executorch/backends/vulkan/runtime/graph/containers/Value.h> 16 17 namespace vkcompute { 18 19 struct Kernel1dParams final { 20 int kernel_size; 21 int stride; 22 int padding; 23 int dilation; 24 int in_group_size; 25 int out_group_size; 26 }; 27 28 struct Kernel2dParams final { 29 utils::ivec2 kernel_size; 30 utils::ivec2 stride; 31 utils::ivec2 padding; 32 utils::ivec2 dilation; 33 }; 34 35 Kernel2dParams create_kernel2d_params( 36 ComputeGraph& graph, 37 const ValueRef weight, 38 const bool kernel_size_only, 39 const ValueRef stride, 40 const ValueRef padding, 41 const ValueRef dilation); 42 43 Kernel2dParams create_kernel2d_params( 44 ComputeGraph& graph, 45 const ValueRef kernel_size, 46 const ValueRef stride, 47 const ValueRef padding); 48 49 int64_t calc_out_size( 50 const int64_t in_size, 51 const int64_t kernel_size, 52 const int64_t stride, 53 const int64_t padding, 54 const int64_t dilation, 55 const bool ceil_mode); 56 57 std::vector<int64_t> calc_out_sizes_hw( 58 ComputeGraph& graph, 59 const std::vector<int64_t>& in_sizes, 60 const ValueRef weight, 61 const bool kernel_size_only, 62 const std::vector<ValueRef>& args, 63 const bool transposed = false); 64 65 } // namespace vkcompute 66