xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/experimental/apot_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""
3This file contains utility functions to convert values
4using APoT nonuniform quantization methods.
5"""
6
7import math
8
9
10r"""Converts floating point input into APoT number
11    based on quantization levels
12"""
13
14
15def float_to_apot(x, levels, indices, alpha):
16    # clip values based on alpha
17    if x < -alpha:
18        return -alpha
19    elif x > alpha:
20        return alpha
21
22    levels_lst = list(levels)
23    indices_lst = list(indices)
24
25    min_delta = math.inf
26    best_idx = 0
27
28    for level, idx in zip(levels_lst, indices_lst):
29        cur_delta = abs(level - x)
30        if cur_delta < min_delta:
31            min_delta = cur_delta
32            best_idx = idx
33
34    return best_idx
35
36
37r"""Converts floating point input into
38    reduced precision floating point value
39    based on quantization levels
40"""
41
42
43def quant_dequant_util(x, levels, indices):
44    levels_lst = list(levels)
45    indices_lst = list(indices)
46
47    min_delta = math.inf
48    best_fp = 0.0
49
50    for level, idx in zip(levels_lst, indices_lst):
51        cur_delta = abs(level - x)
52        if cur_delta < min_delta:
53            min_delta = cur_delta
54            best_fp = level
55
56    return best_fp
57
58
59r"""Converts APoT input into floating point number
60based on quantization levels
61"""
62
63
64def apot_to_float(x_apot, levels, indices):
65    idx = list(indices).index(x_apot)
66    return levels[idx]
67