xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16 
17 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
18 
19 namespace vkcompute {
20 
check_and_prepack_arg(ComputeGraph & graph,ValueRef arg_ref,const utils::StorageType stype,int64_t num_channels,const std::string & debug_name)21 ValueRef check_and_prepack_arg(
22     ComputeGraph& graph,
23     ValueRef arg_ref,
24     const utils::StorageType stype,
25     int64_t num_channels,
26     const std::string& debug_name) {
27   VK_CHECK_COND(
28       graph.val_is_tref(arg_ref),
29       "native_batch_norm requires ",
30       debug_name,
31       " to be a constant tensorref");
32   VK_CHECK_COND(graph.get_tref(arg_ref)->sizes[0] == num_channels);
33 
34   // batch_norm's param are broadcasted on the channel dimension.
35   // In this implementation, we pack the weights along the x dimension, and
36   // in the shader, we lookup using the along the x.
37   return prepack_standard(graph, arg_ref, stype, utils::kWidthPacked);
38 }
39 
add_native_batch_norm_node(ComputeGraph & graph,ValueRef in_ref,ValueRef weight_ref,ValueRef bias_ref,ValueRef mean_ref,ValueRef var_ref,ValueRef eps_ref,ValueRef out_tuple_ref)40 void add_native_batch_norm_node(
41     ComputeGraph& graph,
42     ValueRef in_ref,
43     ValueRef weight_ref,
44     ValueRef bias_ref,
45     ValueRef mean_ref,
46     ValueRef var_ref,
47     ValueRef eps_ref,
48     ValueRef out_tuple_ref) {
49   std::vector<int64_t> in_sizes = graph.get_tensor(in_ref)->sizes();
50   std::vector<int64_t> out_sizes = graph.get_tensor(in_ref)->sizes();
51 
52   VK_CHECK_COND(in_sizes.size() == 4, "BatchNorm only support 4d tensor");
53   VK_CHECK_COND(out_sizes.size() == 4, "BatchNorm only support 4d tensor");
54 
55   // Only the first element of the return value is propagated. The remaining 2
56   // elements are zero-size dummy tensor.
57   ValueRef out_ref = graph.get_value_list(out_tuple_ref)->at(0);
58 
59   utils::StorageType stype = graph.storage_type_of(out_ref);
60 
61   int64_t num_channels = dim_at<kChannel4D>(in_sizes);
62 
63   ValueRef arg_weight =
64       check_and_prepack_arg(graph, weight_ref, stype, num_channels, "weight");
65   ValueRef arg_bias =
66       check_and_prepack_arg(graph, bias_ref, stype, num_channels, "bias");
67   ValueRef arg_mean =
68       check_and_prepack_arg(graph, mean_ref, stype, num_channels, "mean");
69   ValueRef arg_var =
70       check_and_prepack_arg(graph, var_ref, stype, num_channels, "var");
71   float epsilon = graph.extract_scalar<float>(eps_ref);
72 
73   vTensorPtr t_in = graph.get_tensor(in_ref);
74 
75   VK_CHECK_COND(!graph.val_is_tref(out_ref), "Output should not be tref");
76   vTensorPtr t_out = graph.get_tensor(out_ref);
77 
78   VK_CHECK_COND(
79       dim_at<kChannel4D>(t_out->sizes()) == num_channels,
80       "out channel must match in channel");
81 
82   std::string kernel_name = "batchnorm";
83   add_dtype_suffix(kernel_name, *t_out);
84 
85   int32_t num_texel_per_batch =
86       utils::div_up_4((dim_at<kChannel4D>(t_in->sizes())));
87 
88   graph.execute_nodes().emplace_back(new DispatchNode(
89       graph,
90       VK_KERNEL_FROM_STR(kernel_name),
91       graph.create_global_wg_size(out_ref),
92       graph.create_local_wg_size(out_ref),
93       {{out_ref, vkapi::MemoryAccessType::WRITE},
94        {{in_ref, arg_weight, arg_bias, arg_mean, arg_var},
95         vkapi::MemoryAccessType::READ}},
96       {t_out->logical_limits_ubo(),
97        graph.create_params_buffer(epsilon),
98        graph.create_params_buffer(num_texel_per_batch)}));
99 }
100 
native_batch_norm(ComputeGraph & graph,const std::vector<ValueRef> & args)101 void native_batch_norm(ComputeGraph& graph, const std::vector<ValueRef>& args) {
102   // args[5] is momentum. It is not used in the calculation.
103   return add_native_batch_norm_node(
104       graph, args[0], args[1], args[2], args[3], args[4], args[6], args[7]);
105 }
106 
107 REGISTER_OPERATORS {
108   VK_REGISTER_OP(
109       aten._native_batch_norm_legit_no_training.default, native_batch_norm);
110 }
111 
112 } // namespace vkcompute
113