xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/spinquant/fast_hadamard_transform.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include "fast_hadamard_transform.h"
10 
11 #include <algorithm>
12 
13 namespace executorch {
14 namespace {
15 // Normalization step: divide by sqrt(1 << log2_vec_size). Similar
16 // to fast_sqrt above, if N is even, then the maximum-precision way
17 // to do this is right-shift by log2_vec_size / 2. If N is odd, we
18 // still do the right-shift, and then we have an extra division by
19 // sqrt(2) that we perform by making use of a sufficiently accurate
20 // rational approximation. Our initial idea was to divide by sqrt(2)
21 // by adjusting the quantization scale, but that would cause this
22 // function to tend to increase the magnitude of the elements of
23 // vec, which would resulting in clipping and therefore accuracy
24 // loss, especially compounded over 30+ transformer layers.
quantized_normalize_after_fht(const int32_t * tmp,int16_t * out,int log2_vec_size,int vec_size)25 void quantized_normalize_after_fht(
26     const int32_t* tmp,
27     int16_t* out,
28     int log2_vec_size,
29     int vec_size) {
30   const int log2_sqrt_vec_size = log2_vec_size / 2;
31   constexpr int32_t qmin = -(1 << 15) + 1;
32   constexpr int32_t qmax = -qmin;
33   if (log2_vec_size % 2 != 0) {
34     // 408 / 577 - 1.0 / sqrt(2) ~= 1.062e-0.6, which should be close enough.
35     static const int32_t inv_sqrt_2_numerator = 408;
36     static const int32_t inv_sqrt_2_denominator = 577;
37     for (int ii = 0; ii < vec_size; ++ii) {
38       const auto val_over_sqrt_vec_size =
39           (tmp[ii] * inv_sqrt_2_numerator / inv_sqrt_2_denominator) >>
40           log2_sqrt_vec_size;
41       out[ii] = std::clamp(val_over_sqrt_vec_size, qmin, qmax);
42     }
43   } else {
44     for (int ii = 0; ii < vec_size; ++ii) {
45       out[ii] = std::clamp(tmp[ii] >> log2_sqrt_vec_size, qmin, qmax);
46     }
47   }
48 }
49 } // namespace
50 
fast_hadamard_transform_symmetric_quantized_s16(int16_t * vec,int log2_vec_size)51 void fast_hadamard_transform_symmetric_quantized_s16(
52     int16_t* vec,
53     int log2_vec_size) {
54   if (log2_vec_size == 0) {
55     return;
56   }
57 
58   const int vec_size = 1 << log2_vec_size;
59   // We perform log2_vec_size rounds where each round's maximum output
60   // is at most double the maximum input, so we can at most multiply
61   // the maximum input by vec_size. Performing intermediate arithmetic
62   // in 32-bit precision should prevent overflow, since 16 +
63   // log2_vec_size should be much less than 32.
64   auto tmp = std::make_unique<int32_t[]>(vec_size);
65   std::copy(vec, vec + vec_size, tmp.get());
66 
67   // Per the function-level comment above, we can ignore the
68   // quantization scale, so we just delegate to the usual unnormalized
69   // implementation.
70   // NOTE: if we need this to be fast on CPU, we can use FFHT to
71   // generate fht_uint32 similar to fht_float.
72   internal::fast_hadamard_transform_unnormalized_simple_impl(
73       tmp.get(), log2_vec_size);
74 
75   quantized_normalize_after_fht(tmp.get(), vec, log2_vec_size, vec_size);
76 }
77 
fast_hadamard_transform_symmetric_quantized_s16_28N(int16_t * vec,int log2_vec_size)78 void fast_hadamard_transform_symmetric_quantized_s16_28N(
79     int16_t* vec,
80     int log2_vec_size) {
81   if (log2_vec_size == 0) {
82     return;
83   }
84   const int vec_size = (1 << log2_vec_size);
85 
86   auto tmp = std::make_unique<int32_t[]>(vec_size * 28);
87   std::copy(vec, vec + vec_size * 28, tmp.get());
88 
89   for (int ii = 0; ii < 28; ++ii) {
90     internal::fast_hadamard_transform_unnormalized_simple_impl(
91         &tmp[ii * vec_size], log2_vec_size);
92   }
93 
94   for (int ii = 0; ii < vec_size; ++ii) {
95     hadamard_mult_28_strided(&tmp[ii], vec_size);
96   }
97 
98   quantized_normalize_after_fht(tmp.get(), vec, log2_vec_size, vec_size * 28);
99 }
100 
101 } // namespace executorch
102