xref: /aosp_15_r20/external/executorch/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <gtest/gtest.h>
10 
11 #include <ATen/ATen.h>
12 
13 #include <executorch/backends/vulkan/runtime/api/api.h>
14 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16 
17 #include <cassert>
18 
19 //
20 // Reference Implementations
21 //
22 
linear_weight_int4_reference_impl(const at::Tensor & x,const at::Tensor & weights_4x2,const int64_t groupsize,const at::Tensor & scales_and_zeros,const int64_t inner_k_tiles)23 at::Tensor linear_weight_int4_reference_impl(
24     const at::Tensor& x,
25     const at::Tensor& weights_4x2,
26     const int64_t groupsize,
27     const at::Tensor& scales_and_zeros,
28     const int64_t inner_k_tiles) {
29   const std::vector<int64_t> original_x_size(x.sizes().vec());
30   const size_t ndim = original_x_size.size();
31   const int64_t out_features = weights_4x2.size(0);
32   const at::Tensor x_flattened = x.reshape({-1, original_x_size[ndim - 1]});
33   const at::Tensor packed_weights =
34       at::_convert_weight_to_int4pack(weights_4x2, inner_k_tiles);
35   at::Tensor out = at::_weight_int4pack_mm(
36       x_flattened, packed_weights, groupsize, scales_and_zeros);
37   std::vector<int64_t> out_shape(
38       original_x_size.begin(), original_x_size.end());
39   out_shape.at(ndim - 1) = out_features;
40   return out.reshape(out_shape);
41 }
42 
dequantize_and_linear(const at::Tensor & x,const at::Tensor & weights_4x2,const int64_t groupsize,const at::Tensor & scales_and_zeros,const int64_t inner_k_tiles)43 at::Tensor dequantize_and_linear(
44     const at::Tensor& x,
45     const at::Tensor& weights_4x2,
46     const int64_t groupsize,
47     const at::Tensor& scales_and_zeros,
48     const int64_t inner_k_tiles) {
49   std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
50   weights_shape[1] *= 2;
51 
52   at::Tensor weights_dequantized =
53       at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat));
54 
55   const int64_t N = weights_dequantized.size(0);
56   const int64_t K = weights_dequantized.size(1);
57 
58   const int k_groups = K / groupsize;
59   for (int n = 0; n < N; n++) {
60     for (int k = 0; k < K; k += 2) {
61       const int group_idx = k / groupsize;
62       // const int scale_idx = k_groups * n + group_idx;
63       const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
64       const uint8_t second_val = packed_val & 0x0F;
65       const uint8_t first_val = (packed_val & 0xF0) >> 4;
66 
67       const float scale = scales_and_zeros[group_idx][n][0].item().to<float>();
68       const float zero = scales_and_zeros[group_idx][n][1].item().to<float>();
69 
70       weights_dequantized[n][k] = (float(first_val) - 8.0) * scale + zero;
71       weights_dequantized[n][k + 1] = (float(second_val) - 8.0) * scale + zero;
72     }
73   }
74 
75   return at::linear(x, weights_dequantized);
76 }
77 
78 //
79 // Test functions
80 //
81 
test_reference_linear_int4(const int B,const int M,const int K,const int N,const int group_size=32,const int inner_k_tiles=8)82 void test_reference_linear_int4(
83     const int B,
84     const int M,
85     const int K,
86     const int N,
87     const int group_size = 32,
88     const int inner_k_tiles = 8) {
89   assert(K % group_size == 0);
90 
91   at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat));
92   at::Tensor weights_4x2 =
93       at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte));
94 
95   const int k_groups = K / group_size;
96   at::Tensor scales_and_zeros =
97       at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat));
98 
99   at::Tensor out = linear_weight_int4_reference_impl(
100       x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
101 
102   at::Tensor out_ref = dequantize_and_linear(
103       x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
104 
105   ASSERT_TRUE(at::allclose(out, out_ref));
106 }
107 
from_at_scalartype(c10::ScalarType at_scalartype)108 vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
109   using namespace vkcompute;
110   switch (at_scalartype) {
111     case c10::kFloat:
112       return vkapi::kFloat;
113     case c10::kHalf:
114       return vkapi::kHalf;
115     case c10::kInt:
116       return vkapi::kInt;
117     case c10::kLong:
118       return vkapi::kInt;
119     case c10::kChar:
120       return vkapi::kChar;
121     case c10::kByte:
122       return vkapi::kByte;
123     default:
124       VK_THROW("Unsupported at::ScalarType!");
125   }
126 }
127 
test_vulkan_linear_int4(const int B,const int M,const int K,const int N,const int group_size=32,const int inner_k_tiles=8)128 void test_vulkan_linear_int4(
129     const int B,
130     const int M,
131     const int K,
132     const int N,
133     const int group_size = 32,
134     const int inner_k_tiles = 8) {
135   assert(K % group_size == 0);
136 
137   at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat));
138   at::Tensor weights_4x2 =
139       at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte));
140 
141   const int k_groups = K / group_size;
142   at::Tensor scales_and_zeros =
143       at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat));
144 
145   at::Tensor out_ref = dequantize_and_linear(
146       x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
147 
148   // Build Vulkan graph
149   using namespace vkcompute;
150 
151   GraphConfig config;
152   config.set_storage_type_override(utils::kTexture3D);
153   ComputeGraph graph(config);
154 
155 #define MAKE_TENSORREF_FOR(x)              \
156   ValueRef r_##x = graph.add_tensorref(    \
157       x.sizes().vec(),                     \
158       from_at_scalartype(x.scalar_type()), \
159       x.const_data_ptr());
160 
161   MAKE_TENSORREF_FOR(weights_4x2);
162   MAKE_TENSORREF_FOR(scales_and_zeros);
163 
164 #define MAKE_INPUT_FOR(x)                    \
165   IOValueRef r_##x = graph.add_input_tensor( \
166       x.sizes().vec(), from_at_scalartype(x.scalar_type()));
167 
168   MAKE_INPUT_FOR(x);
169 
170   const ValueRef r_out = graph.add_tensor(
171       out_ref.sizes().vec(), from_at_scalartype(out_ref.scalar_type()));
172 
173   VK_GET_OP_FN("et_vk.linear_weight_int4.default")
174   (graph,
175    {r_x.value,
176     r_weights_4x2,
177     graph.add_scalar<int64_t>(group_size),
178     r_scales_and_zeros,
179     kDummyValueRef,
180     r_out});
181 
182   ValueRef staging_out = graph.set_output_tensor(r_out);
183 
184   graph.prepare();
185   graph.encode_prepack();
186   graph.prepack();
187   graph.encode_execute();
188 
189   //
190   // Run model
191   //
192 
193   graph.propagate_resize();
194   graph.copy_into_staging(r_x.staging, x.const_data_ptr(), x.numel());
195 
196   graph.execute();
197 
198   at::Tensor vk_out = at::empty_like(out_ref);
199   graph.copy_from_staging(
200       staging_out, vk_out.mutable_data_ptr(), vk_out.numel());
201 
202   ASSERT_TRUE(at::allclose(vk_out, out_ref, 1e-4, 1e-4));
203 }
204 
TEST(VulkanInt4LinearTest,test_reference_impl)205 TEST(VulkanInt4LinearTest, test_reference_impl) {
206   test_reference_linear_int4(
207       /*B = */ 1,
208       /*M = */ 4,
209       /*K = */ 128,
210       /*N = */ 32);
211 }
212 
TEST(VulkanInt4LinearTest,test_vulkan_impl)213 TEST(VulkanInt4LinearTest, test_vulkan_impl) {
214   if (!vkcompute::api::context()
215            ->adapter_ptr()
216            ->has_full_int8_buffers_support()) {
217     GTEST_SKIP();
218   }
219   test_vulkan_linear_int4(
220       /*B = */ 1,
221       /*M = */ 4,
222       /*K = */ 128,
223       /*N = */ 32);
224 }
225