xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8lut32norm/scalar.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <fxdiv.h>
12 
13 #include <qnnpack/u8lut32norm.h>
14 
compute_sum(size_t n,const uint8_t * x,const uint32_t * t)15 static inline uint32_t compute_sum(
16     size_t n,
17     const uint8_t* x,
18     const uint32_t* t) {
19   assert(n != 0);
20 
21   uint32_t vsum = 0;
22   do {
23     const size_t vx = *x++;
24     vsum += t[vx];
25   } while (--n != 0);
26   return vsum;
27 }
28 
pytorch_u8lut32norm_ukernel__scalar(size_t n,const uint8_t * x,const uint32_t * t,uint8_t * y)29 void pytorch_u8lut32norm_ukernel__scalar(
30     size_t n,
31     const uint8_t* x,
32     const uint32_t* t,
33     uint8_t* y) {
34   assert(n != 0);
35 
36   const uint32_t vsum = compute_sum(n, x, t);
37   assert(vsum != 0);
38 
39   struct fxdiv_divisor_uint32_t vsum_divisor = fxdiv_init_uint32_t(vsum);
40   const uint32_t vrounding = (vsum >> 1);
41   do {
42     const size_t vx = *x++;
43     const uint32_t vt = t[vx];
44     const uint32_t vq =
45         fxdiv_quotient_uint32_t((vt << 8) + vrounding, vsum_divisor);
46     const uint8_t vy = vq > 255 ? UINT8_C(255) : (uint8_t)vq;
47     *y++ = vy;
48   } while (--n != 0);
49 }
50