xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/QTensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/NativeFunctions.h>
3 #include <ATen/native/TensorIterator.h>
4 #include <ATen/native/quantized/cpu/QuantUtils.h>
5 #include <ATen/quantized/QTensorImpl.h>
6 #include <ATen/quantized/Quantizer.h>
7 
8 #include <c10/util/irange.h>
9 
10 #include <cmath>
11 #include <utility>
12 
13 
14 namespace at::native {
15 
quantize_per_tensor_dynamic(const Tensor & self,ScalarType dtype,bool reduce_range)16 Tensor quantize_per_tensor_dynamic(
17     const Tensor& self,
18     ScalarType dtype,
19     bool reduce_range) {
20   TORCH_CHECK( (dtype == ScalarType::QInt8 || dtype == ScalarType::QUInt8 || dtype == ScalarType::Half), "dtype ", dtype, "not supported");
21   auto input_contig = self.contiguous();
22   if (dtype == ScalarType::Half) {
23     return input_contig.to(ScalarType::Half);
24   }
25   float x_min = input_contig.min().item<float>();
26   float x_max = input_contig.max().item<float>();
27 
28   if (reduce_range && at::globalContext().qEngine() == at::QEngine::QNNPACK) {
29     reduce_range = false;
30   }
31 
32   int qmin = 0;
33   int qmax = 0;
34 
35   if (dtype == ScalarType::QInt8) {
36     qmin = -128;
37     qmax = 127;
38   } else {
39     // for now, this branch executes for dtype == ScalarType::QUInt8
40     // additional cases will be added when quantization support for other dtypes becomes available
41     qmin = 0;
42     qmax = 255;
43   }
44 
45   auto q_params = quant_utils::ChooseQuantizationParams(
46       /*min=*/x_min,
47       /*max=*/x_max,
48       /*qmin=*/qmin,
49       /*qmax=*/qmax,
50       /*preserve_sparsity=*/false,
51       /*force_scale_power_of_two=*/false,
52       /*reduce_range=*/reduce_range);
53 
54   return at::native::quantize_per_tensor(self, q_params.scale, q_params.zero_point, dtype);
55 }
56 
quantize_per_tensor(const Tensor & self,double scale,int64_t zero_point,ScalarType dtype)57 Tensor quantize_per_tensor(
58     const Tensor& self,
59     double scale,
60     int64_t zero_point,
61     ScalarType dtype) {
62   auto quantizer = make_per_tensor_affine_quantizer(scale, zero_point, dtype);
63   return quantizer->quantize(self);
64 }
65 
quantize_per_tensor_tensor_qparams(const Tensor & self,const Tensor & scale,const Tensor & zero_point,ScalarType dtype)66 Tensor quantize_per_tensor_tensor_qparams(
67     const Tensor& self,
68     const Tensor& scale,
69     const Tensor& zero_point,
70     ScalarType dtype) {
71   auto quantizer = make_per_tensor_affine_quantizer(scale.item().toDouble(), zero_point.item().toLong(), dtype);
72   return quantizer->quantize(self);
73 }
74 
quantize_per_tensor_list_cpu(TensorList tensors,const Tensor & scales,const Tensor & zero_points,ScalarType dtype)75 std::vector<Tensor> quantize_per_tensor_list_cpu(
76     TensorList tensors,
77     const Tensor& scales,
78     const Tensor& zero_points,
79     ScalarType dtype) {
80   std::vector<Tensor> quantized_tensors;
81   for (const auto i : c10::irange(tensors.size())) {
82     quantized_tensors.push_back(at::quantize_per_tensor(
83         tensors[i],
84         scales[i].item<double>(),
85         zero_points[i].item<int64_t>(),
86         dtype));
87   }
88   return quantized_tensors;
89 }
90 
quantize_per_channel(const Tensor & self,const Tensor & scales,const Tensor & zero_points,int64_t axis,ScalarType dtype)91 Tensor quantize_per_channel(
92     const Tensor& self,
93     const Tensor& scales,
94     const Tensor& zero_points,
95     int64_t axis,
96     ScalarType dtype) {
97   auto quantizer = make_per_channel_affine_quantizer(scales, zero_points, axis, dtype);
98   return quantizer->quantize(self);
99 }
100 
dequantize_cpu_or_cuda(const Tensor & self)101 Tensor dequantize_cpu_or_cuda(const Tensor& self) {
102   return self.to(at::kFloat);
103 }
104 
dequantize_quantized(const Tensor & self)105 Tensor dequantize_quantized(const Tensor& self) {
106   return get_qtensorimpl(self)->quantizer()->dequantize(self);
107 }
108 
dequantize_tensors_quantized_cpu(TensorList tensors)109 std::vector<Tensor> dequantize_tensors_quantized_cpu(TensorList tensors) {
110   std::vector<Tensor> dequantized_tensors;
111   for (const auto & tensor : tensors) {
112     dequantized_tensors.push_back(tensor.dequantize());
113   }
114   return dequantized_tensors;
115 }
116 
q_scale_quant(const Tensor & self)117 double q_scale_quant(const Tensor& self) {
118   auto quantizer = get_qtensorimpl(self)->quantizer();
119   TORCH_CHECK(quantizer->qscheme() == kPerTensorAffine);
120   return static_cast<PerTensorAffineQuantizer*>(quantizer.get())->scale();
121 }
122 
q_zero_point_quant(const Tensor & self)123 int64_t q_zero_point_quant(const Tensor& self) {
124   auto quantizer = get_qtensorimpl(self)->quantizer();
125   TORCH_CHECK(quantizer->qscheme() == kPerTensorAffine);
126   return static_cast<PerTensorAffineQuantizer*>(quantizer.get())->zero_point();
127 }
128 
q_per_channel_scales(const Tensor & self)129 Tensor q_per_channel_scales(const Tensor& self) {
130   auto quantizer = get_qtensorimpl(self)->quantizer();
131   TORCH_CHECK(quantizer->qscheme() == kPerChannelAffine || quantizer->qscheme() == kPerChannelAffineFloatQParams);
132   return static_cast<PerChannelAffineQuantizer*>(quantizer.get())->scales();
133 }
134 
q_per_channel_zero_points(const Tensor & self)135 Tensor q_per_channel_zero_points(const Tensor& self) {
136   auto quantizer = get_qtensorimpl(self)->quantizer();
137   TORCH_CHECK(quantizer->qscheme() == kPerChannelAffine || quantizer->qscheme() == kPerChannelAffineFloatQParams);
138   return static_cast<PerChannelAffineQuantizer*>(quantizer.get())->zero_points();
139 }
140 
q_per_channel_axis(const Tensor & self)141 int64_t q_per_channel_axis(const Tensor& self) {
142   auto quantizer = get_qtensorimpl(self)->quantizer();
143   TORCH_CHECK(quantizer->qscheme() == kPerChannelAffine || quantizer->qscheme() == kPerChannelAffineFloatQParams);
144   return static_cast<PerChannelAffineQuantizer*>(quantizer.get())->axis();
145 }
146 
make_per_channel_quantized_tensor_cpu(const Tensor & self,const Tensor & scales,const Tensor & zero_points,int64_t axis)147 Tensor make_per_channel_quantized_tensor_cpu(
148     const Tensor& self,
149     const Tensor& scales,
150     const Tensor& zero_points,
151     int64_t axis) {
152   Tensor dst = at::_empty_per_channel_affine_quantized(
153       self.sizes(),
154       scales,
155       zero_points,
156       axis,
157       self.options().dtype(toQIntType(self.scalar_type())));
158   Tensor self_contig = self.contiguous();
159   AT_DISPATCH_QINT_TYPES(
160       dst.scalar_type(), "per_channel_affine_qtensor", [&]() {
161         underlying_t* self_data = self_contig.data_ptr<underlying_t>();
162         underlying_t* dst_data =
163             reinterpret_cast<underlying_t*>(dst.data_ptr<scalar_t>());
164         if (self.numel() > 0) {
165           memcpy(dst_data, self_data, self.nbytes());
166         }
167       });
168   return dst;
169 }
170 
set_storage_quantized_(Tensor & self,Storage storage,int64_t storage_offset,IntArrayRef sizes,IntArrayRef strides)171 Tensor& set_storage_quantized_(
172     Tensor& self,
173     Storage storage,
174     int64_t storage_offset,
175     IntArrayRef sizes,
176     IntArrayRef strides) {
177   auto* self_ = self.unsafeGetTensorImpl();
178   self_->set_storage_keep_dtype(std::move(storage));
179   self_->set_storage_offset(storage_offset);
180   self_->set_sizes_and_strides(sizes, strides);
181   return self;
182 }
183 
qscheme_quant(const Tensor & self)184 QScheme qscheme_quant(const Tensor& self) {
185   auto quantizer = get_qtensorimpl(self)->quantizer();
186   return quantizer->qscheme();
187 }
188 
quantized_clone(const Tensor & self,std::optional<c10::MemoryFormat> optional_memory_format)189 Tensor quantized_clone(
190     const Tensor& self,
191     std::optional<c10::MemoryFormat> optional_memory_format) {
192   auto memory_format =
193       optional_memory_format.value_or(MemoryFormat::Contiguous);
194 
195   // TODO: To support all features of MemoryFormat::Preserve we need to add
196   // _empty_affine_quantized_strided function and use it similarly to
197   // Tensor clone(const Tensor& src, std::optional<c10::MemoryFormat>
198   // optional_memory_format) if (self.is_non_overlapping_and_dense()) ->
199   // _empty_affine_quantized_strided
200   if (memory_format == MemoryFormat::Preserve) {
201     memory_format = self.suggest_memory_format();
202   }
203 
204   Tensor dst;
205   if (self.qscheme() == at::kPerTensorAffine) {
206     dst = at::_empty_affine_quantized(
207         self.sizes(),
208         self.options().memory_format(memory_format),
209         self.q_scale(),
210         self.q_zero_point(),
211         std::nullopt);
212   } else if (self.qscheme() == at::kPerChannelAffine) {
213     dst = at::_empty_per_channel_affine_quantized(
214         self.sizes(),
215         self.q_per_channel_scales(),
216         self.q_per_channel_zero_points(),
217         self.q_per_channel_axis(),
218         self.options().memory_format(memory_format),
219         std::nullopt);
220   } else {
221     TORCH_CHECK(false, "clone for quantized Tensor only works for \
222       PerTensorAffine and PerChannelAffine qscheme right now");
223   }
224 
225   at::native::copy_(dst, self, false);
226 
227   return dst;
228 }
229 
equal_quantized_cpu(const Tensor & self,const Tensor & other)230 bool equal_quantized_cpu(const Tensor& self, const Tensor& other) {
231   TORCH_CHECK(
232       self.device().type() == kCPU && other.device().type() == kCPU,
233       "quantized_equal is implemented only for the QuantizedCPU backend");
234   if (!self.is_quantized() || !other.is_quantized()) {
235     return false;
236   }
237 
238   // Delegate to virtual equalTo method. This will ensure different concrete
239   // Quantizers can have specific logic for comparison
240   auto self_quantizer = get_qtensorimpl(self)->quantizer();
241   auto other_quantizer = get_qtensorimpl(other)->quantizer();
242   if (!self_quantizer->equalTo(other_quantizer)) {
243     return false;
244   }
245 
246   // Sizes and element types must be the same
247   if (self.sizes() != other.sizes()) {
248     return false;
249   }
250   if (self.scalar_type() != other.scalar_type()) {
251     return false;
252   }
253 
254   // Data must be the same
255   auto self_contig = self.contiguous();
256   auto other_contig = other.contiguous();
257 
258   void* self_data = self_contig.data_ptr();
259   void* other_data = other_contig.data_ptr();
260   auto data_size = self.numel() * self.element_size();
261   // For QUint4x2 and QUInt2x4, two elements are packed in one byte
262   if (self.scalar_type() == kQUInt4x2 || self.scalar_type() == kQUInt2x4) {
263       TORCH_INTERNAL_ASSERT(self.element_size() == 1);
264       data_size = (data_size>>1) + (data_size&1);
265   }
266   return 0 == memcmp(self_data, other_data, data_size);
267 }
268 
269 /* Calculate the quantization params for the activation tensor */
_choose_qparams_per_tensor(const Tensor & self,bool reduce_range)270 std::tuple<double, int64_t> _choose_qparams_per_tensor(
271     const Tensor& self,
272     bool reduce_range) {
273   at::Tensor a;
274   auto input_contig = self.contiguous();
275   float x_min = input_contig.min().item<float>();
276   float x_max = input_contig.max().item<float>();
277 
278   if (reduce_range && at::globalContext().qEngine() == at::QEngine::QNNPACK) {
279     reduce_range = false;
280   }
281 
282   auto q_params = quant_utils::ChooseQuantizationParams(
283       /*min=*/x_min,
284       /*max=*/x_max,
285       /*qmin=*/0,
286       /*qmax=*/255,
287       /*preserve_sparsity=*/false,
288       /*force_scale_power_of_two=*/false,
289       /*reduce_range=*/reduce_range);
290 
291   return std::make_tuple(q_params.scale, q_params.zero_point);
292 }
293 
calculate_quant_loss(const float * input,int numel,float xmin,float xmax,float * q_input,int bit_width)294 static float calculate_quant_loss(
295     const float* input,
296     int numel,
297     float xmin,
298     float xmax,
299     float* q_input,
300     int bit_width) {
301   xmin = static_cast<at::Half>(xmin);
302   float data_range = xmax - xmin;
303   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
304   float qmax = (1 << bit_width) - 1;
305   float scale = data_range == 0
306       ? 1.0
307       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
308       : static_cast<float>(static_cast<at::Half>(data_range / qmax));
309   float inverse_scale = scale == 0 ? 1.0f : 1.0f / scale;
310 
311   float norm = 0.0f;
312   int i = 0;
313 
314   // TODO add FBGEMM kernel
315   // #ifdef USE_FBGEMM
316   // #endif
317 
318   // remainder loop
319   for (; i < numel; i++) {
320     q_input[i] = std::max(
321         0.0f, std::min<float>(std::nearbyint((input[i] - xmin) * inverse_scale), qmax));
322     q_input[i] = q_input[i] * scale + xmin;
323     norm += (input[i] - q_input[i]) * (input[i] - q_input[i]);
324   }
325   return std::sqrt(norm);
326 }
327 
328 /*
329   Helper function to find the best min/max for a tensor to calculate qparams.
330   It uses a greedy approach to nudge the min and max and calculate the l2 norm
331   and tries to minimize the quant error by doing `torch.norm(x-fake_quant(x,s,z))`
332   Returns the optimized xmax and xmin value of the tensor.
333 */
choose_qparams_optimized(const at::Tensor & input_tensor,int64_t numel,const int64_t n_bins,const double ratio,int64_t bit_width)334 std::tuple<Tensor, Tensor> choose_qparams_optimized(
335     const at::Tensor& input_tensor,
336     int64_t numel,
337     const int64_t n_bins,
338     const double ratio,
339     int64_t bit_width) {
340 
341   if (numel < 0 || numel > input_tensor.numel()) {
342     TORCH_CHECK(false, "numel is out of the bound of input tensor");
343   }
344 
345   TORCH_CHECK(numel <= input_tensor.numel(), "numel ", numel,
346       " greater than input_tensor.numel() ", input_tensor.numel());
347   const float* input_row = input_tensor.const_data_ptr<float>();
348   float xmin = *std::min_element(input_row, input_row + numel);
349   float xmax = *std::max_element(input_row, input_row + numel);
350 
351   float stepsize = (xmax - xmin) / n_bins;
352   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
353   int min_bins = n_bins * (1.0 - (float) ratio);
354   Tensor input_tensor_contig = input_tensor.contiguous();
355   const float* input = input_tensor_contig.const_data_ptr<float>();
356   std::vector<float> q_input(numel);
357 
358   float loss =
359       calculate_quant_loss(input, numel, xmin, xmax, q_input.data(), bit_width);
360   float best_loss = loss;
361 
362   float cur_min = xmin;
363   float cur_max = xmax;
364   float cur_loss = loss;
365 
366   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
367   float thr = min_bins * stepsize;
368   while (cur_min + thr < cur_max) {
369     // move left
370     float loss1 = calculate_quant_loss(
371         input, numel, cur_min + stepsize, cur_max, q_input.data(), bit_width);
372     // move right
373     float loss2 = calculate_quant_loss(
374         input, numel, cur_min, cur_max - stepsize, q_input.data(), bit_width);
375     if (cur_loss < loss1 && cur_loss < loss2 && cur_loss < best_loss) {
376       // found a local optima
377       best_loss = cur_loss;
378       xmin = cur_min;
379       xmax = cur_max;
380     }
381     if (loss1 < loss2) {
382       cur_min = cur_min + stepsize;
383       cur_loss = loss1;
384     } else {
385       cur_max = cur_max - stepsize;
386       cur_loss = loss2;
387     }
388   }
389 
390   at::Tensor xmax_tensor = at::empty({1});
391   at::Tensor xmin_tensor = at::empty({1});
392   xmax_tensor[0] = xmax;
393   xmin_tensor[0] = xmin;
394   return std::make_tuple(xmax_tensor, xmin_tensor);
395 }
396 } // namespace at::native
397