1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <gtest/gtest.h>
10
11 #include <ATen/ATen.h>
12
13 #include <executorch/backends/vulkan/runtime/api/api.h>
14 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16
17 #include <cassert>
18
19 //
20 // Reference Implementations
21 //
22
linear_weight_int4_reference_impl(const at::Tensor & x,const at::Tensor & weights_4x2,const int64_t groupsize,const at::Tensor & scales_and_zeros,const int64_t inner_k_tiles)23 at::Tensor linear_weight_int4_reference_impl(
24 const at::Tensor& x,
25 const at::Tensor& weights_4x2,
26 const int64_t groupsize,
27 const at::Tensor& scales_and_zeros,
28 const int64_t inner_k_tiles) {
29 const std::vector<int64_t> original_x_size(x.sizes().vec());
30 const size_t ndim = original_x_size.size();
31 const int64_t out_features = weights_4x2.size(0);
32 const at::Tensor x_flattened = x.reshape({-1, original_x_size[ndim - 1]});
33 const at::Tensor packed_weights =
34 at::_convert_weight_to_int4pack(weights_4x2, inner_k_tiles);
35 at::Tensor out = at::_weight_int4pack_mm(
36 x_flattened, packed_weights, groupsize, scales_and_zeros);
37 std::vector<int64_t> out_shape(
38 original_x_size.begin(), original_x_size.end());
39 out_shape.at(ndim - 1) = out_features;
40 return out.reshape(out_shape);
41 }
42
dequantize_and_linear(const at::Tensor & x,const at::Tensor & weights_4x2,const int64_t groupsize,const at::Tensor & scales_and_zeros,const int64_t inner_k_tiles)43 at::Tensor dequantize_and_linear(
44 const at::Tensor& x,
45 const at::Tensor& weights_4x2,
46 const int64_t groupsize,
47 const at::Tensor& scales_and_zeros,
48 const int64_t inner_k_tiles) {
49 std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
50 weights_shape[1] *= 2;
51
52 at::Tensor weights_dequantized =
53 at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat));
54
55 const int64_t N = weights_dequantized.size(0);
56 const int64_t K = weights_dequantized.size(1);
57
58 const int k_groups = K / groupsize;
59 for (int n = 0; n < N; n++) {
60 for (int k = 0; k < K; k += 2) {
61 const int group_idx = k / groupsize;
62 // const int scale_idx = k_groups * n + group_idx;
63 const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
64 const uint8_t second_val = packed_val & 0x0F;
65 const uint8_t first_val = (packed_val & 0xF0) >> 4;
66
67 const float scale = scales_and_zeros[group_idx][n][0].item().to<float>();
68 const float zero = scales_and_zeros[group_idx][n][1].item().to<float>();
69
70 weights_dequantized[n][k] = (float(first_val) - 8.0) * scale + zero;
71 weights_dequantized[n][k + 1] = (float(second_val) - 8.0) * scale + zero;
72 }
73 }
74
75 return at::linear(x, weights_dequantized);
76 }
77
78 //
79 // Test functions
80 //
81
test_reference_linear_int4(const int B,const int M,const int K,const int N,const int group_size=32,const int inner_k_tiles=8)82 void test_reference_linear_int4(
83 const int B,
84 const int M,
85 const int K,
86 const int N,
87 const int group_size = 32,
88 const int inner_k_tiles = 8) {
89 assert(K % group_size == 0);
90
91 at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat));
92 at::Tensor weights_4x2 =
93 at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte));
94
95 const int k_groups = K / group_size;
96 at::Tensor scales_and_zeros =
97 at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat));
98
99 at::Tensor out = linear_weight_int4_reference_impl(
100 x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
101
102 at::Tensor out_ref = dequantize_and_linear(
103 x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
104
105 ASSERT_TRUE(at::allclose(out, out_ref));
106 }
107
from_at_scalartype(c10::ScalarType at_scalartype)108 vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
109 using namespace vkcompute;
110 switch (at_scalartype) {
111 case c10::kFloat:
112 return vkapi::kFloat;
113 case c10::kHalf:
114 return vkapi::kHalf;
115 case c10::kInt:
116 return vkapi::kInt;
117 case c10::kLong:
118 return vkapi::kInt;
119 case c10::kChar:
120 return vkapi::kChar;
121 case c10::kByte:
122 return vkapi::kByte;
123 default:
124 VK_THROW("Unsupported at::ScalarType!");
125 }
126 }
127
test_vulkan_linear_int4(const int B,const int M,const int K,const int N,const int group_size=32,const int inner_k_tiles=8)128 void test_vulkan_linear_int4(
129 const int B,
130 const int M,
131 const int K,
132 const int N,
133 const int group_size = 32,
134 const int inner_k_tiles = 8) {
135 assert(K % group_size == 0);
136
137 at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat));
138 at::Tensor weights_4x2 =
139 at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte));
140
141 const int k_groups = K / group_size;
142 at::Tensor scales_and_zeros =
143 at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat));
144
145 at::Tensor out_ref = dequantize_and_linear(
146 x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
147
148 // Build Vulkan graph
149 using namespace vkcompute;
150
151 GraphConfig config;
152 config.set_storage_type_override(utils::kTexture3D);
153 ComputeGraph graph(config);
154
155 #define MAKE_TENSORREF_FOR(x) \
156 ValueRef r_##x = graph.add_tensorref( \
157 x.sizes().vec(), \
158 from_at_scalartype(x.scalar_type()), \
159 x.const_data_ptr());
160
161 MAKE_TENSORREF_FOR(weights_4x2);
162 MAKE_TENSORREF_FOR(scales_and_zeros);
163
164 #define MAKE_INPUT_FOR(x) \
165 IOValueRef r_##x = graph.add_input_tensor( \
166 x.sizes().vec(), from_at_scalartype(x.scalar_type()));
167
168 MAKE_INPUT_FOR(x);
169
170 const ValueRef r_out = graph.add_tensor(
171 out_ref.sizes().vec(), from_at_scalartype(out_ref.scalar_type()));
172
173 VK_GET_OP_FN("et_vk.linear_weight_int4.default")
174 (graph,
175 {r_x.value,
176 r_weights_4x2,
177 graph.add_scalar<int64_t>(group_size),
178 r_scales_and_zeros,
179 kDummyValueRef,
180 r_out});
181
182 ValueRef staging_out = graph.set_output_tensor(r_out);
183
184 graph.prepare();
185 graph.encode_prepack();
186 graph.prepack();
187 graph.encode_execute();
188
189 //
190 // Run model
191 //
192
193 graph.propagate_resize();
194 graph.copy_into_staging(r_x.staging, x.const_data_ptr(), x.numel());
195
196 graph.execute();
197
198 at::Tensor vk_out = at::empty_like(out_ref);
199 graph.copy_from_staging(
200 staging_out, vk_out.mutable_data_ptr(), vk_out.numel());
201
202 ASSERT_TRUE(at::allclose(vk_out, out_ref, 1e-4, 1e-4));
203 }
204
TEST(VulkanInt4LinearTest,test_reference_impl)205 TEST(VulkanInt4LinearTest, test_reference_impl) {
206 test_reference_linear_int4(
207 /*B = */ 1,
208 /*M = */ 4,
209 /*K = */ 128,
210 /*N = */ 32);
211 }
212
TEST(VulkanInt4LinearTest,test_vulkan_impl)213 TEST(VulkanInt4LinearTest, test_vulkan_impl) {
214 if (!vkcompute::api::context()
215 ->adapter_ptr()
216 ->has_full_int8_buffers_support()) {
217 GTEST_SKIP();
218 }
219 test_vulkan_linear_int4(
220 /*B = */ 1,
221 /*M = */ 4,
222 /*K = */ 128,
223 /*N = */ 32);
224 }
225