1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
15
16 namespace vkcompute {
17
check_q_8w_linear_args(const ComputeGraph & graph,const ValueRef mat1,const ValueRef qmat2_data,const ValueRef scales,const ValueRef out)18 void check_q_8w_linear_args(
19 const ComputeGraph& graph,
20 const ValueRef mat1,
21 const ValueRef qmat2_data,
22 const ValueRef scales,
23 const ValueRef out) {
24 std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
25 std::vector<int64_t> qmat2_sizes = graph.sizes_of(qmat2_data);
26 std::vector<int64_t> scales_sizes = graph.sizes_of(scales);
27
28 VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3);
29 VK_CHECK_COND(qmat2_sizes.size() == 2);
30 VK_CHECK_COND(scales_sizes.size() == 1);
31
32 VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out));
33
34 VK_CHECK_COND(
35 utils::val_at(-1, mat1_sizes) == utils::val_at(-1, qmat2_sizes));
36 VK_CHECK_COND(
37 utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes));
38 }
39
resize_q_8w_linear_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)40 void resize_q_8w_linear_node(
41 ComputeGraph* graph,
42 const std::vector<ArgGroup>& args,
43 const std::vector<ValueRef>& extra_args) {
44 (void)extra_args;
45
46 vTensorPtr out = graph->get_tensor(args[0].refs[0]);
47 vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]);
48 vTensorPtr qmat2 = graph->get_tensor(args[1].refs[1]);
49
50 const int out_cols = utils::val_at(-2, mat1->sizes());
51 const int out_rows = utils::val_at(-2, qmat2->sizes());
52
53 std::vector<int64_t> new_out_sizes(3);
54 if (mat1->sizes().size() == 2) {
55 new_out_sizes.resize(2);
56 new_out_sizes.at(0) = out_cols;
57 new_out_sizes.at(1) = out_rows;
58 } else {
59 new_out_sizes.at(0) = mat1->sizes().at(0);
60 new_out_sizes.at(1) = out_cols;
61 new_out_sizes.at(2) = out_rows;
62 }
63
64 out->virtual_resize(new_out_sizes);
65 }
66
add_q_8w_linear_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef q_mat2_data,const ValueRef scales_data,const ValueRef out)67 void add_q_8w_linear_node(
68 ComputeGraph& graph,
69 const ValueRef mat1,
70 const ValueRef q_mat2_data,
71 const ValueRef scales_data,
72 const ValueRef out) {
73 auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
74 ValueRef mat1_W_packed = mat1;
75 ValueRef out_W_packed = out;
76 if (!graph.is_buffer_storage(out) &&
77 graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
78 // Ensure mat1 is width packed
79 mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
80 viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
81 // Ensure out is packed correctly
82 out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
83 }
84 ValueRef q_mat2 = prepack_standard(
85 graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked);
86 ValueRef scales = prepack_standard(
87 graph, scales_data, graph.storage_type_of(out), utils::kWidthPacked);
88
89 std::string kernel_name = "q_8w_linear";
90 kernel_name.reserve(kShaderNameReserve);
91 add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed));
92 add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2));
93 add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed));
94 add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed));
95
96 vkapi::ParamsBindList ubos({});
97 if (graph.is_buffer_storage(out_W_packed)) {
98 ubos.append(
99 {graph.sizes_ubo(out_W_packed),
100 graph.strides_ubo(out_W_packed),
101 graph.numel_ubo(out_W_packed),
102 graph.sizes_ubo(mat1_W_packed),
103 graph.strides_ubo(mat1),
104 graph.strides_ubo(q_mat2),
105 graph.strides_ubo(scales)});
106 } else {
107 ubos.append(
108 {graph.logical_limits_ubo(out_W_packed),
109 graph.sizes_ubo(mat1_W_packed)});
110 }
111
112 // set global work group size to be 1 dimensional
113 const utils::uvec3 wg_size = {
114 static_cast<uint32_t>(graph.numel_of(out_W_packed)), 1, 1};
115
116 graph.execute_nodes().emplace_back(new DispatchNode(
117 graph,
118 VK_KERNEL_FROM_STR(kernel_name),
119 wg_size,
120 graph.create_local_wg_size(wg_size),
121 // Inputs and Outputs
122 {{out_W_packed, vkapi::MemoryAccessType::WRITE},
123 {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
124 // Shader params buffers
125 ubos,
126 // Specialization Constants
127 {},
128 // Resizing Logic
129 resize_q_8w_linear_node));
130 if (!graph.is_buffer_storage(out) &&
131 graph.packed_dim_of(out) != WHCN::kWidthDim) {
132 viewFn(graph, {out_W_packed, graph.add_none(), out});
133 }
134 }
135
add_q_8w_linear_optimized_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef q_mat2_data,const ValueRef scales_data,const ValueRef out)136 void add_q_8w_linear_optimized_node(
137 ComputeGraph& graph,
138 const ValueRef mat1,
139 const ValueRef q_mat2_data,
140 const ValueRef scales_data,
141 const ValueRef out) {
142 auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
143 ValueRef mat1_W_packed = mat1;
144 ValueRef out_W_packed = out;
145 if (!graph.is_buffer_storage(out) &&
146 graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
147 // Ensure mat1 is width packed
148 mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
149 viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
150 // Ensure out is packed correctly
151 out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
152 }
153
154 utils::StorageType stype = graph.storage_type_of(out);
155 ValueRef q_mat2 =
156 prepack_standard(graph, q_mat2_data, stype, utils::kWidthPacked);
157 ValueRef scales =
158 prepack_standard(graph, scales_data, stype, utils::kWidthPacked);
159
160 std::string kernel_name = "q_8w_linear_optimized";
161 kernel_name.reserve(kShaderNameReserve);
162 add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed));
163 add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2));
164 std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1_W_packed);
165 const int mat1_dims = mat1_sizes.size();
166 if (mat1_dims == 3) {
167 kernel_name = "batch_" + kernel_name;
168 }
169 if (mat1_sizes.at(mat1_dims - 2) < 8) {
170 kernel_name += "_tile_row_2";
171 } else {
172 kernel_name += "_tile_row_4";
173 }
174
175 add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed));
176 add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed));
177
178 vkapi::ParamsBindList ubos({});
179
180 utils::uvec3 global_size;
181 utils::uvec3 local_size;
182 if (graph.is_buffer_storage(out)) {
183 ubos.append(
184 {graph.sizes_ubo(out_W_packed),
185 graph.strides_ubo(out_W_packed),
186 graph.numel_ubo(out_W_packed),
187 graph.sizes_ubo(mat1_W_packed),
188 graph.strides_ubo(mat1_W_packed),
189 graph.strides_ubo(q_mat2),
190 graph.strides_ubo(scales)});
191 global_size = graph.create_global_wg_size(out_W_packed);
192 local_size = graph.create_local_wg_size(out_W_packed);
193 } else {
194 global_size = graph.logical_limits_of(out_W_packed);
195 ubos.append(
196 {graph.logical_limits_ubo(out_W_packed),
197 graph.sizes_ubo(mat1_W_packed)});
198 if (mat1_sizes.at(mat1_dims - 2) < 8) {
199 global_size = global_size = utils::divup_vec(global_size, {1, 2, 1});
200 } else {
201 global_size = utils::divup_vec(global_size, {1, 4, 1});
202 }
203 local_size = {16, 3, 1};
204 }
205
206 graph.execute_nodes().emplace_back(new DispatchNode(
207 graph,
208 VK_KERNEL_FROM_STR(kernel_name),
209 global_size,
210 local_size,
211 // Inputs and Outputs
212 {{out_W_packed, vkapi::MemoryAccessType::WRITE},
213 {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
214 // Shader params buffers
215 ubos,
216 // Specialization Constants
217 {}, // spec_vars,
218 // Resizing Logic
219 resize_q_8w_linear_node));
220
221 if (!graph.is_buffer_storage(out)) {
222 viewFn(graph, {out_W_packed, graph.add_none(), out});
223 }
224 }
225
weight_int8pack_mm(ComputeGraph & graph,const std::vector<ValueRef> & args)226 void weight_int8pack_mm(
227 ComputeGraph& graph,
228 const std::vector<ValueRef>& args) {
229 check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]);
230 return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]);
231 }
232
check_q_4w_linear_args(ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef group_size,const ValueRef scales_and_zeros,const ValueRef out)233 void check_q_4w_linear_args(
234 ComputeGraph& graph,
235 const ValueRef mat1,
236 const ValueRef mat2_data,
237 const ValueRef group_size,
238 const ValueRef scales_and_zeros,
239 const ValueRef out) {
240 VK_CHECK_COND(graph.int16_shader_types_enabled());
241 VK_CHECK_COND(graph.int8_buffers_enabled());
242
243 VK_CHECK_COND(graph.val_is_tensor(mat1));
244 VK_CHECK_COND(graph.val_is_tref(mat2_data));
245 VK_CHECK_COND(graph.val_is_tref(scales_and_zeros));
246
247 VK_CHECK_COND(graph.dim_of(mat1) <= 3);
248 VK_CHECK_COND(graph.dim_of(mat2_data) == 2);
249 VK_CHECK_COND(graph.dim_of(scales_and_zeros) == 3);
250
251 VK_CHECK_COND(graph.size_at<int>(-3, mat1) == 1);
252 const int K = graph.size_at<int>(-1, mat1);
253 VK_CHECK_COND(graph.size_at<int>(-1, mat2_data) * 2 == K);
254
255 const int group_size_val = graph.extract_scalar<int>(group_size);
256 VK_CHECK_COND(K % group_size_val == 0);
257
258 VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim);
259 VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);
260
261 VK_CHECK_COND(graph.has_standard_axis_map(mat1));
262 VK_CHECK_COND(graph.has_standard_axis_map(out));
263 }
264
resize_q_4w_linear_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)265 void resize_q_4w_linear_node(
266 ComputeGraph* graph,
267 const std::vector<ArgGroup>& args,
268 const std::vector<ValueRef>& extra_args) {
269 (void)extra_args;
270
271 vTensorPtr out = graph->get_tensor(args[0].refs[0]);
272 vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]);
273 vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]);
274
275 const int out_cols = utils::val_at(-2, mat1->sizes());
276 const int out_rows = utils::val_at(-2, mat2->sizes());
277
278 std::vector<int64_t> new_out_sizes(3);
279 if (mat1->sizes().size() == 2) {
280 new_out_sizes.resize(2);
281 new_out_sizes.at(0) = out_cols;
282 new_out_sizes.at(1) = out_rows;
283 } else {
284 new_out_sizes.at(0) = mat1->sizes().at(0);
285 new_out_sizes.at(1) = out_cols;
286 new_out_sizes.at(2) = out_rows;
287 }
288
289 out->virtual_resize(new_out_sizes);
290 }
291
add_q_4w_linear_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef group_size,const ValueRef scales_and_zeros_data,const ValueRef out)292 void add_q_4w_linear_node(
293 ComputeGraph& graph,
294 const ValueRef mat1,
295 const ValueRef mat2_data,
296 const ValueRef group_size,
297 const ValueRef scales_and_zeros_data,
298 const ValueRef out) {
299 check_q_4w_linear_args(
300 graph, mat1, mat2_data, group_size, scales_and_zeros_data, out);
301
302 utils::StorageType storage_type = graph.storage_type_of(out);
303
304 ValueRef mat2 = prepack_direct_copy_buffer(graph, mat2_data);
305
306 ValueRef scales_and_zeros = prepack_standard(
307 graph,
308 scales_and_zeros_data,
309 graph.storage_type_of(out),
310 utils::kWidthPacked);
311
312 std::string kernel_name = "q_4w_linear";
313 add_storage_type_suffix(kernel_name, storage_type);
314 add_dtype_suffix(kernel_name, graph.dtype_of(out));
315
316 const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);
317
318 vkapi::ParamsBindList ubos({});
319 ubos.append(graph.logical_limits_ubo(out));
320 ubos.append(graph.sizes_ubo(mat1));
321 ubos.append(graph.strides_ubo(mat2));
322 ubos.append(graph.strides_ubo(scales_and_zeros));
323
324 utils::uvec3 global_wg_size = graph.logical_limits_of(out);
325 utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
326
327 graph.execute_nodes().emplace_back(new DispatchNode(
328 graph,
329 VK_KERNEL_FROM_STR(kernel_name),
330 global_wg_size,
331 local_wg_size,
332 // Inputs and Outputs
333 {{out, vkapi::MemoryAccessType::WRITE},
334 {{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}},
335 // Shader params buffers
336 ubos,
337 // Specialization Constants
338 {SV(group_size_val)},
339 // Resizing Logic
340 resize_q_4w_linear_node,
341 {}));
342 }
343
linear_weight_int4(ComputeGraph & graph,const std::vector<ValueRef> & args)344 void linear_weight_int4(
345 ComputeGraph& graph,
346 const std::vector<ValueRef>& args) {
347 return add_q_4w_linear_node(
348 graph,
349 args[0], // mat1
350 args[1], // mat2
351 args[2], // group_size
352 args[3], // scales_and_zeros
353 // There is an unused variable inner_k_tiles which is used to call
354 // _convert_weight_to_int4pack in the AOT custom op, which is why the 4th
355 // argument is skipped.
356 args[5] // out
357 );
358 }
359
360 REGISTER_OPERATORS {
361 VK_REGISTER_OP(aten._weight_int8pack_mm.default, weight_int8pack_mm);
362 VK_REGISTER_OP(et_vk.linear_weight_int4.default, linear_weight_int4);
363 }
364
365 } // namespace vkcompute
366