1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #include <assert.h> 7*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h> 8*4bdc9457SAndroid Build Coastguard Worker 9*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h> 10*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math-stubs.h> 11*4bdc9457SAndroid Build Coastguard Worker 12*4bdc9457SAndroid Build Coastguard Worker xnn_math_u32_sqrt__scalar_hashemian(size_t n,const uint32_t * input,uint32_t * output)13*4bdc9457SAndroid Build Coastguard Workervoid xnn_math_u32_sqrt__scalar_hashemian( 14*4bdc9457SAndroid Build Coastguard Worker size_t n, 15*4bdc9457SAndroid Build Coastguard Worker const uint32_t* input, 16*4bdc9457SAndroid Build Coastguard Worker uint32_t* output) 17*4bdc9457SAndroid Build Coastguard Worker { 18*4bdc9457SAndroid Build Coastguard Worker assert(n % sizeof(uint32_t) == 0); 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Worker for (; n != 0; n -= sizeof(uint32_t)) { 21*4bdc9457SAndroid Build Coastguard Worker const uint32_t vx = *input++; 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker uint32_t vy = vx; 24*4bdc9457SAndroid Build Coastguard Worker if (vx != 0) { 25*4bdc9457SAndroid Build Coastguard Worker /* 26*4bdc9457SAndroid Build Coastguard Worker * Based on "Square Rooting Algorithms for Integer and Floating-Point Numbers" by Reza Hashemian 27*4bdc9457SAndroid Build Coastguard Worker * and StackOverflow answer https://stackoverflow.com/a/31149161 28*4bdc9457SAndroid Build Coastguard Worker */ 29*4bdc9457SAndroid Build Coastguard Worker 30*4bdc9457SAndroid Build Coastguard Worker const uint32_t vn = math_clz_nonzero_u32(vx); 31*4bdc9457SAndroid Build Coastguard Worker const uint32_t vleft_shift = vn & 1; 32*4bdc9457SAndroid Build Coastguard Worker const uint32_t vm_minus_1 = 15 - (vn >> 1); 33*4bdc9457SAndroid Build Coastguard Worker const uint32_t vm_plus_1 = vm_minus_1 + 2; 34*4bdc9457SAndroid Build Coastguard Worker const uint32_t vexp2_m_minus_1 = UINT32_C(1) << vm_minus_1; 35*4bdc9457SAndroid Build Coastguard Worker const uint32_t vz = vexp2_m_minus_1 - (vx >> (vm_plus_1 - vleft_shift)); 36*4bdc9457SAndroid Build Coastguard Worker 37*4bdc9457SAndroid Build Coastguard Worker vy = vz; 38*4bdc9457SAndroid Build Coastguard Worker // Iterate until y[i] == y[i-1]. Alternatively, we can do 7 iterations: 39*4bdc9457SAndroid Build Coastguard Worker // for (uint32_t i = 0; i < 7; i++) { 40*4bdc9457SAndroid Build Coastguard Worker // vy = vz + ((vy * vy) >> vm_plus_1); 41*4bdc9457SAndroid Build Coastguard Worker // } 42*4bdc9457SAndroid Build Coastguard Worker uint32_t vy_prev; 43*4bdc9457SAndroid Build Coastguard Worker do { 44*4bdc9457SAndroid Build Coastguard Worker vy_prev = vy; 45*4bdc9457SAndroid Build Coastguard Worker vy = vz + ((vy * vy) >> vm_plus_1); 46*4bdc9457SAndroid Build Coastguard Worker } while (vy != vy_prev); 47*4bdc9457SAndroid Build Coastguard Worker 48*4bdc9457SAndroid Build Coastguard Worker // Reconstruct Y = 2**m - vy 49*4bdc9457SAndroid Build Coastguard Worker vy = (vexp2_m_minus_1 << 1) - vy; 50*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(vleft_shift) { 51*4bdc9457SAndroid Build Coastguard Worker // Multiply by sqrt(0.5) by subtracting vy * (1 - sqrt(0.5)), 1 - sqrt(0.5) is represented 52*4bdc9457SAndroid Build Coastguard Worker // as a .16 fixed-point number to guarantee than the product doesn't overflow 32 bits. 53*4bdc9457SAndroid Build Coastguard Worker // Using 1 - sqrt(0.5) under these constraints is 1 bit more accurate than using sqrt(0.5) directly. 54*4bdc9457SAndroid Build Coastguard Worker vy -= (vy * UINT32_C(19195)) >> 16; 55*4bdc9457SAndroid Build Coastguard Worker } 56*4bdc9457SAndroid Build Coastguard Worker 57*4bdc9457SAndroid Build Coastguard Worker // When X has an even number of bits, Y can overestimate isqrt(X) by 1 due to truncations in fixed-point 58*4bdc9457SAndroid Build Coastguard Worker // arithmetics. When X has an odd number of bits, Y can overestimate isqrt(X) by an extra 1 (2 total) due to 59*4bdc9457SAndroid Build Coastguard Worker // truncation in the multiplication by sqrt(0.5). 60*4bdc9457SAndroid Build Coastguard Worker // We decrement Y once if X < Y * Y and decrement it once again if Y * Y - X > X - (Y - 1) * (Y - 1). 61*4bdc9457SAndroid Build Coastguard Worker uint32_t vsquared_y = vy * vy; 62*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(vsquared_y > vx) { 63*4bdc9457SAndroid Build Coastguard Worker vsquared_y -= 2 * vy - 1; 64*4bdc9457SAndroid Build Coastguard Worker vy -= 1; 65*4bdc9457SAndroid Build Coastguard Worker } 66*4bdc9457SAndroid Build Coastguard Worker 67*4bdc9457SAndroid Build Coastguard Worker // Y is within a distance of 1 from properly rounded sqrt(X). 68*4bdc9457SAndroid Build Coastguard Worker // - Increment Y if (Y + 1) * (Y + 1) - X < X - Y * Y. 69*4bdc9457SAndroid Build Coastguard Worker // - Decrement Y if Y * Y - X > X - (Y - 1) * (Y - 1). 70*4bdc9457SAndroid Build Coastguard Worker // The increment + decrement are combined together to re-use the (Y * Y) value. 71*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(vsquared_y < vx - vy) { 72*4bdc9457SAndroid Build Coastguard Worker vy += 1; 73*4bdc9457SAndroid Build Coastguard Worker } else if XNN_UNPREDICTABLE(vsquared_y - vy >= vx) { 74*4bdc9457SAndroid Build Coastguard Worker vy -= 1; 75*4bdc9457SAndroid Build Coastguard Worker } 76*4bdc9457SAndroid Build Coastguard Worker } 77*4bdc9457SAndroid Build Coastguard Worker 78*4bdc9457SAndroid Build Coastguard Worker *output++ = vy; 79*4bdc9457SAndroid Build Coastguard Worker } 80*4bdc9457SAndroid Build Coastguard Worker } 81