1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cpp_custom_type_hack.h>
4 #include <ATen/Context.h>
5 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
6 #include <ATen/native/quantized/cpu/init_qnnpack.h>
7 #include <ATen/native/quantized/PackedParams.h>
8 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
9 #include <ATen/native/quantized/cpu/OnednnUtils.h>
10 #include <ATen/native/quantized/cpu/QuantUtils.h>
11 #include <ATen/native/mkldnn/MKLDNNCommon.h>
12 #include <ATen/quantized/Quantizer.h>
13 #include <torch/custom_class.h>
14 #include <torch/library.h>
15
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/_saturate_weight_to_fp16.h>
21 #include <ATen/ops/_saturate_weight_to_fp16_native.h>
22 #include <ATen/ops/dequantize.h>
23 #include <ATen/ops/empty.h>
24 #include <ATen/ops/quantize_per_tensor.h>
25 #include <ATen/ops/zeros.h>
26 #endif
27
28 #include <c10/util/irange.h>
29
30 #include <algorithm>
31 #include <utility>
32 #include <vector>
33
34 int register_linear_params();
35
36 #ifdef USE_FBGEMM
37 namespace {
38 // Calculate the column offsets.
39 // Note this includes the sum of the columns as well as the scalar term
40 // B_zero_point * K, whereas the row_offsets created by
41 // PackAWithQuantRowOffset is only the sum of the A rows.
calc_col_offsets_transpose(int K,int N,const int8_t * Bint8,int32_t * B_zero_point,int32_t * col_offsets,c10::QScheme qtype)42 void calc_col_offsets_transpose(
43 int K,
44 int N,
45 const int8_t* Bint8,
46 int32_t* B_zero_point,
47 int32_t* col_offsets,
48 c10::QScheme qtype) {
49 for (const auto i : c10::irange(N)) {
50 int32_t sum = 0;
51 for (const auto j : c10::irange(K)) {
52 sum += Bint8[i * K + j];
53 }
54 if (qtype == c10::kPerTensorAffine) {
55 col_offsets[i] = sum - B_zero_point[0] * K;
56 } else {
57 col_offsets[i] = sum - B_zero_point[i] * K;
58 }
59 }
60 }
61 } // namespace
62
prepack(at::Tensor weight,std::optional<at::Tensor> bias)63 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeight::prepack(
64 // NOLINTNEXTLINE(performance-unnecessary-value-param)
65 at::Tensor weight,
66 std::optional<at::Tensor> bias) {
67 TORCH_CHECK(
68 weight.dim() == 2,
69 "The weight tensor for quantized::linear_prepack (fbgemm) should"
70 " be 2-dimensional.");
71
72 auto N = weight.size(0);
73 auto K = weight.size(1);
74
75 // TODO: contiguous is called for further JIT optimizations.
76 auto weight_contig = weight.contiguous();
77 const auto qtype = weight.qscheme();
78 std::vector<int32_t> weight_zero_points_int32(1, 0);
79 if (qtype == c10::kPerTensorAffine) {
80 weight_zero_points_int32[0] = {static_cast<int32_t>(weight.q_zero_point())};
81 } else if (qtype == c10::kPerChannelAffine) {
82 weight_zero_points_int32.resize(N, 0);
83 for (const auto i : c10::irange(N)) {
84 weight_zero_points_int32[i] =
85 weight.q_per_channel_zero_points()[i].item<int32_t>();
86 }
87 }
88 std::vector<float> weight_scales_float(1, 0.0);
89 if (qtype == c10::kPerTensorAffine) {
90 weight_scales_float[0] = {static_cast<float>(weight.q_scale())};
91 } else if (qtype == c10::kPerChannelAffine) {
92 weight_scales_float.resize(N, 0.0);
93 for (const auto i : c10::irange(N)) {
94 weight_scales_float[i] = weight.q_per_channel_scales()[i].item<float>();
95 }
96 }
97
98 int8_t* weight_ptr_int8 =
99 reinterpret_cast<int8_t*>(weight_contig.data_ptr<c10::qint8>());
100
101 std::vector<int32_t> col_offsets(N);
102 calc_col_offsets_transpose(
103 /*K=*/static_cast<int>(K),
104 /*N=*/static_cast<int>(N),
105 /*Bint8=*/weight_ptr_int8,
106 /*B_zero_point=*/weight_zero_points_int32.data(),
107 /*col_offsets=*/col_offsets.data(),
108 /*qtype=*/qtype);
109
110 std::optional<at::Tensor> bias_contig;
111 if (bias.has_value()) {
112 at::Tensor bias_vec = bias.value();
113 TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
114 TORCH_CHECK(
115 bias_vec.size(0) == N,
116 "bias should have N elements: " + std::to_string(N));
117 bias_contig = bias->contiguous();
118 }
119 auto ret_ptr = c10::make_intrusive<PackedLinearWeight>(
120 std::make_unique<fbgemm::PackBMatrix<int8_t>>(
121 /*trans=*/fbgemm::matrix_op_t::Transpose,
122 /*nRow=*/K,
123 /*nCol=*/N,
124 /*smat=*/weight_ptr_int8,
125 /*ld=*/K,
126 /*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
127 /*groups=*/1),
128 bias_contig,
129 col_offsets,
130 weight_scales_float,
131 weight_zero_points_int32,
132 qtype);
133 return ret_ptr;
134 }
135 #endif // USE_FBGEMM
136
137 #ifdef USE_PYTORCH_QNNPACK
prepack(at::Tensor weight,std::optional<at::Tensor> bias_in)138 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsQnnp::prepack(
139 // NOLINTNEXTLINE(performance-unnecessary-value-param)
140 at::Tensor weight,
141 std::optional<at::Tensor> bias_in) {
142 TORCH_CHECK(
143 weight.dim() == 2,
144 "quantized::linear_prepack (qnnpack): Weight tensor rank should be == 2");
145
146 int64_t rows_w = weight.size(0);
147 at::Tensor bias_fp32;
148 if (bias_in.has_value()) {
149 bias_fp32 = bias_in.value();
150 } else {
151 bias_fp32 = at::zeros(rows_w, weight.options().dtype(at::kFloat));
152 }
153 TORCH_CHECK(
154 !bias_fp32.defined() ||
155 (bias_fp32.ndimension() == 1 && bias_fp32.size(0) == rows_w),
156 "quantized::linear_prepack (qnnpack): Given weight of size ",
157 weight.sizes(),
158 ", expected bias to be 1-dimensional with ",
159 rows_w,
160 " elements",
161 ", but got bias of size ",
162 bias_fp32.sizes(),
163 " instead");
164
165 at::Tensor weight_contig = weight.contiguous();
166 auto [w_zero_points, w_scales] =
167 make_zero_points_and_scales_tensor(weight_contig);
168
169 at::native::initQNNPACK();
170
171 // We set the pre-packed linear weights to nullptr below as we call pre-pack
172 // during the first invocation of operator run. Refer to Linear.cpp for more
173 // details. TODO Update to actually call pre-pack here once bias is removed
174 // from pre-packing step.
175 auto wt_ptr = c10::make_intrusive<PackedLinearWeightsQnnp>(
176 nullptr,
177 weight_contig, /* int8_t weight */
178 bias_fp32.contiguous(), /* fp32 bias */
179 std::nullopt, /* input_scale */
180 w_scales,
181 std::move(w_zero_points));
182 return wt_ptr;
183 }
184 #endif // USE_PYTORCH_QNNPACK
185
186 #ifdef USE_FBGEMM
187
prepack(at::Tensor weight,std::optional<at::Tensor> bias)188 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightFp16::prepack(
189 // NOLINTNEXTLINE(performance-unnecessary-value-param)
190 at::Tensor weight,
191 // NOLINTNEXTLINE(performance-unnecessary-value-param)
192 std::optional<at::Tensor> bias) {
193
194 weight = at::_saturate_weight_to_fp16(weight);
195
196 const int64_t K = weight.size(1);
197 const int64_t N = weight.size(0);
198 at::Tensor weight_contig = weight.contiguous();
199 float* weight_contig_ptr = weight_contig.data_ptr<float>();
200
201 // TODO(mingzhe09088):
202 // Consider using a functor here in PackedGemmMatrixFP16
203 // Comments from (XQ): Not entirely sure this make_unique is safe.
204 // make_unique is created with regular "new", and freed through
205 // TypeMetaData::deleteFn in this function. This is perfectly fine if the
206 // tensors are created and freed within this translation unit. It might be
207 // very problematic if that tensor flows across dll boundaries.
208 auto ptr = c10::make_intrusive<PackedLinearWeightFp16>(
209 std::make_unique<fbgemm::PackedGemmMatrixFP16>(
210 fbgemm::matrix_op_t::Transpose, K, N, 1, weight_contig_ptr),
211 bias);
212 return ptr;
213 }
214 #endif // USE_FBGEMM
215
216 #if AT_MKLDNN_ENABLED()
prepack(at::Tensor weight,std::optional<at::Tensor> bias)217 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsOnednn::prepack(
218 at::Tensor weight,
219 std::optional<at::Tensor> bias) {
220 TORCH_CHECK(
221 weight.dim() == 2,
222 "The weight tensor for quantized::linear_prepack (onednn) should"
223 " be 2-dimensional.");
224 // Weight
225 std::vector<int64_t> dims = weight.sizes().vec();
226 auto N = weight.size(0);
227 std::vector<int32_t> wgt_zero_points;
228 ideep::scale_t wgt_scales;
229 const auto qtype = weight.qscheme();
230 if (qtype == c10::kPerTensorAffine) {
231 TORCH_CHECK(
232 weight.q_zero_point() == 0,
233 "quantized::linear_prepack: ONEDNN only supports symmetric quantization of weight,"
234 " whose zero point must be 0, but got ", weight.q_zero_point());
235 wgt_zero_points = std::vector<int32_t>(1, weight.q_zero_point());
236 wgt_scales = ideep::scale_t(1, 1.0/weight.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal
237 } else if (qtype == c10::kPerChannelAffine) {
238 wgt_zero_points.resize(N);
239 wgt_scales.resize(N);
240 for (int i = 0; i < N; ++i) {
241 wgt_zero_points[i] = weight.q_per_channel_zero_points()[i].item<int32_t>();
242 TORCH_CHECK(
243 wgt_zero_points[i] == 0,
244 "quantized::linear_prepack: ONEDNN only supports symmetric quantization of weight,"
245 " whose zero point must be 0, but got ", wgt_zero_points[i], ", at index ", i);
246 wgt_scales[i] = 1.0f / weight.q_per_channel_scales()[i].item<float>(); // Scales of ONEDNN and PyTorch are reciprocal
247 }
248 } else {
249 TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype));
250 }
251
252 // Prepack weight
253 auto weight_copy = weight.clone();
254 ideep::tensor wgt = ideep::tensor({dims, dnnl::memory::data_type::s8}, weight_copy.data_ptr());
255 wgt.transpose_(0, 1); // ONEDNN requires transposed weight
256 auto src_dims = ideep::dims(); // Unknown when prepacking
257 ideep::attr_t op_attr;
258 op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
259 auto w_desc = ideep::matmul_forward::expected_weights_desc(wgt.get_dims(), src_dims, dnnl::memory::data_type::s8,
260 dnnl::memory::data_type::u8, op_attr);
261 ideep::tensor exp_wgt(w_desc);
262 exp_wgt.feed_from(wgt);
263 ideep::tensor * packed_weight_p = new ideep::tensor(std::move(exp_wgt));
264 packed_weight_p->set_scale(wgt_scales);
265 packed_weight_p->set_zero_point(wgt_zero_points);
266 std::unique_ptr<ideep::tensor> weight_ptr(packed_weight_p);
267 // Bias
268 std::optional<ideep::tensor> onednn_bias{std::nullopt};
269 if (bias.has_value()) {
270 auto& b = bias.value();
271 auto bias_size = b.sizes().vec();
272 bias_size.insert(bias_size.begin(), 1);
273 TORCH_CHECK(
274 bias_size[1] == weight_ptr->get_dim(1),
275 "bias should have N elements: ",
276 std::to_string(weight_ptr->get_dim(1)),
277 ", but got ", bias_size[1]);
278 auto bias_desc = ideep::tensor::desc(bias_size, dnnl::memory::data_type::f32);
279 ideep::tensor packed_bias;
280 packed_bias.init(bias_desc, b.data_ptr());
281 onednn_bias = std::optional<ideep::tensor>(packed_bias);
282 }
283 auto ret_ptr = c10::make_intrusive<PackedLinearWeightsOnednn>(
284 PackedLinearWeightsOnednn{
285 std::move(weight_ptr),
286 onednn_bias,
287 weight,
288 bias});
289 return ret_ptr;
290 }
291
pack_weight_to_onednn_tensor(const at::Tensor & weight,std::optional<torch::List<int64_t>> & input_shape)292 inline at::Tensor pack_weight_to_onednn_tensor(
293 const at::Tensor& weight,
294 std::optional<torch::List<int64_t>>& input_shape) {
295 std::vector<int64_t> w_dims = weight.sizes().vec();
296 ideep::tensor wei = ideep::tensor({w_dims, dnnl::memory::data_type::s8}, weight.data_ptr());
297 wei.transpose_(0, 1); // oneDNN requires transposed weight
298 ideep::dims input_dims = input_shape.has_value() ? input_shape.value().vec() : ideep::dims();
299 ideep::attr_t op_attr;
300 op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
301 auto w_desc = ideep::matmul_forward::expected_weights_desc(
302 wei.get_dims(), input_dims, dnnl::memory::data_type::s8, dnnl::memory::data_type::u8, op_attr);
303 ideep::tensor expected_weight(w_desc);
304 expected_weight.feed_from(wei);
305 auto packed_weight = at::native::new_with_itensor_mkldnn(
306 std::move(expected_weight),
307 c10::optTypeMetaToScalarType(weight.options().dtype_opt()),
308 weight.options().device_opt());
309 return packed_weight;
310 }
311
312 #endif // #if AT_MKLDNN_ENABLED()
313
314 namespace at::native {
315
_saturate_weight_to_fp16(const Tensor & weight)316 at::Tensor _saturate_weight_to_fp16(const Tensor& weight) {
317 Tensor weight_contig = weight.contiguous();
318 float* weight_contig_ptr = weight_contig.data_ptr<float>();
319 quant_utils::HandleWeightsSaturation(weight.size(0) * weight.size(1), weight_contig_ptr);
320 return weight;
321 }
322
323 template <class... Inputs>
makeStack(Inputs &&...inputs)324 inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
325 return {std::forward<Inputs>(inputs)...};
326 }
327
328 template <class... Args>
callOpByHandle(const c10::OperatorHandle & op,Args...args)329 inline std::vector<c10::IValue> callOpByHandle(
330 const c10::OperatorHandle& op,
331 Args... args) {
332 auto stack = makeStack(std::forward<Args>(args)...);
333 c10::Dispatcher::singleton().callBoxed(op, &stack);
334 return stack;
335 }
336
337 template <class... Args>
callOpByName(const char * func_name,const char * overload_name,Args...args)338 inline std::vector<c10::IValue> callOpByName(
339 const char* func_name,
340 const char* overload_name,
341 Args... args) {
342 const std::optional<c10::OperatorHandle> op_handle =
343 c10::Dispatcher::singleton().findSchema({func_name, overload_name});
344 assert(op_handle.has_value());
345 return callOpByHandle(op_handle.value(), std::forward<Args>(args)...);
346 }
347
348 at::Tensor wrapped_quantized_linear(
349 at::Tensor input,
350 const at::Tensor& input_scale,
351 const at::Tensor& input_zero_point,
352 const at::Tensor& weight,
353 const at::Tensor& weight_scale,
354 const at::Tensor& weight_zero_point,
355 const at::Tensor& bias,
356 const at::Tensor& output_scale,
357 const at::Tensor& output_zero_point,
358 [[maybe_unused]] const int64_t out_channel);
359
wrapped_quantized_linear(at::Tensor input,const at::Tensor & input_scale,const at::Tensor & input_zero_point,const at::Tensor & weight,const at::Tensor & weight_scale,const at::Tensor & weight_zero_point,const at::Tensor & bias,const at::Tensor & output_scale,const at::Tensor & output_zero_point,const int64_t out_channel)360 at::Tensor wrapped_quantized_linear(
361 // NOLINTNEXTLINE(performance-unnecessary-value-param)
362 at::Tensor input,
363 const at::Tensor& input_scale,
364 const at::Tensor& input_zero_point,
365 const at::Tensor& weight,
366 const at::Tensor& weight_scale,
367 const at::Tensor& weight_zero_point,
368 const at::Tensor& bias,
369 const at::Tensor& output_scale,
370 const at::Tensor& output_zero_point,
371 [[maybe_unused]] const int64_t out_channel) {
372 //This op does four things:
373 // 1. Use quantize_per_tensor to quantize the input
374 // 2. Use quantized::linear_prepack to prepack the weight and bias
375 // 3. Use quantized::linear to do the int8 linear quantized computation
376 // 4. Use dequantize to dequantize the result of quantized::linear
377 // The reason we do this is because we want to have such wrapper op to
378 // bypass the issue from torch.export
379 #ifdef USE_FBGEMM
380 auto qw = at::quantize_per_tensor(
381 weight, weight_scale, weight_zero_point, c10::ScalarType::QInt8);
382 auto op = Dispatcher::singleton()
383 .findSchemaOrThrow("quantized::linear_prepack", "")
384 .typed<c10::intrusive_ptr<LinearPackedParamsBase>(
385 at::Tensor, std::optional<at::Tensor>)>();
386 auto packed_params = op.call(qw, bias);
387
388 auto qx = at::quantize_per_tensor(
389 input, input_scale, input_zero_point, c10::ScalarType::QUInt8);
390
391 const auto scale_val = output_scale.item().toFloat();
392 const auto zero_point_val = output_zero_point.item().toLong();
393
394 auto result = callOpByName(
395 "quantized::linear", "", qx, packed_params, scale_val, zero_point_val);
396
397 return at::dequantize(result[0].toTensor());
398 #else // USE_FBGEMM
399 TORCH_CHECK(
400 false, "This PyTorch installation was not built with FBGEMM operators");
401 #endif // USE_FBGEMM
402 }
403
404 at::Tensor wrapped_quantized_linear_meta(
405 at::Tensor input,
406 [[maybe_unused]] const at::Tensor& input_scale,
407 [[maybe_unused]] const at::Tensor& input_zero_point,
408 const at::Tensor& weight,
409 [[maybe_unused]] const at::Tensor& weight_scale,
410 [[maybe_unused]] const at::Tensor& weight_zero_point,
411 [[maybe_unused]] const at::Tensor& bias,
412 [[maybe_unused]] const at::Tensor& output_scale,
413 [[maybe_unused]] const at::Tensor& output_zero_point,
414 [[maybe_unused]] const int64_t out_channel);
415
wrapped_quantized_linear_meta(at::Tensor input,const at::Tensor & input_scale,const at::Tensor & input_zero_point,const at::Tensor & weight,const at::Tensor & weight_scale,const at::Tensor & weight_zero_point,const at::Tensor & bias,const at::Tensor & output_scale,const at::Tensor & output_zero_point,const int64_t out_channel)416 at::Tensor wrapped_quantized_linear_meta(
417 // NOLINTNEXTLINE(performance-unnecessary-value-param)
418 at::Tensor input,
419 [[maybe_unused]] const at::Tensor& input_scale,
420 [[maybe_unused]] const at::Tensor& input_zero_point,
421 const at::Tensor& weight,
422 [[maybe_unused]] const at::Tensor& weight_scale,
423 [[maybe_unused]] const at::Tensor& weight_zero_point,
424 [[maybe_unused]] const at::Tensor& bias,
425 [[maybe_unused]] const at::Tensor& output_scale,
426 [[maybe_unused]] const at::Tensor& output_zero_point,
427 [[maybe_unused]] const int64_t out_channel) {
428 #ifdef USE_FBGEMM
429 const at::SymInt M = input.sym_size(0);
430 const at::SymInt N = weight.sym_size(0);
431 auto Y = at::empty_symint({M, N}, input.options().dtype(at::kFloat));
432 return Y;
433 #else // USE_FBGEMM
434 TORCH_CHECK(
435 false, "This PyTorch installation was not built with FBGEMM operators");
436 #endif // USE_FBGEMM
437 }
438
439 at::Tensor _wrapped_linear_prepack(const at::Tensor& weight,
440 const at::Tensor& weight_scale,
441 const at::Tensor& weight_zero_point,
442 const at::Tensor& bias);
443
_wrapped_linear_prepack(const at::Tensor & weight,const at::Tensor & weight_scale,const at::Tensor & weight_zero_point,const at::Tensor & bias)444 at::Tensor _wrapped_linear_prepack(const at::Tensor& weight,
445 const at::Tensor& weight_scale,
446 const at::Tensor& weight_zero_point,
447 const at::Tensor& bias) {
448 // This op does two things
449 // 1. Use quantize_per_tensor to quantize the weight
450 // 2. Use quantized::linear_prepack to prepack the weight and bias
451 // The reason we do this is because we want to have such wrapper op to
452 // save the quantized weight as constants for AOTI
453 #ifdef USE_FBGEMM
454 TORCH_CHECK(
455 weight.dim() == 2,
456 "fbgemm weight packing only packs matrices not vectors.");
457 auto qw = at::quantize_per_tensor(
458 weight, weight_scale, weight_zero_point, c10::ScalarType::QInt8);
459
460 auto op = Dispatcher::singleton()
461 .findSchemaOrThrow("quantized::linear_prepack", "")
462 .typed<c10::intrusive_ptr<LinearPackedParamsBase>(
463 at::Tensor, std::optional<at::Tensor>)>();
464 auto packed_params = op.call(qw, bias);
465
466 auto unique_ptr_wrapper =
467 std::make_unique<decltype(packed_params)>(std::move(packed_params));
468 auto ret = cpp_custom_type_hack::create(
469 std::move(unique_ptr_wrapper), weight.options());
470 return ret;
471 #else // USE_FBGEMM
472 TORCH_CHECK(
473 false, "This PyTorch installation was not built with FBGEMM operators");
474 #endif // USE_FBGEMM
475 }
476
477 at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale,
478 const at::Tensor& input_zero_point,
479 const at::Tensor& packed_weight,
480 const at::Tensor& output_scale,
481 const at::Tensor& output_zero_point,
482 [[maybe_unused]] const int64_t out_channel);
483
_wrapped_quantized_linear_prepacked(const at::Tensor & input,const at::Tensor & input_scale,const at::Tensor & input_zero_point,const at::Tensor & packed_weight,const at::Tensor & output_scale,const at::Tensor & output_zero_point,const int64_t out_channel)484 at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale,
485 const at::Tensor& input_zero_point,
486 const at::Tensor& packed_weight,
487 const at::Tensor& output_scale,
488 const at::Tensor& output_zero_point,
489 [[maybe_unused]] const int64_t out_channel) {
490 // This op is similar to wrapped_quantized_linear, but it takes the prepacked weight
491 #ifdef USE_FBGEMM
492 auto qx = at::quantize_per_tensor(
493 input, input_scale, input_zero_point, c10::ScalarType::QUInt8);
494 const auto scale_val = output_scale.item().toFloat();
495 const auto zero_point_val = output_zero_point.item().toLong();
496 auto packed_weight_ptr =
497 // @lint-ignore CLANGTIDY facebook-hte-Deprecated
498 cpp_custom_type_hack::cast<c10::intrusive_ptr<LinearPackedParamsBase>>(
499 packed_weight);
500 auto result = callOpByName(
501 "quantized::linear", "", qx, packed_weight_ptr, scale_val, zero_point_val);
502
503 return at::dequantize(result[0].toTensor());
504 #else // USE_FBGEMM
505 TORCH_CHECK(
506 false, "This PyTorch installation was not built with FBGEMM operators");
507 #endif // USE_FBGEMM
508 }
509
510 at::Tensor _wrapped_linear_prepack_meta(const at::Tensor& weight,
511 [[maybe_unused]] const at::Tensor& weight_scale,
512 [[maybe_unused]] const at::Tensor& weight_zero_point,
513 [[maybe_unused]] const at::Tensor& bias);
514
_wrapped_linear_prepack_meta(const at::Tensor & weight,const at::Tensor & weight_scale,const at::Tensor & weight_zero_point,const at::Tensor & bias)515 at::Tensor _wrapped_linear_prepack_meta(const at::Tensor& weight,
516 [[maybe_unused]] const at::Tensor& weight_scale,
517 [[maybe_unused]] const at::Tensor& weight_zero_point,
518 [[maybe_unused]] const at::Tensor& bias) {
519 #ifdef USE_FBGEMM
520 TORCH_CHECK(
521 weight.dim() == 2,
522 "fbgemm weight packing only packs matrices not vectors.");
523 const at::SymInt M = weight.sym_size(0);
524 const at::SymInt N = weight.sym_size(1);
525 auto Y = at::empty_symint({M, N}, weight.options().dtype(at::kFloat));
526 return Y;
527 #else // USE_FBGEMM
528 TORCH_CHECK(
529 false, "This PyTorch installation was not built with FBGEMM operators");
530 #endif // USE_FBGEMM
531 }
532
533 at::Tensor _wrapped_quantized_linear_prepacked_meta(const at::Tensor& input,
534 [[maybe_unused]] const at::Tensor& input_scale,
535 [[maybe_unused]] const at::Tensor& input_zero_point,
536 [[maybe_unused]] const at::Tensor& packed_weight,
537 [[maybe_unused]] const at::Tensor& output_scale,
538 [[maybe_unused]] const at::Tensor& output_zero_point,
539 const int64_t out_channel);
540
_wrapped_quantized_linear_prepacked_meta(const at::Tensor & input,const at::Tensor & input_scale,const at::Tensor & input_zero_point,const at::Tensor & packed_weight,const at::Tensor & output_scale,const at::Tensor & output_zero_point,const int64_t out_channel)541 at::Tensor _wrapped_quantized_linear_prepacked_meta(const at::Tensor& input,
542 [[maybe_unused]] const at::Tensor& input_scale,
543 [[maybe_unused]] const at::Tensor& input_zero_point,
544 [[maybe_unused]] const at::Tensor& packed_weight,
545 [[maybe_unused]] const at::Tensor& output_scale,
546 [[maybe_unused]] const at::Tensor& output_zero_point,
547 const int64_t out_channel) {
548 #ifdef USE_FBGEMM
549 auto out_sizes = input.sym_sizes().vec();
550 TORCH_CHECK(
551 out_sizes.size() == 2,
552 "The dimension of weight tensor should be equal to 2");
553 out_sizes[out_sizes.size() - 1] = out_channel;
554
555 return at::empty_symint(out_sizes, input.options());
556 #else // USE_FBGEMM
557 TORCH_CHECK(
558 false, "This PyTorch installation was not built with FBGEMM operators");
559 #endif // USE_FBGEMM
560 }
561
562 namespace {
563
564 class QLinearPackWeightInt8 final {
565 public:
run(at::Tensor weight,std::optional<Tensor> bias)566 static c10::intrusive_ptr<LinearPackedParamsBase> run(
567 at::Tensor weight,
568 std::optional<Tensor> bias) {
569 auto& ctx = at::globalContext();
570
571 #ifdef USE_FBGEMM
572 if (ctx.qEngine() == at::QEngine::FBGEMM ||
573 ctx.qEngine() == at::QEngine::X86) {
574 return PackedLinearWeight::prepack(std::move(weight), std::move(bias));
575 }
576 #endif
577 #ifdef USE_PYTORCH_QNNPACK
578 if (ctx.qEngine() == at::QEngine::QNNPACK) {
579 return PackedLinearWeightsQnnp::prepack(
580 std::move(weight), std::move(bias));
581 }
582 #endif
583 #if AT_MKLDNN_ENABLED()
584 if (ctx.qEngine() == at::QEngine::ONEDNN) {
585 return PackedLinearWeightsOnednn::prepack(std::move(weight), std::move(bias));
586 }
587 #endif // #if AT_MKLDNN_ENABLED()
588 TORCH_CHECK(
589 false,
590 "Didn't find engine for operation quantized::linear_prepack ",
591 toString(ctx.qEngine()));
592 }
593 };
594
595 class QLinearPackWeightFp16 final {
596 public:
run(at::Tensor weight,std::optional<Tensor> bias)597 static c10::intrusive_ptr<LinearPackedParamsBase> run(
598 at::Tensor weight,
599 std::optional<Tensor> bias) {
600 auto& ctx = at::globalContext();
601 #ifdef USE_FBGEMM
602 // temporarily convert weight back to fp32, needs to be fixed
603 // after fbgemm fixes the interface for their prepacking op (take fp16 input0
604 weight = weight.to(ScalarType::Float);
605 if (ctx.qEngine() == at::QEngine::FBGEMM ||
606 ctx.qEngine() == at::QEngine::X86) {
607 return PackedLinearWeightFp16::prepack(
608 std::move(weight), std::move(bias));
609 }
610 #endif // USE_FBGEMM
611 #ifdef USE_PYTORCH_QNNPACK
612 if (ctx.qEngine() == at::QEngine::QNNPACK) {
613 TORCH_CHECK(
614 false,
615 "quantized::linear_prepack_fp16 is currently "
616 "not supported by QNNPACK");
617 }
618 #endif // USE_PYTORCH_QNNPACK
619 #if AT_MKLDNN_ENABLED()
620 if (ctx.qEngine() == at::QEngine::ONEDNN) {
621 TORCH_CHECK(
622 false,
623 "quantized::linear_prepack_fp16 is currently "
624 "not supported by ONEDNN");
625 }
626 #endif // #if AT_MKLDNN_ENABLED()
627 TORCH_CHECK(
628 false,
629 "Didn't find engine for operation quantized::linear_prepack_fp16 ",
630 toString(ctx.qEngine()));
631 }
632 };
633
634 class QLinearPackWeightInt8Legacy final {
635 public:
run(at::Tensor weight,std::optional<Tensor> bias)636 static Tensor run(
637 // NOLINTNEXTLINE(performance-unnecessary-value-param)
638 [[maybe_unused]] at::Tensor weight,
639 // NOLINTNEXTLINE(performance-unnecessary-value-param)
640 [[maybe_unused]] std::optional<Tensor> bias) {
641 TORCH_CHECK(false,
642 "This model uses an outdated version of quantized.linear_prepack. "
643 "Please re-export your model using the newer definitions in torch.jit.quantized");
644 }
645 };
646
647 class QLinearPackWeightFp16Legacy final {
648 public:
run(at::Tensor weight,std::optional<Tensor> bias)649 static Tensor run(
650 // NOLINTNEXTLINE(performance-unnecessary-value-param)
651 [[maybe_unused]] at::Tensor weight,
652 // NOLINTNEXTLINE(performance-unnecessary-value-param)
653 [[maybe_unused]] std::optional<Tensor> bias) {
654 TORCH_CHECK(false,
655 "This model uses an outdated version of quantized.linear_prepack_fp16. "
656 "Please re-export your model using the newer definitions in torch.jit.quantized");
657 }
658 };
659
660 class QLinearPackWeightInt8Onednn final {
661 public:
run(at::Tensor weight,std::optional<torch::List<int64_t>> input_shape)662 static at::Tensor run(
663 // NOLINTNEXTLINE(performance-unnecessary-value-param)
664 [[maybe_unused]] at::Tensor weight, // Not QTensor
665 // NOLINTNEXTLINE(performance-unnecessary-value-param)
666 [[maybe_unused]] std::optional<torch::List<int64_t>> input_shape) {
667 #if AT_MKLDNN_ENABLED()
668 return pack_weight_to_onednn_tensor(weight, input_shape);
669 #else
670 TORCH_CHECK(false, "Unimplemented as onednn is not available.");
671 #endif
672 }
673 };
674
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)675 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
676 register_linear_params();
677 m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8::run));
678 m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack_legacy"), TORCH_FN(QLinearPackWeightInt8Legacy::run));
679 }
680
TORCH_LIBRARY_IMPL(quantized,CPU,m)681 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
682 register_linear_params();
683 m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16::run));
684 m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run));
685 }
686
TORCH_LIBRARY_IMPL(_quantized,QuantizedCPU,m)687 TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
688 register_linear_params();
689 m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8::run));
690 }
691
TORCH_LIBRARY_IMPL(_quantized,CPU,m)692 TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
693 register_linear_params();
694 m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16::run));
695 m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run));
696 m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear));
697 m.impl(
698 TORCH_SELECTIVE_NAME("_quantized::_wrapped_linear_prepack"),
699 _wrapped_linear_prepack);
700 m.impl(
701 TORCH_SELECTIVE_NAME("_quantized::_wrapped_quantized_linear_prepacked"),
702 _wrapped_quantized_linear_prepacked);
703 }
704
TORCH_LIBRARY_IMPL(_quantized,Meta,m)705 TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
706 m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear_meta));
707 m.impl(
708 TORCH_SELECTIVE_NAME("_quantized::_wrapped_linear_prepack"),
709 _wrapped_linear_prepack_meta);
710 m.impl(
711 TORCH_SELECTIVE_NAME("_quantized::_wrapped_quantized_linear_prepacked"),
712 _wrapped_quantized_linear_prepacked_meta);
713 }
714
TORCH_LIBRARY_IMPL(onednn,CPU,m)715 TORCH_LIBRARY_IMPL(onednn, CPU, m) {
716 m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_prepack"), TORCH_FN(QLinearPackWeightInt8Onednn::run));
717 }
718
719 } // namespace
720 } // namespace at::native
721