1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16
17 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
18
19 namespace vkcompute {
20
check_and_prepack_arg(ComputeGraph & graph,ValueRef arg_ref,const utils::StorageType stype,int64_t num_channels,const std::string & debug_name)21 ValueRef check_and_prepack_arg(
22 ComputeGraph& graph,
23 ValueRef arg_ref,
24 const utils::StorageType stype,
25 int64_t num_channels,
26 const std::string& debug_name) {
27 VK_CHECK_COND(
28 graph.val_is_tref(arg_ref),
29 "native_batch_norm requires ",
30 debug_name,
31 " to be a constant tensorref");
32 VK_CHECK_COND(graph.get_tref(arg_ref)->sizes[0] == num_channels);
33
34 // batch_norm's param are broadcasted on the channel dimension.
35 // In this implementation, we pack the weights along the x dimension, and
36 // in the shader, we lookup using the along the x.
37 return prepack_standard(graph, arg_ref, stype, utils::kWidthPacked);
38 }
39
add_native_batch_norm_node(ComputeGraph & graph,ValueRef in_ref,ValueRef weight_ref,ValueRef bias_ref,ValueRef mean_ref,ValueRef var_ref,ValueRef eps_ref,ValueRef out_tuple_ref)40 void add_native_batch_norm_node(
41 ComputeGraph& graph,
42 ValueRef in_ref,
43 ValueRef weight_ref,
44 ValueRef bias_ref,
45 ValueRef mean_ref,
46 ValueRef var_ref,
47 ValueRef eps_ref,
48 ValueRef out_tuple_ref) {
49 std::vector<int64_t> in_sizes = graph.get_tensor(in_ref)->sizes();
50 std::vector<int64_t> out_sizes = graph.get_tensor(in_ref)->sizes();
51
52 VK_CHECK_COND(in_sizes.size() == 4, "BatchNorm only support 4d tensor");
53 VK_CHECK_COND(out_sizes.size() == 4, "BatchNorm only support 4d tensor");
54
55 // Only the first element of the return value is propagated. The remaining 2
56 // elements are zero-size dummy tensor.
57 ValueRef out_ref = graph.get_value_list(out_tuple_ref)->at(0);
58
59 utils::StorageType stype = graph.storage_type_of(out_ref);
60
61 int64_t num_channels = dim_at<kChannel4D>(in_sizes);
62
63 ValueRef arg_weight =
64 check_and_prepack_arg(graph, weight_ref, stype, num_channels, "weight");
65 ValueRef arg_bias =
66 check_and_prepack_arg(graph, bias_ref, stype, num_channels, "bias");
67 ValueRef arg_mean =
68 check_and_prepack_arg(graph, mean_ref, stype, num_channels, "mean");
69 ValueRef arg_var =
70 check_and_prepack_arg(graph, var_ref, stype, num_channels, "var");
71 float epsilon = graph.extract_scalar<float>(eps_ref);
72
73 vTensorPtr t_in = graph.get_tensor(in_ref);
74
75 VK_CHECK_COND(!graph.val_is_tref(out_ref), "Output should not be tref");
76 vTensorPtr t_out = graph.get_tensor(out_ref);
77
78 VK_CHECK_COND(
79 dim_at<kChannel4D>(t_out->sizes()) == num_channels,
80 "out channel must match in channel");
81
82 std::string kernel_name = "batchnorm";
83 add_dtype_suffix(kernel_name, *t_out);
84
85 int32_t num_texel_per_batch =
86 utils::div_up_4((dim_at<kChannel4D>(t_in->sizes())));
87
88 graph.execute_nodes().emplace_back(new DispatchNode(
89 graph,
90 VK_KERNEL_FROM_STR(kernel_name),
91 graph.create_global_wg_size(out_ref),
92 graph.create_local_wg_size(out_ref),
93 {{out_ref, vkapi::MemoryAccessType::WRITE},
94 {{in_ref, arg_weight, arg_bias, arg_mean, arg_var},
95 vkapi::MemoryAccessType::READ}},
96 {t_out->logical_limits_ubo(),
97 graph.create_params_buffer(epsilon),
98 graph.create_params_buffer(num_texel_per_batch)}));
99 }
100
native_batch_norm(ComputeGraph & graph,const std::vector<ValueRef> & args)101 void native_batch_norm(ComputeGraph& graph, const std::vector<ValueRef>& args) {
102 // args[5] is momentum. It is not used in the calculation.
103 return add_native_batch_norm_node(
104 graph, args[0], args[1], args[2], args[3], args[4], args[6], args[7]);
105 }
106
107 REGISTER_OPERATORS {
108 VK_REGISTER_OP(
109 aten._native_batch_norm_legit_no_training.default, native_batch_norm);
110 }
111
112 } // namespace vkcompute
113