xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/fused_obs_fake_quant.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <c10/util/irange.h>
4 #include <cmath>
5 #include <tuple>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.h>
12 #include <ATen/ops/_fused_moving_avg_obs_fq_helper.h>
13 #include <ATen/ops/_fused_moving_avg_obs_fq_helper_native.h>
14 #include <ATen/ops/aminmax.h>
15 #include <ATen/ops/fake_quantize_per_channel_affine_cachemask.h>
16 #include <ATen/ops/fused_moving_avg_obs_fake_quant_native.h>
17 #include <ATen/ops/ones.h>
18 #include <ATen/ops/ones_like.h>
19 #endif
20 
21 #ifdef USE_FBGEMM
22 #include <fbgemm/QuantUtils.h>
23 #endif
24 #include <ATen/native/quantized/cpu/QuantUtils.h>
25 
26 namespace {
calculate_moving_average(const at::Tensor & x,at::Tensor & running_min,at::Tensor & running_max,float averaging_const,bool per_row_fake_quant,int ch_axis)27 void calculate_moving_average(
28     const at::Tensor& x,
29     at::Tensor& running_min,
30     at::Tensor& running_max,
31     float averaging_const,
32     bool per_row_fake_quant,
33     int ch_axis) {
34   at::Tensor x_min, x_max;
35   if (per_row_fake_quant) {
36     TORCH_CHECK(
37         ch_axis == 0,
38         "Per-channel FakeQuant in fused_moving_avg_obs_fake_quant is only supported on axis == 0");
39     std::tie(x_min, x_max) = at::aminmax(x, 1);
40   } else {
41     std::tie(x_min, x_max) = at::aminmax(x);
42   }
43   const float* min_curr_val = x_min.const_data_ptr<float>();
44   const float* max_curr_val = x_max.const_data_ptr<float>();
45   // Moving Average Min/Max observer for input tensor
46   float* running_min_val = running_min.data_ptr<float>();
47   float* running_max_val = running_max.data_ptr<float>();
48   for (const auto i : c10::irange(x_min.numel())) {
49     running_min_val[i] = std::isinf(running_min_val[i]) ? min_curr_val[i]
50                                                         : running_min_val[i] +
51             averaging_const * (min_curr_val[i] - running_min_val[i]);
52     running_max_val[i] = std::isinf(running_max_val[i]) ? max_curr_val[i]
53                                                         : running_max_val[i] +
54             averaging_const * (max_curr_val[i] - running_max_val[i]);
55   }
56 
57   return;
58 }
59 
choose_qparams_fake_quant(const at::Tensor & x,const at::Tensor & inp_running_min,const at::Tensor & inp_running_max,at::Tensor & scale,at::Tensor & zero_point,bool per_row_fake_quant,bool symmetric_quant,int qmin,int qmax,int ch_axis)60 std::tuple<at::Tensor, at::Tensor> choose_qparams_fake_quant(
61     const at::Tensor& x,
62     const at::Tensor& inp_running_min,
63     const at::Tensor& inp_running_max,
64     at::Tensor& scale,
65     at::Tensor& zero_point,
66     bool per_row_fake_quant,
67     bool symmetric_quant,
68     int qmin,
69     int qmax,
70     int ch_axis) {
71   std::tuple<at::Tensor, at::Tensor> fake_quant_out;
72   at::Tensor x_min, x_max;
73   if (per_row_fake_quant) {
74     float* x_min_data = inp_running_min.data_ptr<float>();
75     float* x_max_data = inp_running_max.data_ptr<float>();
76     for (const auto i : c10::irange(inp_running_min.numel())) {
77 #ifdef USE_FBGEMM
78       auto x_qparams = fbgemm::ChooseQuantizationParams(
79           x_min_data[i],
80           x_max_data[i],
81           qmin,
82           qmax,
83           symmetric_quant, // preserve sparsity
84           false // force power of two
85       );
86       scale[i] = x_qparams.scale;
87       zero_point[i] = x_qparams.zero_point;
88 #else
89       auto x_qparams = quant_utils::ChooseQuantizationParams(
90           x_min_data[i],
91           x_max_data[i],
92           qmin,
93           qmax,
94           symmetric_quant, // preserve sparsity
95           false // force power of two
96       );
97       scale[i] = x_qparams.scale;
98       zero_point[i] = x_qparams.zero_point;
99 #endif
100     }
101     fake_quant_out = at::fake_quantize_per_channel_affine_cachemask(
102         x, scale, zero_point, ch_axis, qmin, qmax);
103   } else {
104 #ifdef USE_FBGEMM
105     fbgemm::TensorQuantizationParams x_qparams{};
106     // compute quantization parameters using min-max values
107     x_qparams = fbgemm::ChooseQuantizationParams(
108         inp_running_min.item().toFloat(),
109         inp_running_max.item().toFloat(),
110         qmin,
111         qmax,
112         symmetric_quant, // bool preserve_sparsity
113         false // force power of two
114     );
115 
116     scale[0] = x_qparams.scale;
117     zero_point[0] = x_qparams.zero_point;
118 #else
119     quant_utils::TensorQuantizationParams x_qparams{};
120     // compute quantization parameters using min-max values
121     x_qparams = quant_utils::ChooseQuantizationParams(
122         inp_running_min.item().toFloat(),
123         inp_running_max.item().toFloat(),
124         qmin,
125         qmax,
126         symmetric_quant, // bool preserve_sparsity
127         false // force power of two
128     );
129     scale[0] = x_qparams.scale;
130     zero_point[0] = x_qparams.zero_point;
131 #endif
132     auto fake_quant_enabled = at::ones(1, x.options().dtype(at::kLong));
133     fake_quant_out =
134         at::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
135             x, scale, zero_point, fake_quant_enabled, qmin, qmax);
136   }
137   return fake_quant_out;
138 }
139 } // namespace
140 
141 namespace at {
142 namespace native {
143 
fused_moving_avg_obs_fake_quant_cpu(const at::Tensor & self,const at::Tensor & observer_on,const at::Tensor & fake_quant_on,at::Tensor & running_min,at::Tensor & running_max,at::Tensor & scale,at::Tensor & zero_point,const double averaging_const,const int64_t quant_min,const int64_t quant_max,const int64_t ch_axis,bool per_row_fake_quant,bool symmetric_quant)144 std::tuple<at::Tensor, at::Tensor> fused_moving_avg_obs_fake_quant_cpu(
145     const at::Tensor& self,
146     const at::Tensor& observer_on,
147     const at::Tensor& fake_quant_on,
148     at::Tensor& running_min,
149     at::Tensor& running_max,
150     at::Tensor& scale,
151     at::Tensor& zero_point,
152     const double averaging_const,
153     const int64_t quant_min,
154     const int64_t quant_max,
155     const int64_t ch_axis,
156     bool per_row_fake_quant,
157     bool symmetric_quant) {
158   TORCH_CHECK(ch_axis < self.dim(), "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()");
159   // Calculate min/max
160   auto observe = observer_on.item().toInt();
161   // Calculate the size of the dimension we need to quantize over,
162   // For per-channel quant we default to axis 0, since it is only for
163   // weight quantization currently.
164   if (per_row_fake_quant) {
165     at::Tensor y = self;
166     if (self.dim() != 2) {
167       auto res = DimVector(self.sizes());
168       std::iota(res.begin(), res.end(), 0);
169       res[ch_axis] = 0;
170       res[0] = ch_axis;
171 
172       y = self.permute(res);
173       y = y.flatten(1);
174     }
175     int64_t size = self.size(ch_axis);
176     if (running_min.numel() == 0) {
177       float inf = std::numeric_limits<float>::infinity();
178       running_min.resize_(size).fill_(inf);
179       running_max.resize_(size).fill_(-inf);
180       scale.resize_(size);
181       zero_point.resize_(size);
182     }
183     if (observe) {
184       calculate_moving_average(
185           y,
186           running_min,
187           running_max,
188           averaging_const,
189           per_row_fake_quant,
190           ch_axis);
191     }
192   } else {
193     if (observe) {
194       calculate_moving_average(
195           self,
196           running_min,
197           running_max,
198           averaging_const,
199           per_row_fake_quant,
200           ch_axis);
201     }
202   }
203   // Calculate qparams and fake_quantize
204   auto fake_quant = fake_quant_on.item().toInt();
205   if (fake_quant) {
206     return choose_qparams_fake_quant(
207         self,
208         running_min,
209         running_max,
210         scale,
211         zero_point,
212         per_row_fake_quant,
213         symmetric_quant,
214         quant_min,
215         quant_max,
216         ch_axis);
217   }
218   auto mask = at::ones_like(self, at::kBool, MemoryFormat::Preserve);
219   return std::make_tuple(self.clone(), mask);
220 }
221 
fused_moving_avg_obs_fake_quant(const at::Tensor & self,const at::Tensor & observer_on,const at::Tensor & fake_quant_on,at::Tensor & running_min,at::Tensor & running_max,at::Tensor & scale,at::Tensor & zero_point,const double averaging_const,const int64_t quant_min,const int64_t quant_max,const int64_t ch_axis,bool per_row_fake_quant,bool symmetric_quant)222 at::Tensor fused_moving_avg_obs_fake_quant(
223     const at::Tensor& self,
224     const at::Tensor& observer_on,
225     const at::Tensor& fake_quant_on,
226     at::Tensor& running_min,
227     at::Tensor& running_max,
228     at::Tensor& scale,
229     at::Tensor& zero_point,
230     const double averaging_const,
231     const int64_t quant_min,
232     const int64_t quant_max,
233     const int64_t ch_axis,
234     bool per_row_fake_quant,
235     bool symmetric_quant) {
236   if (self.sym_numel() == 0) {
237     return self.clone();
238   }
239   const auto res = at::_fused_moving_avg_obs_fq_helper(
240       self,
241       observer_on,
242       fake_quant_on,
243       running_min,
244       running_max,
245       scale,
246       zero_point,
247       averaging_const,
248       quant_min,
249       quant_max,
250       ch_axis,
251       per_row_fake_quant,
252       symmetric_quant);
253   return std::get<0>(res);
254 }
255 } // namespace native
256 } // namespace at
257