// Copyright 2022 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #include #include #include #include void xnn_math_u32_sqrt__scalar_hashemian( size_t n, const uint32_t* input, uint32_t* output) { assert(n % sizeof(uint32_t) == 0); for (; n != 0; n -= sizeof(uint32_t)) { const uint32_t vx = *input++; uint32_t vy = vx; if (vx != 0) { /* * Based on "Square Rooting Algorithms for Integer and Floating-Point Numbers" by Reza Hashemian * and StackOverflow answer https://stackoverflow.com/a/31149161 */ const uint32_t vn = math_clz_nonzero_u32(vx); const uint32_t vleft_shift = vn & 1; const uint32_t vm_minus_1 = 15 - (vn >> 1); const uint32_t vm_plus_1 = vm_minus_1 + 2; const uint32_t vexp2_m_minus_1 = UINT32_C(1) << vm_minus_1; const uint32_t vz = vexp2_m_minus_1 - (vx >> (vm_plus_1 - vleft_shift)); vy = vz; // Iterate until y[i] == y[i-1]. Alternatively, we can do 7 iterations: // for (uint32_t i = 0; i < 7; i++) { // vy = vz + ((vy * vy) >> vm_plus_1); // } uint32_t vy_prev; do { vy_prev = vy; vy = vz + ((vy * vy) >> vm_plus_1); } while (vy != vy_prev); // Reconstruct Y = 2**m - vy vy = (vexp2_m_minus_1 << 1) - vy; if XNN_UNPREDICTABLE(vleft_shift) { // Multiply by sqrt(0.5) by subtracting vy * (1 - sqrt(0.5)), 1 - sqrt(0.5) is represented // as a .16 fixed-point number to guarantee than the product doesn't overflow 32 bits. // Using 1 - sqrt(0.5) under these constraints is 1 bit more accurate than using sqrt(0.5) directly. vy -= (vy * UINT32_C(19195)) >> 16; } // When X has an even number of bits, Y can overestimate isqrt(X) by 1 due to truncations in fixed-point // arithmetics. When X has an odd number of bits, Y can overestimate isqrt(X) by an extra 1 (2 total) due to // truncation in the multiplication by sqrt(0.5). // We decrement Y once if X < Y * Y and decrement it once again if Y * Y - X > X - (Y - 1) * (Y - 1). uint32_t vsquared_y = vy * vy; if XNN_UNPREDICTABLE(vsquared_y > vx) { vsquared_y -= 2 * vy - 1; vy -= 1; } // Y is within a distance of 1 from properly rounded sqrt(X). // - Increment Y if (Y + 1) * (Y + 1) - X < X - Y * Y. // - Decrement Y if Y * Y - X > X - (Y - 1) * (Y - 1). // The increment + decrement are combined together to re-use the (Y * Y) value. if XNN_UNPREDICTABLE(vsquared_y < vx - vy) { vy += 1; } else if XNN_UNPREDICTABLE(vsquared_y - vy >= vx) { vy -= 1; } } *output++ = vy; } }