xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
15 
16 namespace vkcompute {
17 
check_q_8w_linear_args(const ComputeGraph & graph,const ValueRef mat1,const ValueRef qmat2_data,const ValueRef scales,const ValueRef out)18 void check_q_8w_linear_args(
19     const ComputeGraph& graph,
20     const ValueRef mat1,
21     const ValueRef qmat2_data,
22     const ValueRef scales,
23     const ValueRef out) {
24   std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
25   std::vector<int64_t> qmat2_sizes = graph.sizes_of(qmat2_data);
26   std::vector<int64_t> scales_sizes = graph.sizes_of(scales);
27 
28   VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3);
29   VK_CHECK_COND(qmat2_sizes.size() == 2);
30   VK_CHECK_COND(scales_sizes.size() == 1);
31 
32   VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out));
33 
34   VK_CHECK_COND(
35       utils::val_at(-1, mat1_sizes) == utils::val_at(-1, qmat2_sizes));
36   VK_CHECK_COND(
37       utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes));
38 }
39 
resize_q_8w_linear_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)40 void resize_q_8w_linear_node(
41     ComputeGraph* graph,
42     const std::vector<ArgGroup>& args,
43     const std::vector<ValueRef>& extra_args) {
44   (void)extra_args;
45 
46   vTensorPtr out = graph->get_tensor(args[0].refs[0]);
47   vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]);
48   vTensorPtr qmat2 = graph->get_tensor(args[1].refs[1]);
49 
50   const int out_cols = utils::val_at(-2, mat1->sizes());
51   const int out_rows = utils::val_at(-2, qmat2->sizes());
52 
53   std::vector<int64_t> new_out_sizes(3);
54   if (mat1->sizes().size() == 2) {
55     new_out_sizes.resize(2);
56     new_out_sizes.at(0) = out_cols;
57     new_out_sizes.at(1) = out_rows;
58   } else {
59     new_out_sizes.at(0) = mat1->sizes().at(0);
60     new_out_sizes.at(1) = out_cols;
61     new_out_sizes.at(2) = out_rows;
62   }
63 
64   out->virtual_resize(new_out_sizes);
65 }
66 
add_q_8w_linear_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef q_mat2_data,const ValueRef scales_data,const ValueRef out)67 void add_q_8w_linear_node(
68     ComputeGraph& graph,
69     const ValueRef mat1,
70     const ValueRef q_mat2_data,
71     const ValueRef scales_data,
72     const ValueRef out) {
73   auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
74   ValueRef mat1_W_packed = mat1;
75   ValueRef out_W_packed = out;
76   if (!graph.is_buffer_storage(out) &&
77       graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
78     // Ensure mat1 is width packed
79     mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
80     viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
81     // Ensure out is packed correctly
82     out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
83   }
84   ValueRef q_mat2 = prepack_standard(
85       graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked);
86   ValueRef scales = prepack_standard(
87       graph, scales_data, graph.storage_type_of(out), utils::kWidthPacked);
88 
89   std::string kernel_name = "q_8w_linear";
90   kernel_name.reserve(kShaderNameReserve);
91   add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed));
92   add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2));
93   add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed));
94   add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed));
95 
96   vkapi::ParamsBindList ubos({});
97   if (graph.is_buffer_storage(out_W_packed)) {
98     ubos.append(
99         {graph.sizes_ubo(out_W_packed),
100          graph.strides_ubo(out_W_packed),
101          graph.numel_ubo(out_W_packed),
102          graph.sizes_ubo(mat1_W_packed),
103          graph.strides_ubo(mat1),
104          graph.strides_ubo(q_mat2),
105          graph.strides_ubo(scales)});
106   } else {
107     ubos.append(
108         {graph.logical_limits_ubo(out_W_packed),
109          graph.sizes_ubo(mat1_W_packed)});
110   }
111 
112   // set global work group size to be 1 dimensional
113   const utils::uvec3 wg_size = {
114       static_cast<uint32_t>(graph.numel_of(out_W_packed)), 1, 1};
115 
116   graph.execute_nodes().emplace_back(new DispatchNode(
117       graph,
118       VK_KERNEL_FROM_STR(kernel_name),
119       wg_size,
120       graph.create_local_wg_size(wg_size),
121       // Inputs and Outputs
122       {{out_W_packed, vkapi::MemoryAccessType::WRITE},
123        {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
124       // Shader params buffers
125       ubos,
126       // Specialization Constants
127       {},
128       // Resizing Logic
129       resize_q_8w_linear_node));
130   if (!graph.is_buffer_storage(out) &&
131       graph.packed_dim_of(out) != WHCN::kWidthDim) {
132     viewFn(graph, {out_W_packed, graph.add_none(), out});
133   }
134 }
135 
add_q_8w_linear_optimized_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef q_mat2_data,const ValueRef scales_data,const ValueRef out)136 void add_q_8w_linear_optimized_node(
137     ComputeGraph& graph,
138     const ValueRef mat1,
139     const ValueRef q_mat2_data,
140     const ValueRef scales_data,
141     const ValueRef out) {
142   auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
143   ValueRef mat1_W_packed = mat1;
144   ValueRef out_W_packed = out;
145   if (!graph.is_buffer_storage(out) &&
146       graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
147     // Ensure mat1 is width packed
148     mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
149     viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
150     // Ensure out is packed correctly
151     out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
152   }
153 
154   utils::StorageType stype = graph.storage_type_of(out);
155   ValueRef q_mat2 =
156       prepack_standard(graph, q_mat2_data, stype, utils::kWidthPacked);
157   ValueRef scales =
158       prepack_standard(graph, scales_data, stype, utils::kWidthPacked);
159 
160   std::string kernel_name = "q_8w_linear_optimized";
161   kernel_name.reserve(kShaderNameReserve);
162   add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed));
163   add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2));
164   std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1_W_packed);
165   const int mat1_dims = mat1_sizes.size();
166   if (mat1_dims == 3) {
167     kernel_name = "batch_" + kernel_name;
168   }
169   if (mat1_sizes.at(mat1_dims - 2) < 8) {
170     kernel_name += "_tile_row_2";
171   } else {
172     kernel_name += "_tile_row_4";
173   }
174 
175   add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed));
176   add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed));
177 
178   vkapi::ParamsBindList ubos({});
179 
180   utils::uvec3 global_size;
181   utils::uvec3 local_size;
182   if (graph.is_buffer_storage(out)) {
183     ubos.append(
184         {graph.sizes_ubo(out_W_packed),
185          graph.strides_ubo(out_W_packed),
186          graph.numel_ubo(out_W_packed),
187          graph.sizes_ubo(mat1_W_packed),
188          graph.strides_ubo(mat1_W_packed),
189          graph.strides_ubo(q_mat2),
190          graph.strides_ubo(scales)});
191     global_size = graph.create_global_wg_size(out_W_packed);
192     local_size = graph.create_local_wg_size(out_W_packed);
193   } else {
194     global_size = graph.logical_limits_of(out_W_packed);
195     ubos.append(
196         {graph.logical_limits_ubo(out_W_packed),
197          graph.sizes_ubo(mat1_W_packed)});
198     if (mat1_sizes.at(mat1_dims - 2) < 8) {
199       global_size = global_size = utils::divup_vec(global_size, {1, 2, 1});
200     } else {
201       global_size = utils::divup_vec(global_size, {1, 4, 1});
202     }
203     local_size = {16, 3, 1};
204   }
205 
206   graph.execute_nodes().emplace_back(new DispatchNode(
207       graph,
208       VK_KERNEL_FROM_STR(kernel_name),
209       global_size,
210       local_size,
211       // Inputs and Outputs
212       {{out_W_packed, vkapi::MemoryAccessType::WRITE},
213        {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
214       // Shader params buffers
215       ubos,
216       // Specialization Constants
217       {}, // spec_vars,
218       // Resizing Logic
219       resize_q_8w_linear_node));
220 
221   if (!graph.is_buffer_storage(out)) {
222     viewFn(graph, {out_W_packed, graph.add_none(), out});
223   }
224 }
225 
weight_int8pack_mm(ComputeGraph & graph,const std::vector<ValueRef> & args)226 void weight_int8pack_mm(
227     ComputeGraph& graph,
228     const std::vector<ValueRef>& args) {
229   check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]);
230   return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]);
231 }
232 
check_q_4w_linear_args(ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef group_size,const ValueRef scales_and_zeros,const ValueRef out)233 void check_q_4w_linear_args(
234     ComputeGraph& graph,
235     const ValueRef mat1,
236     const ValueRef mat2_data,
237     const ValueRef group_size,
238     const ValueRef scales_and_zeros,
239     const ValueRef out) {
240   VK_CHECK_COND(graph.int16_shader_types_enabled());
241   VK_CHECK_COND(graph.int8_buffers_enabled());
242 
243   VK_CHECK_COND(graph.val_is_tensor(mat1));
244   VK_CHECK_COND(graph.val_is_tref(mat2_data));
245   VK_CHECK_COND(graph.val_is_tref(scales_and_zeros));
246 
247   VK_CHECK_COND(graph.dim_of(mat1) <= 3);
248   VK_CHECK_COND(graph.dim_of(mat2_data) == 2);
249   VK_CHECK_COND(graph.dim_of(scales_and_zeros) == 3);
250 
251   VK_CHECK_COND(graph.size_at<int>(-3, mat1) == 1);
252   const int K = graph.size_at<int>(-1, mat1);
253   VK_CHECK_COND(graph.size_at<int>(-1, mat2_data) * 2 == K);
254 
255   const int group_size_val = graph.extract_scalar<int>(group_size);
256   VK_CHECK_COND(K % group_size_val == 0);
257 
258   VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim);
259   VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);
260 
261   VK_CHECK_COND(graph.has_standard_axis_map(mat1));
262   VK_CHECK_COND(graph.has_standard_axis_map(out));
263 }
264 
resize_q_4w_linear_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)265 void resize_q_4w_linear_node(
266     ComputeGraph* graph,
267     const std::vector<ArgGroup>& args,
268     const std::vector<ValueRef>& extra_args) {
269   (void)extra_args;
270 
271   vTensorPtr out = graph->get_tensor(args[0].refs[0]);
272   vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]);
273   vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]);
274 
275   const int out_cols = utils::val_at(-2, mat1->sizes());
276   const int out_rows = utils::val_at(-2, mat2->sizes());
277 
278   std::vector<int64_t> new_out_sizes(3);
279   if (mat1->sizes().size() == 2) {
280     new_out_sizes.resize(2);
281     new_out_sizes.at(0) = out_cols;
282     new_out_sizes.at(1) = out_rows;
283   } else {
284     new_out_sizes.at(0) = mat1->sizes().at(0);
285     new_out_sizes.at(1) = out_cols;
286     new_out_sizes.at(2) = out_rows;
287   }
288 
289   out->virtual_resize(new_out_sizes);
290 }
291 
add_q_4w_linear_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef group_size,const ValueRef scales_and_zeros_data,const ValueRef out)292 void add_q_4w_linear_node(
293     ComputeGraph& graph,
294     const ValueRef mat1,
295     const ValueRef mat2_data,
296     const ValueRef group_size,
297     const ValueRef scales_and_zeros_data,
298     const ValueRef out) {
299   check_q_4w_linear_args(
300       graph, mat1, mat2_data, group_size, scales_and_zeros_data, out);
301 
302   utils::StorageType storage_type = graph.storage_type_of(out);
303 
304   ValueRef mat2 = prepack_direct_copy_buffer(graph, mat2_data);
305 
306   ValueRef scales_and_zeros = prepack_standard(
307       graph,
308       scales_and_zeros_data,
309       graph.storage_type_of(out),
310       utils::kWidthPacked);
311 
312   std::string kernel_name = "q_4w_linear";
313   add_storage_type_suffix(kernel_name, storage_type);
314   add_dtype_suffix(kernel_name, graph.dtype_of(out));
315 
316   const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);
317 
318   vkapi::ParamsBindList ubos({});
319   ubos.append(graph.logical_limits_ubo(out));
320   ubos.append(graph.sizes_ubo(mat1));
321   ubos.append(graph.strides_ubo(mat2));
322   ubos.append(graph.strides_ubo(scales_and_zeros));
323 
324   utils::uvec3 global_wg_size = graph.logical_limits_of(out);
325   utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
326 
327   graph.execute_nodes().emplace_back(new DispatchNode(
328       graph,
329       VK_KERNEL_FROM_STR(kernel_name),
330       global_wg_size,
331       local_wg_size,
332       // Inputs and Outputs
333       {{out, vkapi::MemoryAccessType::WRITE},
334        {{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}},
335       // Shader params buffers
336       ubos,
337       // Specialization Constants
338       {SV(group_size_val)},
339       // Resizing Logic
340       resize_q_4w_linear_node,
341       {}));
342 }
343 
linear_weight_int4(ComputeGraph & graph,const std::vector<ValueRef> & args)344 void linear_weight_int4(
345     ComputeGraph& graph,
346     const std::vector<ValueRef>& args) {
347   return add_q_4w_linear_node(
348       graph,
349       args[0], // mat1
350       args[1], // mat2
351       args[2], // group_size
352       args[3], // scales_and_zeros
353       // There is an unused variable inner_k_tiles which is used to call
354       // _convert_weight_to_int4pack in the AOT custom op, which is why the 4th
355       // argument is skipped.
356       args[5] // out
357   );
358 }
359 
360 REGISTER_OPERATORS {
361   VK_REGISTER_OP(aten._weight_int8pack_mm.default, weight_int8pack_mm);
362   VK_REGISTER_OP(et_vk.linear_weight_int4.default, linear_weight_int4);
363 }
364 
365 } // namespace vkcompute
366