1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <c10/util/irange.h>
4 #include <cmath>
5 #include <tuple>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.h>
12 #include <ATen/ops/_fused_moving_avg_obs_fq_helper.h>
13 #include <ATen/ops/_fused_moving_avg_obs_fq_helper_native.h>
14 #include <ATen/ops/aminmax.h>
15 #include <ATen/ops/fake_quantize_per_channel_affine_cachemask.h>
16 #include <ATen/ops/fused_moving_avg_obs_fake_quant_native.h>
17 #include <ATen/ops/ones.h>
18 #include <ATen/ops/ones_like.h>
19 #endif
20
21 #ifdef USE_FBGEMM
22 #include <fbgemm/QuantUtils.h>
23 #endif
24 #include <ATen/native/quantized/cpu/QuantUtils.h>
25
26 namespace {
calculate_moving_average(const at::Tensor & x,at::Tensor & running_min,at::Tensor & running_max,float averaging_const,bool per_row_fake_quant,int ch_axis)27 void calculate_moving_average(
28 const at::Tensor& x,
29 at::Tensor& running_min,
30 at::Tensor& running_max,
31 float averaging_const,
32 bool per_row_fake_quant,
33 int ch_axis) {
34 at::Tensor x_min, x_max;
35 if (per_row_fake_quant) {
36 TORCH_CHECK(
37 ch_axis == 0,
38 "Per-channel FakeQuant in fused_moving_avg_obs_fake_quant is only supported on axis == 0");
39 std::tie(x_min, x_max) = at::aminmax(x, 1);
40 } else {
41 std::tie(x_min, x_max) = at::aminmax(x);
42 }
43 const float* min_curr_val = x_min.const_data_ptr<float>();
44 const float* max_curr_val = x_max.const_data_ptr<float>();
45 // Moving Average Min/Max observer for input tensor
46 float* running_min_val = running_min.data_ptr<float>();
47 float* running_max_val = running_max.data_ptr<float>();
48 for (const auto i : c10::irange(x_min.numel())) {
49 running_min_val[i] = std::isinf(running_min_val[i]) ? min_curr_val[i]
50 : running_min_val[i] +
51 averaging_const * (min_curr_val[i] - running_min_val[i]);
52 running_max_val[i] = std::isinf(running_max_val[i]) ? max_curr_val[i]
53 : running_max_val[i] +
54 averaging_const * (max_curr_val[i] - running_max_val[i]);
55 }
56
57 return;
58 }
59
choose_qparams_fake_quant(const at::Tensor & x,const at::Tensor & inp_running_min,const at::Tensor & inp_running_max,at::Tensor & scale,at::Tensor & zero_point,bool per_row_fake_quant,bool symmetric_quant,int qmin,int qmax,int ch_axis)60 std::tuple<at::Tensor, at::Tensor> choose_qparams_fake_quant(
61 const at::Tensor& x,
62 const at::Tensor& inp_running_min,
63 const at::Tensor& inp_running_max,
64 at::Tensor& scale,
65 at::Tensor& zero_point,
66 bool per_row_fake_quant,
67 bool symmetric_quant,
68 int qmin,
69 int qmax,
70 int ch_axis) {
71 std::tuple<at::Tensor, at::Tensor> fake_quant_out;
72 at::Tensor x_min, x_max;
73 if (per_row_fake_quant) {
74 float* x_min_data = inp_running_min.data_ptr<float>();
75 float* x_max_data = inp_running_max.data_ptr<float>();
76 for (const auto i : c10::irange(inp_running_min.numel())) {
77 #ifdef USE_FBGEMM
78 auto x_qparams = fbgemm::ChooseQuantizationParams(
79 x_min_data[i],
80 x_max_data[i],
81 qmin,
82 qmax,
83 symmetric_quant, // preserve sparsity
84 false // force power of two
85 );
86 scale[i] = x_qparams.scale;
87 zero_point[i] = x_qparams.zero_point;
88 #else
89 auto x_qparams = quant_utils::ChooseQuantizationParams(
90 x_min_data[i],
91 x_max_data[i],
92 qmin,
93 qmax,
94 symmetric_quant, // preserve sparsity
95 false // force power of two
96 );
97 scale[i] = x_qparams.scale;
98 zero_point[i] = x_qparams.zero_point;
99 #endif
100 }
101 fake_quant_out = at::fake_quantize_per_channel_affine_cachemask(
102 x, scale, zero_point, ch_axis, qmin, qmax);
103 } else {
104 #ifdef USE_FBGEMM
105 fbgemm::TensorQuantizationParams x_qparams{};
106 // compute quantization parameters using min-max values
107 x_qparams = fbgemm::ChooseQuantizationParams(
108 inp_running_min.item().toFloat(),
109 inp_running_max.item().toFloat(),
110 qmin,
111 qmax,
112 symmetric_quant, // bool preserve_sparsity
113 false // force power of two
114 );
115
116 scale[0] = x_qparams.scale;
117 zero_point[0] = x_qparams.zero_point;
118 #else
119 quant_utils::TensorQuantizationParams x_qparams{};
120 // compute quantization parameters using min-max values
121 x_qparams = quant_utils::ChooseQuantizationParams(
122 inp_running_min.item().toFloat(),
123 inp_running_max.item().toFloat(),
124 qmin,
125 qmax,
126 symmetric_quant, // bool preserve_sparsity
127 false // force power of two
128 );
129 scale[0] = x_qparams.scale;
130 zero_point[0] = x_qparams.zero_point;
131 #endif
132 auto fake_quant_enabled = at::ones(1, x.options().dtype(at::kLong));
133 fake_quant_out =
134 at::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
135 x, scale, zero_point, fake_quant_enabled, qmin, qmax);
136 }
137 return fake_quant_out;
138 }
139 } // namespace
140
141 namespace at {
142 namespace native {
143
fused_moving_avg_obs_fake_quant_cpu(const at::Tensor & self,const at::Tensor & observer_on,const at::Tensor & fake_quant_on,at::Tensor & running_min,at::Tensor & running_max,at::Tensor & scale,at::Tensor & zero_point,const double averaging_const,const int64_t quant_min,const int64_t quant_max,const int64_t ch_axis,bool per_row_fake_quant,bool symmetric_quant)144 std::tuple<at::Tensor, at::Tensor> fused_moving_avg_obs_fake_quant_cpu(
145 const at::Tensor& self,
146 const at::Tensor& observer_on,
147 const at::Tensor& fake_quant_on,
148 at::Tensor& running_min,
149 at::Tensor& running_max,
150 at::Tensor& scale,
151 at::Tensor& zero_point,
152 const double averaging_const,
153 const int64_t quant_min,
154 const int64_t quant_max,
155 const int64_t ch_axis,
156 bool per_row_fake_quant,
157 bool symmetric_quant) {
158 TORCH_CHECK(ch_axis < self.dim(), "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()");
159 // Calculate min/max
160 auto observe = observer_on.item().toInt();
161 // Calculate the size of the dimension we need to quantize over,
162 // For per-channel quant we default to axis 0, since it is only for
163 // weight quantization currently.
164 if (per_row_fake_quant) {
165 at::Tensor y = self;
166 if (self.dim() != 2) {
167 auto res = DimVector(self.sizes());
168 std::iota(res.begin(), res.end(), 0);
169 res[ch_axis] = 0;
170 res[0] = ch_axis;
171
172 y = self.permute(res);
173 y = y.flatten(1);
174 }
175 int64_t size = self.size(ch_axis);
176 if (running_min.numel() == 0) {
177 float inf = std::numeric_limits<float>::infinity();
178 running_min.resize_(size).fill_(inf);
179 running_max.resize_(size).fill_(-inf);
180 scale.resize_(size);
181 zero_point.resize_(size);
182 }
183 if (observe) {
184 calculate_moving_average(
185 y,
186 running_min,
187 running_max,
188 averaging_const,
189 per_row_fake_quant,
190 ch_axis);
191 }
192 } else {
193 if (observe) {
194 calculate_moving_average(
195 self,
196 running_min,
197 running_max,
198 averaging_const,
199 per_row_fake_quant,
200 ch_axis);
201 }
202 }
203 // Calculate qparams and fake_quantize
204 auto fake_quant = fake_quant_on.item().toInt();
205 if (fake_quant) {
206 return choose_qparams_fake_quant(
207 self,
208 running_min,
209 running_max,
210 scale,
211 zero_point,
212 per_row_fake_quant,
213 symmetric_quant,
214 quant_min,
215 quant_max,
216 ch_axis);
217 }
218 auto mask = at::ones_like(self, at::kBool, MemoryFormat::Preserve);
219 return std::make_tuple(self.clone(), mask);
220 }
221
fused_moving_avg_obs_fake_quant(const at::Tensor & self,const at::Tensor & observer_on,const at::Tensor & fake_quant_on,at::Tensor & running_min,at::Tensor & running_max,at::Tensor & scale,at::Tensor & zero_point,const double averaging_const,const int64_t quant_min,const int64_t quant_max,const int64_t ch_axis,bool per_row_fake_quant,bool symmetric_quant)222 at::Tensor fused_moving_avg_obs_fake_quant(
223 const at::Tensor& self,
224 const at::Tensor& observer_on,
225 const at::Tensor& fake_quant_on,
226 at::Tensor& running_min,
227 at::Tensor& running_max,
228 at::Tensor& scale,
229 at::Tensor& zero_point,
230 const double averaging_const,
231 const int64_t quant_min,
232 const int64_t quant_max,
233 const int64_t ch_axis,
234 bool per_row_fake_quant,
235 bool symmetric_quant) {
236 if (self.sym_numel() == 0) {
237 return self.clone();
238 }
239 const auto res = at::_fused_moving_avg_obs_fq_helper(
240 self,
241 observer_on,
242 fake_quant_on,
243 running_min,
244 running_max,
245 scale,
246 zero_point,
247 averaging_const,
248 quant_min,
249 quant_max,
250 ch_axis,
251 per_row_fake_quant,
252 symmetric_quant);
253 return std::get<0>(res);
254 }
255 } // namespace native
256 } // namespace at
257