xref: /aosp_15_r20/external/XNNPACK/src/math/sqrt-u32-scalar-hashemian.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #include <assert.h>
7*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
8*4bdc9457SAndroid Build Coastguard Worker 
9*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
10*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math-stubs.h>
11*4bdc9457SAndroid Build Coastguard Worker 
12*4bdc9457SAndroid Build Coastguard Worker 
xnn_math_u32_sqrt__scalar_hashemian(size_t n,const uint32_t * input,uint32_t * output)13*4bdc9457SAndroid Build Coastguard Worker void xnn_math_u32_sqrt__scalar_hashemian(
14*4bdc9457SAndroid Build Coastguard Worker     size_t n,
15*4bdc9457SAndroid Build Coastguard Worker     const uint32_t* input,
16*4bdc9457SAndroid Build Coastguard Worker     uint32_t* output)
17*4bdc9457SAndroid Build Coastguard Worker {
18*4bdc9457SAndroid Build Coastguard Worker   assert(n % sizeof(uint32_t) == 0);
19*4bdc9457SAndroid Build Coastguard Worker 
20*4bdc9457SAndroid Build Coastguard Worker   for (; n != 0; n -= sizeof(uint32_t)) {
21*4bdc9457SAndroid Build Coastguard Worker     const uint32_t vx = *input++;
22*4bdc9457SAndroid Build Coastguard Worker 
23*4bdc9457SAndroid Build Coastguard Worker     uint32_t vy = vx;
24*4bdc9457SAndroid Build Coastguard Worker     if (vx != 0) {
25*4bdc9457SAndroid Build Coastguard Worker       /*
26*4bdc9457SAndroid Build Coastguard Worker        * Based on "Square Rooting Algorithms for Integer and Floating-Point Numbers" by Reza Hashemian
27*4bdc9457SAndroid Build Coastguard Worker        * and StackOverflow answer https://stackoverflow.com/a/31149161
28*4bdc9457SAndroid Build Coastguard Worker       */
29*4bdc9457SAndroid Build Coastguard Worker 
30*4bdc9457SAndroid Build Coastguard Worker       const uint32_t vn = math_clz_nonzero_u32(vx);
31*4bdc9457SAndroid Build Coastguard Worker       const uint32_t vleft_shift = vn & 1;
32*4bdc9457SAndroid Build Coastguard Worker       const uint32_t vm_minus_1 = 15 - (vn >> 1);
33*4bdc9457SAndroid Build Coastguard Worker       const uint32_t vm_plus_1 = vm_minus_1 + 2;
34*4bdc9457SAndroid Build Coastguard Worker       const uint32_t vexp2_m_minus_1 = UINT32_C(1) << vm_minus_1;
35*4bdc9457SAndroid Build Coastguard Worker       const uint32_t vz = vexp2_m_minus_1 - (vx >> (vm_plus_1 - vleft_shift));
36*4bdc9457SAndroid Build Coastguard Worker 
37*4bdc9457SAndroid Build Coastguard Worker       vy = vz;
38*4bdc9457SAndroid Build Coastguard Worker       // Iterate until y[i] == y[i-1]. Alternatively, we can do 7 iterations:
39*4bdc9457SAndroid Build Coastguard Worker       //   for (uint32_t i = 0; i < 7; i++) {
40*4bdc9457SAndroid Build Coastguard Worker       //     vy = vz + ((vy * vy) >> vm_plus_1);
41*4bdc9457SAndroid Build Coastguard Worker       //   }
42*4bdc9457SAndroid Build Coastguard Worker       uint32_t vy_prev;
43*4bdc9457SAndroid Build Coastguard Worker       do {
44*4bdc9457SAndroid Build Coastguard Worker         vy_prev = vy;
45*4bdc9457SAndroid Build Coastguard Worker         vy = vz + ((vy * vy) >> vm_plus_1);
46*4bdc9457SAndroid Build Coastguard Worker       } while (vy != vy_prev);
47*4bdc9457SAndroid Build Coastguard Worker 
48*4bdc9457SAndroid Build Coastguard Worker       // Reconstruct Y = 2**m - vy
49*4bdc9457SAndroid Build Coastguard Worker       vy = (vexp2_m_minus_1 << 1) - vy;
50*4bdc9457SAndroid Build Coastguard Worker       if XNN_UNPREDICTABLE(vleft_shift) {
51*4bdc9457SAndroid Build Coastguard Worker         // Multiply by sqrt(0.5) by subtracting vy * (1 - sqrt(0.5)), 1 - sqrt(0.5) is represented
52*4bdc9457SAndroid Build Coastguard Worker         // as a .16 fixed-point number to guarantee than the product doesn't overflow 32 bits.
53*4bdc9457SAndroid Build Coastguard Worker         // Using 1 - sqrt(0.5) under these constraints is 1 bit more accurate than using sqrt(0.5) directly.
54*4bdc9457SAndroid Build Coastguard Worker         vy -= (vy * UINT32_C(19195)) >> 16;
55*4bdc9457SAndroid Build Coastguard Worker       }
56*4bdc9457SAndroid Build Coastguard Worker 
57*4bdc9457SAndroid Build Coastguard Worker       // When X has an even number of bits, Y can overestimate isqrt(X) by 1 due to truncations in fixed-point
58*4bdc9457SAndroid Build Coastguard Worker       // arithmetics. When X has an odd number of bits, Y can overestimate isqrt(X) by an extra 1 (2 total) due to
59*4bdc9457SAndroid Build Coastguard Worker       // truncation in the multiplication by sqrt(0.5).
60*4bdc9457SAndroid Build Coastguard Worker       // We decrement Y once if X < Y * Y and decrement it once again if Y * Y - X > X - (Y - 1) * (Y - 1).
61*4bdc9457SAndroid Build Coastguard Worker       uint32_t vsquared_y = vy * vy;
62*4bdc9457SAndroid Build Coastguard Worker       if XNN_UNPREDICTABLE(vsquared_y > vx) {
63*4bdc9457SAndroid Build Coastguard Worker         vsquared_y -= 2 * vy - 1;
64*4bdc9457SAndroid Build Coastguard Worker         vy -= 1;
65*4bdc9457SAndroid Build Coastguard Worker       }
66*4bdc9457SAndroid Build Coastguard Worker 
67*4bdc9457SAndroid Build Coastguard Worker       // Y is within a distance of 1 from properly rounded sqrt(X).
68*4bdc9457SAndroid Build Coastguard Worker       // - Increment Y if (Y + 1) * (Y + 1) - X < X - Y * Y.
69*4bdc9457SAndroid Build Coastguard Worker       // - Decrement Y if Y * Y - X > X - (Y - 1) * (Y - 1).
70*4bdc9457SAndroid Build Coastguard Worker       // The increment + decrement are combined together to re-use the (Y * Y) value.
71*4bdc9457SAndroid Build Coastguard Worker       if XNN_UNPREDICTABLE(vsquared_y < vx - vy) {
72*4bdc9457SAndroid Build Coastguard Worker         vy += 1;
73*4bdc9457SAndroid Build Coastguard Worker       } else if XNN_UNPREDICTABLE(vsquared_y - vy >= vx) {
74*4bdc9457SAndroid Build Coastguard Worker         vy -= 1;
75*4bdc9457SAndroid Build Coastguard Worker       }
76*4bdc9457SAndroid Build Coastguard Worker     }
77*4bdc9457SAndroid Build Coastguard Worker 
78*4bdc9457SAndroid Build Coastguard Worker     *output++ = vy;
79*4bdc9457SAndroid Build Coastguard Worker   }
80*4bdc9457SAndroid Build Coastguard Worker }
81