xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cpp_custom_type_hack.h>
4 #include <ATen/Context.h>
5 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
6 #include <ATen/native/quantized/cpu/init_qnnpack.h>
7 #include <ATen/native/quantized/PackedParams.h>
8 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
9 #include <ATen/native/quantized/cpu/OnednnUtils.h>
10 #include <ATen/native/quantized/cpu/QuantUtils.h>
11 #include <ATen/native/mkldnn/MKLDNNCommon.h>
12 #include <ATen/quantized/Quantizer.h>
13 #include <torch/custom_class.h>
14 #include <torch/library.h>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/_saturate_weight_to_fp16.h>
21 #include <ATen/ops/_saturate_weight_to_fp16_native.h>
22 #include <ATen/ops/dequantize.h>
23 #include <ATen/ops/empty.h>
24 #include <ATen/ops/quantize_per_tensor.h>
25 #include <ATen/ops/zeros.h>
26 #endif
27 
28 #include <c10/util/irange.h>
29 
30 #include <algorithm>
31 #include <utility>
32 #include <vector>
33 
34 int register_linear_params();
35 
36 #ifdef USE_FBGEMM
37 namespace {
38 // Calculate the column offsets.
39 // Note this includes the sum of the columns as well as the scalar term
40 // B_zero_point * K, whereas the row_offsets created by
41 // PackAWithQuantRowOffset is only the sum of the A rows.
calc_col_offsets_transpose(int K,int N,const int8_t * Bint8,int32_t * B_zero_point,int32_t * col_offsets,c10::QScheme qtype)42 void calc_col_offsets_transpose(
43     int K,
44     int N,
45     const int8_t* Bint8,
46     int32_t* B_zero_point,
47     int32_t* col_offsets,
48     c10::QScheme qtype) {
49   for (const auto i : c10::irange(N)) {
50     int32_t sum = 0;
51     for (const auto j : c10::irange(K)) {
52       sum += Bint8[i * K + j];
53     }
54     if (qtype == c10::kPerTensorAffine) {
55       col_offsets[i] = sum - B_zero_point[0] * K;
56     } else {
57       col_offsets[i] = sum - B_zero_point[i] * K;
58     }
59   }
60 }
61 } // namespace
62 
prepack(at::Tensor weight,std::optional<at::Tensor> bias)63 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeight::prepack(
64     // NOLINTNEXTLINE(performance-unnecessary-value-param)
65     at::Tensor weight,
66     std::optional<at::Tensor> bias) {
67   TORCH_CHECK(
68       weight.dim() == 2,
69       "The weight tensor for quantized::linear_prepack (fbgemm) should"
70       " be 2-dimensional.");
71 
72   auto N = weight.size(0);
73   auto K = weight.size(1);
74 
75   // TODO: contiguous is called for further JIT optimizations.
76   auto weight_contig = weight.contiguous();
77   const auto qtype = weight.qscheme();
78   std::vector<int32_t> weight_zero_points_int32(1, 0);
79   if (qtype == c10::kPerTensorAffine) {
80     weight_zero_points_int32[0] = {static_cast<int32_t>(weight.q_zero_point())};
81   } else if (qtype == c10::kPerChannelAffine) {
82     weight_zero_points_int32.resize(N, 0);
83     for (const auto i : c10::irange(N)) {
84       weight_zero_points_int32[i] =
85           weight.q_per_channel_zero_points()[i].item<int32_t>();
86     }
87   }
88   std::vector<float> weight_scales_float(1, 0.0);
89   if (qtype == c10::kPerTensorAffine) {
90     weight_scales_float[0] = {static_cast<float>(weight.q_scale())};
91   } else if (qtype == c10::kPerChannelAffine) {
92     weight_scales_float.resize(N, 0.0);
93     for (const auto i : c10::irange(N)) {
94       weight_scales_float[i] = weight.q_per_channel_scales()[i].item<float>();
95     }
96   }
97 
98   int8_t* weight_ptr_int8 =
99       reinterpret_cast<int8_t*>(weight_contig.data_ptr<c10::qint8>());
100 
101   std::vector<int32_t> col_offsets(N);
102   calc_col_offsets_transpose(
103       /*K=*/static_cast<int>(K),
104       /*N=*/static_cast<int>(N),
105       /*Bint8=*/weight_ptr_int8,
106       /*B_zero_point=*/weight_zero_points_int32.data(),
107       /*col_offsets=*/col_offsets.data(),
108       /*qtype=*/qtype);
109 
110   std::optional<at::Tensor> bias_contig;
111   if (bias.has_value()) {
112     at::Tensor bias_vec = bias.value();
113     TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
114     TORCH_CHECK(
115         bias_vec.size(0) == N,
116         "bias should have N elements: " + std::to_string(N));
117     bias_contig = bias->contiguous();
118   }
119   auto ret_ptr = c10::make_intrusive<PackedLinearWeight>(
120       std::make_unique<fbgemm::PackBMatrix<int8_t>>(
121           /*trans=*/fbgemm::matrix_op_t::Transpose,
122           /*nRow=*/K,
123           /*nCol=*/N,
124           /*smat=*/weight_ptr_int8,
125           /*ld=*/K,
126           /*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
127           /*groups=*/1),
128       bias_contig,
129       col_offsets,
130       weight_scales_float,
131       weight_zero_points_int32,
132       qtype);
133   return ret_ptr;
134 }
135 #endif // USE_FBGEMM
136 
137 #ifdef USE_PYTORCH_QNNPACK
prepack(at::Tensor weight,std::optional<at::Tensor> bias_in)138 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsQnnp::prepack(
139     // NOLINTNEXTLINE(performance-unnecessary-value-param)
140     at::Tensor weight,
141     std::optional<at::Tensor> bias_in) {
142   TORCH_CHECK(
143       weight.dim() == 2,
144       "quantized::linear_prepack (qnnpack): Weight tensor rank should be == 2");
145 
146   int64_t rows_w = weight.size(0);
147   at::Tensor bias_fp32;
148   if (bias_in.has_value()) {
149     bias_fp32 = bias_in.value();
150   } else {
151     bias_fp32 = at::zeros(rows_w, weight.options().dtype(at::kFloat));
152   }
153   TORCH_CHECK(
154       !bias_fp32.defined() ||
155           (bias_fp32.ndimension() == 1 && bias_fp32.size(0) == rows_w),
156       "quantized::linear_prepack (qnnpack): Given weight of size ",
157       weight.sizes(),
158       ", expected bias to be 1-dimensional with ",
159       rows_w,
160       " elements",
161       ", but got bias of size ",
162       bias_fp32.sizes(),
163       " instead");
164 
165   at::Tensor weight_contig = weight.contiguous();
166   auto [w_zero_points, w_scales] =
167       make_zero_points_and_scales_tensor(weight_contig);
168 
169   at::native::initQNNPACK();
170 
171   // We set the pre-packed linear weights to nullptr below as we call pre-pack
172   // during the first invocation of operator run. Refer to Linear.cpp for more
173   // details. TODO Update to actually call pre-pack here once bias is removed
174   // from pre-packing step.
175   auto wt_ptr = c10::make_intrusive<PackedLinearWeightsQnnp>(
176       nullptr,
177       weight_contig, /* int8_t weight */
178       bias_fp32.contiguous(), /* fp32 bias */
179       std::nullopt, /* input_scale */
180       w_scales,
181       std::move(w_zero_points));
182   return wt_ptr;
183 }
184 #endif // USE_PYTORCH_QNNPACK
185 
186 #ifdef USE_FBGEMM
187 
prepack(at::Tensor weight,std::optional<at::Tensor> bias)188 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightFp16::prepack(
189     // NOLINTNEXTLINE(performance-unnecessary-value-param)
190     at::Tensor weight,
191     // NOLINTNEXTLINE(performance-unnecessary-value-param)
192     std::optional<at::Tensor> bias) {
193 
194   weight = at::_saturate_weight_to_fp16(weight);
195 
196   const int64_t K = weight.size(1);
197   const int64_t N = weight.size(0);
198   at::Tensor weight_contig = weight.contiguous();
199   float* weight_contig_ptr = weight_contig.data_ptr<float>();
200 
201   // TODO(mingzhe09088):
202   // Consider using a functor here in PackedGemmMatrixFP16
203   // Comments from (XQ): Not entirely sure this make_unique is safe.
204   // make_unique is created with regular "new", and freed through
205   // TypeMetaData::deleteFn in this function. This is perfectly fine if the
206   // tensors are created and freed within this translation unit. It might be
207   // very problematic if that tensor flows across dll boundaries.
208   auto ptr = c10::make_intrusive<PackedLinearWeightFp16>(
209       std::make_unique<fbgemm::PackedGemmMatrixFP16>(
210           fbgemm::matrix_op_t::Transpose, K, N, 1, weight_contig_ptr),
211       bias);
212   return ptr;
213 }
214 #endif // USE_FBGEMM
215 
216 #if AT_MKLDNN_ENABLED()
prepack(at::Tensor weight,std::optional<at::Tensor> bias)217 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsOnednn::prepack(
218     at::Tensor weight,
219     std::optional<at::Tensor> bias) {
220   TORCH_CHECK(
221       weight.dim() == 2,
222       "The weight tensor for quantized::linear_prepack (onednn) should"
223       " be 2-dimensional.");
224   // Weight
225   std::vector<int64_t> dims = weight.sizes().vec();
226   auto N = weight.size(0);
227   std::vector<int32_t> wgt_zero_points;
228   ideep::scale_t wgt_scales;
229   const auto qtype = weight.qscheme();
230   if (qtype == c10::kPerTensorAffine) {
231     TORCH_CHECK(
232         weight.q_zero_point() == 0,
233         "quantized::linear_prepack: ONEDNN only supports symmetric quantization of weight,"
234         " whose zero point must be 0, but got ", weight.q_zero_point());
235     wgt_zero_points = std::vector<int32_t>(1, weight.q_zero_point());
236     wgt_scales = ideep::scale_t(1, 1.0/weight.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal
237   } else if (qtype == c10::kPerChannelAffine) {
238     wgt_zero_points.resize(N);
239     wgt_scales.resize(N);
240     for (int i = 0; i < N; ++i) {
241       wgt_zero_points[i] = weight.q_per_channel_zero_points()[i].item<int32_t>();
242       TORCH_CHECK(
243           wgt_zero_points[i] == 0,
244           "quantized::linear_prepack: ONEDNN only supports symmetric quantization of weight,"
245           " whose zero point must be 0, but got ",  wgt_zero_points[i], ", at index ", i);
246       wgt_scales[i] = 1.0f / weight.q_per_channel_scales()[i].item<float>(); // Scales of ONEDNN and PyTorch are reciprocal
247     }
248   } else {
249     TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype));
250   }
251 
252   // Prepack weight
253   auto weight_copy = weight.clone();
254   ideep::tensor wgt = ideep::tensor({dims, dnnl::memory::data_type::s8}, weight_copy.data_ptr());
255   wgt.transpose_(0, 1); // ONEDNN requires transposed weight
256   auto src_dims = ideep::dims(); // Unknown when prepacking
257   ideep::attr_t op_attr;
258   op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
259   auto w_desc = ideep::matmul_forward::expected_weights_desc(wgt.get_dims(), src_dims, dnnl::memory::data_type::s8,
260                                                              dnnl::memory::data_type::u8, op_attr);
261   ideep::tensor exp_wgt(w_desc);
262   exp_wgt.feed_from(wgt);
263   ideep::tensor * packed_weight_p = new ideep::tensor(std::move(exp_wgt));
264   packed_weight_p->set_scale(wgt_scales);
265   packed_weight_p->set_zero_point(wgt_zero_points);
266   std::unique_ptr<ideep::tensor> weight_ptr(packed_weight_p);
267   // Bias
268   std::optional<ideep::tensor> onednn_bias{std::nullopt};
269   if (bias.has_value()) {
270     auto& b = bias.value();
271     auto bias_size = b.sizes().vec();
272     bias_size.insert(bias_size.begin(), 1);
273     TORCH_CHECK(
274         bias_size[1] == weight_ptr->get_dim(1),
275         "bias should have N elements: ",
276         std::to_string(weight_ptr->get_dim(1)),
277         ", but got ", bias_size[1]);
278     auto bias_desc = ideep::tensor::desc(bias_size, dnnl::memory::data_type::f32);
279     ideep::tensor packed_bias;
280     packed_bias.init(bias_desc, b.data_ptr());
281     onednn_bias = std::optional<ideep::tensor>(packed_bias);
282   }
283   auto ret_ptr = c10::make_intrusive<PackedLinearWeightsOnednn>(
284       PackedLinearWeightsOnednn{
285         std::move(weight_ptr),
286         onednn_bias,
287         weight,
288         bias});
289   return ret_ptr;
290 }
291 
pack_weight_to_onednn_tensor(const at::Tensor & weight,std::optional<torch::List<int64_t>> & input_shape)292 inline at::Tensor pack_weight_to_onednn_tensor(
293     const at::Tensor& weight,
294     std::optional<torch::List<int64_t>>& input_shape) {
295   std::vector<int64_t> w_dims = weight.sizes().vec();
296   ideep::tensor wei = ideep::tensor({w_dims, dnnl::memory::data_type::s8}, weight.data_ptr());
297   wei.transpose_(0, 1); // oneDNN requires transposed weight
298   ideep::dims input_dims = input_shape.has_value() ? input_shape.value().vec() : ideep::dims();
299   ideep::attr_t op_attr;
300   op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
301   auto w_desc = ideep::matmul_forward::expected_weights_desc(
302       wei.get_dims(), input_dims, dnnl::memory::data_type::s8, dnnl::memory::data_type::u8, op_attr);
303   ideep::tensor expected_weight(w_desc);
304   expected_weight.feed_from(wei);
305   auto packed_weight = at::native::new_with_itensor_mkldnn(
306       std::move(expected_weight),
307       c10::optTypeMetaToScalarType(weight.options().dtype_opt()),
308       weight.options().device_opt());
309   return packed_weight;
310 }
311 
312 #endif // #if AT_MKLDNN_ENABLED()
313 
314 namespace at::native {
315 
_saturate_weight_to_fp16(const Tensor & weight)316 at::Tensor _saturate_weight_to_fp16(const Tensor& weight) {
317   Tensor weight_contig = weight.contiguous();
318   float* weight_contig_ptr = weight_contig.data_ptr<float>();
319   quant_utils::HandleWeightsSaturation(weight.size(0) * weight.size(1), weight_contig_ptr);
320   return weight;
321 }
322 
323 template <class... Inputs>
makeStack(Inputs &&...inputs)324 inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
325   return {std::forward<Inputs>(inputs)...};
326 }
327 
328 template <class... Args>
callOpByHandle(const c10::OperatorHandle & op,Args...args)329 inline std::vector<c10::IValue> callOpByHandle(
330     const c10::OperatorHandle& op,
331     Args... args) {
332   auto stack = makeStack(std::forward<Args>(args)...);
333   c10::Dispatcher::singleton().callBoxed(op, &stack);
334   return stack;
335 }
336 
337 template <class... Args>
callOpByName(const char * func_name,const char * overload_name,Args...args)338 inline std::vector<c10::IValue> callOpByName(
339     const char* func_name,
340     const char* overload_name,
341     Args... args) {
342   const std::optional<c10::OperatorHandle> op_handle =
343       c10::Dispatcher::singleton().findSchema({func_name, overload_name});
344   assert(op_handle.has_value());
345   return callOpByHandle(op_handle.value(), std::forward<Args>(args)...);
346 }
347 
348 at::Tensor wrapped_quantized_linear(
349     at::Tensor input,
350     const at::Tensor& input_scale,
351     const at::Tensor& input_zero_point,
352     const at::Tensor& weight,
353     const at::Tensor& weight_scale,
354     const at::Tensor& weight_zero_point,
355     const at::Tensor& bias,
356     const at::Tensor& output_scale,
357     const at::Tensor& output_zero_point,
358     [[maybe_unused]] const int64_t out_channel);
359 
wrapped_quantized_linear(at::Tensor input,const at::Tensor & input_scale,const at::Tensor & input_zero_point,const at::Tensor & weight,const at::Tensor & weight_scale,const at::Tensor & weight_zero_point,const at::Tensor & bias,const at::Tensor & output_scale,const at::Tensor & output_zero_point,const int64_t out_channel)360 at::Tensor wrapped_quantized_linear(
361     // NOLINTNEXTLINE(performance-unnecessary-value-param)
362     at::Tensor input,
363     const at::Tensor& input_scale,
364     const at::Tensor& input_zero_point,
365     const at::Tensor& weight,
366     const at::Tensor& weight_scale,
367     const at::Tensor& weight_zero_point,
368     const at::Tensor& bias,
369     const at::Tensor& output_scale,
370     const at::Tensor& output_zero_point,
371     [[maybe_unused]] const int64_t out_channel) {
372   //This op does four things:
373   // 1. Use quantize_per_tensor to quantize the input
374   // 2. Use quantized::linear_prepack to prepack the weight and bias
375   // 3. Use quantized::linear to do the int8 linear quantized computation
376   // 4. Use dequantize to dequantize the result of quantized::linear
377   // The reason we do this is because we want to have such wrapper op to
378   // bypass the issue from torch.export
379 #ifdef USE_FBGEMM
380   auto qw = at::quantize_per_tensor(
381       weight, weight_scale, weight_zero_point, c10::ScalarType::QInt8);
382   auto op = Dispatcher::singleton()
383                 .findSchemaOrThrow("quantized::linear_prepack", "")
384                 .typed<c10::intrusive_ptr<LinearPackedParamsBase>(
385                     at::Tensor, std::optional<at::Tensor>)>();
386   auto packed_params = op.call(qw, bias);
387 
388   auto qx = at::quantize_per_tensor(
389       input, input_scale, input_zero_point, c10::ScalarType::QUInt8);
390 
391   const auto scale_val = output_scale.item().toFloat();
392   const auto zero_point_val = output_zero_point.item().toLong();
393 
394   auto result = callOpByName(
395       "quantized::linear", "", qx, packed_params, scale_val, zero_point_val);
396 
397   return at::dequantize(result[0].toTensor());
398 #else // USE_FBGEMM
399   TORCH_CHECK(
400       false, "This PyTorch installation was not built with FBGEMM operators");
401 #endif // USE_FBGEMM
402 }
403 
404 at::Tensor wrapped_quantized_linear_meta(
405     at::Tensor input,
406     [[maybe_unused]] const at::Tensor& input_scale,
407     [[maybe_unused]] const at::Tensor& input_zero_point,
408     const at::Tensor& weight,
409     [[maybe_unused]] const at::Tensor& weight_scale,
410     [[maybe_unused]] const at::Tensor& weight_zero_point,
411     [[maybe_unused]] const at::Tensor& bias,
412     [[maybe_unused]] const at::Tensor& output_scale,
413     [[maybe_unused]] const at::Tensor& output_zero_point,
414     [[maybe_unused]] const int64_t out_channel);
415 
wrapped_quantized_linear_meta(at::Tensor input,const at::Tensor & input_scale,const at::Tensor & input_zero_point,const at::Tensor & weight,const at::Tensor & weight_scale,const at::Tensor & weight_zero_point,const at::Tensor & bias,const at::Tensor & output_scale,const at::Tensor & output_zero_point,const int64_t out_channel)416 at::Tensor wrapped_quantized_linear_meta(
417     // NOLINTNEXTLINE(performance-unnecessary-value-param)
418     at::Tensor input,
419     [[maybe_unused]] const at::Tensor& input_scale,
420     [[maybe_unused]] const at::Tensor& input_zero_point,
421     const at::Tensor& weight,
422     [[maybe_unused]] const at::Tensor& weight_scale,
423     [[maybe_unused]] const at::Tensor& weight_zero_point,
424     [[maybe_unused]] const at::Tensor& bias,
425     [[maybe_unused]] const at::Tensor& output_scale,
426     [[maybe_unused]] const at::Tensor& output_zero_point,
427     [[maybe_unused]] const int64_t out_channel) {
428 #ifdef USE_FBGEMM
429   const at::SymInt M = input.sym_size(0);
430   const at::SymInt N = weight.sym_size(0);
431   auto Y = at::empty_symint({M, N}, input.options().dtype(at::kFloat));
432   return Y;
433 #else // USE_FBGEMM
434   TORCH_CHECK(
435       false, "This PyTorch installation was not built with FBGEMM operators");
436 #endif // USE_FBGEMM
437 }
438 
439 at::Tensor _wrapped_linear_prepack(const at::Tensor& weight,
440     const at::Tensor& weight_scale,
441     const at::Tensor& weight_zero_point,
442     const at::Tensor& bias);
443 
_wrapped_linear_prepack(const at::Tensor & weight,const at::Tensor & weight_scale,const at::Tensor & weight_zero_point,const at::Tensor & bias)444 at::Tensor _wrapped_linear_prepack(const at::Tensor& weight,
445     const at::Tensor& weight_scale,
446     const at::Tensor& weight_zero_point,
447     const at::Tensor& bias) {
448   // This op does two things
449   // 1. Use quantize_per_tensor to quantize the weight
450   // 2. Use quantized::linear_prepack to prepack the weight and bias
451   // The reason we do this is because we want to have such wrapper op to
452   // save the quantized weight as constants for AOTI
453 #ifdef USE_FBGEMM
454   TORCH_CHECK(
455       weight.dim() == 2,
456       "fbgemm weight packing only packs matrices not vectors.");
457   auto qw = at::quantize_per_tensor(
458       weight, weight_scale, weight_zero_point, c10::ScalarType::QInt8);
459 
460   auto op = Dispatcher::singleton()
461                 .findSchemaOrThrow("quantized::linear_prepack", "")
462                 .typed<c10::intrusive_ptr<LinearPackedParamsBase>(
463                     at::Tensor, std::optional<at::Tensor>)>();
464   auto packed_params = op.call(qw, bias);
465 
466   auto unique_ptr_wrapper =
467       std::make_unique<decltype(packed_params)>(std::move(packed_params));
468   auto ret = cpp_custom_type_hack::create(
469       std::move(unique_ptr_wrapper), weight.options());
470   return ret;
471 #else // USE_FBGEMM
472   TORCH_CHECK(
473       false, "This PyTorch installation was not built with FBGEMM operators");
474 #endif // USE_FBGEMM
475 }
476 
477 at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale,
478     const at::Tensor& input_zero_point,
479     const at::Tensor& packed_weight,
480     const at::Tensor& output_scale,
481     const at::Tensor& output_zero_point,
482     [[maybe_unused]] const int64_t out_channel);
483 
_wrapped_quantized_linear_prepacked(const at::Tensor & input,const at::Tensor & input_scale,const at::Tensor & input_zero_point,const at::Tensor & packed_weight,const at::Tensor & output_scale,const at::Tensor & output_zero_point,const int64_t out_channel)484 at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale,
485     const at::Tensor& input_zero_point,
486     const at::Tensor& packed_weight,
487     const at::Tensor& output_scale,
488     const at::Tensor& output_zero_point,
489     [[maybe_unused]] const int64_t out_channel) {
490   // This op is similar to wrapped_quantized_linear, but it takes the prepacked weight
491 #ifdef USE_FBGEMM
492   auto qx = at::quantize_per_tensor(
493       input, input_scale, input_zero_point, c10::ScalarType::QUInt8);
494   const auto scale_val = output_scale.item().toFloat();
495   const auto zero_point_val = output_zero_point.item().toLong();
496   auto packed_weight_ptr =
497       // @lint-ignore CLANGTIDY facebook-hte-Deprecated
498       cpp_custom_type_hack::cast<c10::intrusive_ptr<LinearPackedParamsBase>>(
499           packed_weight);
500   auto result = callOpByName(
501       "quantized::linear", "", qx, packed_weight_ptr, scale_val, zero_point_val);
502 
503   return at::dequantize(result[0].toTensor());
504 #else // USE_FBGEMM
505   TORCH_CHECK(
506       false, "This PyTorch installation was not built with FBGEMM operators");
507 #endif // USE_FBGEMM
508 }
509 
510 at::Tensor _wrapped_linear_prepack_meta(const at::Tensor& weight,
511     [[maybe_unused]] const at::Tensor& weight_scale,
512     [[maybe_unused]] const at::Tensor& weight_zero_point,
513     [[maybe_unused]] const at::Tensor& bias);
514 
_wrapped_linear_prepack_meta(const at::Tensor & weight,const at::Tensor & weight_scale,const at::Tensor & weight_zero_point,const at::Tensor & bias)515 at::Tensor _wrapped_linear_prepack_meta(const at::Tensor& weight,
516     [[maybe_unused]] const at::Tensor& weight_scale,
517     [[maybe_unused]] const at::Tensor& weight_zero_point,
518     [[maybe_unused]] const at::Tensor& bias) {
519 #ifdef USE_FBGEMM
520   TORCH_CHECK(
521       weight.dim() == 2,
522       "fbgemm weight packing only packs matrices not vectors.");
523   const at::SymInt M = weight.sym_size(0);
524   const at::SymInt N = weight.sym_size(1);
525   auto Y = at::empty_symint({M, N}, weight.options().dtype(at::kFloat));
526   return Y;
527 #else // USE_FBGEMM
528   TORCH_CHECK(
529       false, "This PyTorch installation was not built with FBGEMM operators");
530 #endif // USE_FBGEMM
531 }
532 
533 at::Tensor _wrapped_quantized_linear_prepacked_meta(const at::Tensor& input,
534     [[maybe_unused]] const at::Tensor& input_scale,
535     [[maybe_unused]] const at::Tensor& input_zero_point,
536     [[maybe_unused]] const at::Tensor& packed_weight,
537     [[maybe_unused]] const at::Tensor& output_scale,
538     [[maybe_unused]] const at::Tensor& output_zero_point,
539     const int64_t out_channel);
540 
_wrapped_quantized_linear_prepacked_meta(const at::Tensor & input,const at::Tensor & input_scale,const at::Tensor & input_zero_point,const at::Tensor & packed_weight,const at::Tensor & output_scale,const at::Tensor & output_zero_point,const int64_t out_channel)541 at::Tensor _wrapped_quantized_linear_prepacked_meta(const at::Tensor& input,
542     [[maybe_unused]] const at::Tensor& input_scale,
543     [[maybe_unused]] const at::Tensor& input_zero_point,
544     [[maybe_unused]] const at::Tensor& packed_weight,
545     [[maybe_unused]] const at::Tensor& output_scale,
546     [[maybe_unused]] const at::Tensor& output_zero_point,
547     const int64_t out_channel) {
548 #ifdef USE_FBGEMM
549   auto out_sizes = input.sym_sizes().vec();
550   TORCH_CHECK(
551         out_sizes.size() == 2,
552         "The dimension of weight tensor should be equal to 2");
553   out_sizes[out_sizes.size() - 1] = out_channel;
554 
555   return at::empty_symint(out_sizes, input.options());
556 #else // USE_FBGEMM
557   TORCH_CHECK(
558       false, "This PyTorch installation was not built with FBGEMM operators");
559 #endif // USE_FBGEMM
560 }
561 
562 namespace {
563 
564 class QLinearPackWeightInt8 final {
565  public:
run(at::Tensor weight,std::optional<Tensor> bias)566   static c10::intrusive_ptr<LinearPackedParamsBase> run(
567       at::Tensor weight,
568       std::optional<Tensor> bias) {
569     auto& ctx = at::globalContext();
570 
571 #ifdef USE_FBGEMM
572     if (ctx.qEngine() == at::QEngine::FBGEMM ||
573         ctx.qEngine() == at::QEngine::X86) {
574       return PackedLinearWeight::prepack(std::move(weight), std::move(bias));
575     }
576 #endif
577 #ifdef USE_PYTORCH_QNNPACK
578     if (ctx.qEngine() == at::QEngine::QNNPACK) {
579       return PackedLinearWeightsQnnp::prepack(
580           std::move(weight), std::move(bias));
581     }
582 #endif
583 #if AT_MKLDNN_ENABLED()
584     if (ctx.qEngine() == at::QEngine::ONEDNN) {
585       return PackedLinearWeightsOnednn::prepack(std::move(weight), std::move(bias));
586     }
587 #endif // #if AT_MKLDNN_ENABLED()
588     TORCH_CHECK(
589         false,
590         "Didn't find engine for operation quantized::linear_prepack ",
591         toString(ctx.qEngine()));
592   }
593 };
594 
595 class QLinearPackWeightFp16 final {
596  public:
run(at::Tensor weight,std::optional<Tensor> bias)597   static c10::intrusive_ptr<LinearPackedParamsBase> run(
598       at::Tensor weight,
599       std::optional<Tensor> bias) {
600     auto& ctx = at::globalContext();
601 #ifdef USE_FBGEMM
602     // temporarily convert weight back to fp32, needs to be fixed
603     // after fbgemm fixes the interface for their prepacking op (take fp16 input0
604     weight = weight.to(ScalarType::Float);
605     if (ctx.qEngine() == at::QEngine::FBGEMM ||
606         ctx.qEngine() == at::QEngine::X86) {
607       return PackedLinearWeightFp16::prepack(
608           std::move(weight), std::move(bias));
609     }
610 #endif // USE_FBGEMM
611 #ifdef USE_PYTORCH_QNNPACK
612     if (ctx.qEngine() == at::QEngine::QNNPACK) {
613       TORCH_CHECK(
614           false,
615           "quantized::linear_prepack_fp16 is currently "
616           "not supported by QNNPACK");
617     }
618 #endif // USE_PYTORCH_QNNPACK
619 #if AT_MKLDNN_ENABLED()
620     if (ctx.qEngine() == at::QEngine::ONEDNN) {
621       TORCH_CHECK(
622           false,
623           "quantized::linear_prepack_fp16 is currently "
624           "not supported by ONEDNN");
625     }
626 #endif // #if AT_MKLDNN_ENABLED()
627     TORCH_CHECK(
628         false,
629         "Didn't find engine for operation quantized::linear_prepack_fp16 ",
630         toString(ctx.qEngine()));
631   }
632 };
633 
634 class QLinearPackWeightInt8Legacy final {
635  public:
run(at::Tensor weight,std::optional<Tensor> bias)636   static Tensor run(
637     // NOLINTNEXTLINE(performance-unnecessary-value-param)
638     [[maybe_unused]] at::Tensor weight,
639     // NOLINTNEXTLINE(performance-unnecessary-value-param)
640     [[maybe_unused]] std::optional<Tensor> bias) {
641     TORCH_CHECK(false,
642         "This model uses an outdated version of quantized.linear_prepack. "
643         "Please re-export your model using the newer definitions in torch.jit.quantized");
644   }
645 };
646 
647 class QLinearPackWeightFp16Legacy final {
648  public:
run(at::Tensor weight,std::optional<Tensor> bias)649   static Tensor run(
650     // NOLINTNEXTLINE(performance-unnecessary-value-param)
651     [[maybe_unused]] at::Tensor weight,
652     // NOLINTNEXTLINE(performance-unnecessary-value-param)
653     [[maybe_unused]] std::optional<Tensor> bias) {
654     TORCH_CHECK(false,
655         "This model uses an outdated version of quantized.linear_prepack_fp16. "
656         "Please re-export your model using the newer definitions in torch.jit.quantized");
657   }
658 };
659 
660 class QLinearPackWeightInt8Onednn final {
661  public:
run(at::Tensor weight,std::optional<torch::List<int64_t>> input_shape)662   static at::Tensor run(
663     // NOLINTNEXTLINE(performance-unnecessary-value-param)
664     [[maybe_unused]] at::Tensor weight, // Not QTensor
665     // NOLINTNEXTLINE(performance-unnecessary-value-param)
666     [[maybe_unused]] std::optional<torch::List<int64_t>> input_shape) {
667 #if AT_MKLDNN_ENABLED()
668     return pack_weight_to_onednn_tensor(weight, input_shape);
669 #else
670     TORCH_CHECK(false, "Unimplemented as onednn is not available.");
671 #endif
672   }
673 };
674 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)675 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
676   register_linear_params();
677   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8::run));
678   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack_legacy"), TORCH_FN(QLinearPackWeightInt8Legacy::run));
679 }
680 
TORCH_LIBRARY_IMPL(quantized,CPU,m)681 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
682   register_linear_params();
683   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16::run));
684   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run));
685 }
686 
TORCH_LIBRARY_IMPL(_quantized,QuantizedCPU,m)687 TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
688   register_linear_params();
689   m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8::run));
690 }
691 
TORCH_LIBRARY_IMPL(_quantized,CPU,m)692 TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
693   register_linear_params();
694   m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16::run));
695   m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run));
696   m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear));
697   m.impl(
698       TORCH_SELECTIVE_NAME("_quantized::_wrapped_linear_prepack"),
699       _wrapped_linear_prepack);
700   m.impl(
701       TORCH_SELECTIVE_NAME("_quantized::_wrapped_quantized_linear_prepacked"),
702       _wrapped_quantized_linear_prepacked);
703 }
704 
TORCH_LIBRARY_IMPL(_quantized,Meta,m)705 TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
706   m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear_meta));
707   m.impl(
708       TORCH_SELECTIVE_NAME("_quantized::_wrapped_linear_prepack"),
709       _wrapped_linear_prepack_meta);
710   m.impl(
711       TORCH_SELECTIVE_NAME("_quantized::_wrapped_quantized_linear_prepacked"),
712       _wrapped_quantized_linear_prepacked_meta);
713 }
714 
TORCH_LIBRARY_IMPL(onednn,CPU,m)715 TORCH_LIBRARY_IMPL(onednn, CPU, m) {
716   m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_prepack"), TORCH_FN(QLinearPackWeightInt8Onednn::run));
717 }
718 
719 } // namespace
720 } // namespace at::native
721