1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <fxdiv.h>
12
13 #include <qnnpack/u8lut32norm.h>
14
compute_sum(size_t n,const uint8_t * x,const uint32_t * t)15 static inline uint32_t compute_sum(
16 size_t n,
17 const uint8_t* x,
18 const uint32_t* t) {
19 assert(n != 0);
20
21 uint32_t vsum = 0;
22 do {
23 const size_t vx = *x++;
24 vsum += t[vx];
25 } while (--n != 0);
26 return vsum;
27 }
28
pytorch_u8lut32norm_ukernel__scalar(size_t n,const uint8_t * x,const uint32_t * t,uint8_t * y)29 void pytorch_u8lut32norm_ukernel__scalar(
30 size_t n,
31 const uint8_t* x,
32 const uint32_t* t,
33 uint8_t* y) {
34 assert(n != 0);
35
36 const uint32_t vsum = compute_sum(n, x, t);
37 assert(vsum != 0);
38
39 struct fxdiv_divisor_uint32_t vsum_divisor = fxdiv_init_uint32_t(vsum);
40 const uint32_t vrounding = (vsum >> 1);
41 do {
42 const size_t vx = *x++;
43 const uint32_t vt = t[vx];
44 const uint32_t vq =
45 fxdiv_quotient_uint32_t((vt << 8) + vrounding, vsum_divisor);
46 const uint8_t vy = vq > 255 ? UINT8_C(255) : (uint8_t)vq;
47 *y++ = vy;
48 } while (--n != 0);
49 }
50