xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
6 #include <ATen/native/quantized/PackedParams.h>
7 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
8 #include <ATen/native/quantized/cpu/OnednnUtils.h>
9 #include <ATen/native/quantized/cpu/QuantUtils.h>
10 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
11 #include <torch/library.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #else
16 #include <ATen/ops/_empty_affine_quantized.h>
17 #include <ATen/ops/aminmax.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
20 #include <ATen/ops/fbgemm_linear_fp16_weight_native.h>
21 #include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h>
22 #include <ATen/ops/quantize_per_tensor.h>
23 #endif
24 
25 #include <c10/util/irange.h>
26 
27 #include <algorithm>
28 #include <string>
29 #include <type_traits>
30 
31 int register_linear_params();
32 
33 #ifdef USE_FBGEMM
34 template <bool ReluFused>
apply_dynamic_impl(at::Tensor input,bool reduce_range)35 at::Tensor PackedLinearWeight::apply_dynamic_impl(
36     at::Tensor input,
37     bool reduce_range) {
38   using at::Tensor;
39   // fp32 * int8 -> fp32 (with quantization on activation, and dequantization
40   // on the result).
41 
42   // We make a strong guarantee that models using these operators will have
43   // the same numerics across different machines. Therefore, we do not provide
44   // a fallback path and rather fail loudly if we cannot run FBGEMM.
45   TORCH_CHECK(
46       fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
47 
48   // TODO: contiguous is called for further jit optimizations.
49   auto input_contig = input.contiguous();
50   const auto* input_ptr = input_contig.const_data_ptr<float>();
51 
52   TORCH_CHECK(
53       input.dim() >= 2,
54       "The dimension of input tensor should be larger than or equal to 2");
55   // C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
56   // matrices, respectively.
57   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
58   int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
59 
60   auto packB = w.get();
61 
62   int64_t N = static_cast<int64_t>(packB->numCols());
63   int64_t K = input.size(input.dim() - 1);
64   TORCH_CHECK(
65       K == static_cast<int64_t>(packB->numRows()),
66       "The number of rows in the packB should be equal to K: " +
67           std::to_string(K));
68 
69   // Calculate statistics for quantization of the input Tensor
70   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
71   float x_min, x_max;
72   fbgemm::FindMinMax(
73       /*m=*/input_ptr,
74       /*min=*/&x_min,
75       /*max=*/&x_max,
76       /*len=*/input.numel());
77 
78   // Input tensor is quantized as 8-bit unsigned values
79   static constexpr int precision = 8;
80   static constexpr bool is_signed = false;
81 
82   // Calculate scale and zero point for quantization of input tensor
83   auto q_params = quant_utils::ChooseQuantizationParams(
84       /*min=*/x_min,
85       /*max=*/x_max,
86       /*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
87       /*qmax=*/
88       is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
89       /*preserve_sparsity=*/false,
90       /*force_scale_power_of_two=*/false,
91       /*reduce_range=*/reduce_range);
92 
93   q_params.precision = precision;
94 
95   // ReQuantizeForFloat requires pointers to the zero point values,
96   // since in the case of rowwise quantization these will be arrays rather
97   // than scalars. But in this case, we're doing whole-tensor quantization so
98   // we just pass a pointer to the scale values (and internally
99   // ReQuantizeForFloat won't index past 0.
100 
101   const float* bias_ptr = nullptr;
102   at::Tensor bias_vec;
103   if (bias_.has_value()) {
104     bias_vec = bias_.value();
105     TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
106     TORCH_CHECK(
107         bias_vec.size(0) == N,
108         "bias should have N elements: " + std::to_string(N));
109     // TODO: contiguous is called for further jit optimizations.
110     auto bias_contig = bias_vec.contiguous();
111     bias_ptr = bias_contig.data_ptr<float>();
112   }
113   // The resulting matrix here is 2-D, let's view it with the original
114   // left hand dimensions of the input. Here are two examples:
115   // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
116   // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
117   std::vector<int64_t> out_sizes = input.sizes().vec();
118   out_sizes.back() = N;
119   // Allocate output Tensor and a buffer for fbgemmPacked to use
120   auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
121   auto buffer = at::empty_like(
122       output,
123       output.options().dtype(at::kInt),
124       LEGACY_CONTIGUOUS_MEMORY_FORMAT);
125 
126   int num_tasks = at::get_num_threads();
127   at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
128     // This operation does the following:
129     // 1) Quantizes the input matrix given the statistics we've calculated
130     // above
131     // 2) Creates a "row buffer" vector with offset values that must be
132     // added
133     //    to the integer matrix multiplication operation to ensure
134     //    correctness. This "row buffer" is also called the row offset, and it
135     //    is needed when we use affine quantization for weights.
136     // 3) Packs the resulting quantized matrix into vector-register and cache
137     //    friendly tiles.
138     //
139     //  Note this is not executed eagerly, but rather within the fbgemmPacked
140     //  call below.
141 
142     fbgemm::PackAWithQuantRowOffset<uint8_t> packA(
143         /*trans=*/fbgemm::matrix_op_t::NoTranspose,
144         /*nRow=*/M,
145         /*nCol=*/K,
146         /*smat=*/input_ptr,
147         /*ld=*/K,
148         /*pmat=*/nullptr, // Currently, packA manages ownership of `pmat`.
149         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
150         /*scale=*/q_params.scale,
151         /*zero_pt=*/q_params.zero_point);
152     // TODO: Consider a way to pre-allocate and reuse
153     // pmat buffer.
154 
155     // This is the end of the pipeline, pass the resulting matrix through.
156     fbgemm::DoNothing<float, float> doNothingObj{};
157 
158     for (const auto task_id : c10::irange(begin, end)) {
159       if (q_scheme == c10::kPerTensorAffine) {
160         // Process the per tensor quantization.
161         //
162         // After the uint8 * int8 matrix multiplication is performed, this
163         // operation does:
164         //  1) Add in row and column offsets to the rows and columns,
165         //  respectively.
166         //  2) Dequantize the results into floating point.
167         //  3) Add in the bias term.
168         fbgemm::ReQuantizeForFloat<ReluFused> outputProcObj(
169             /*nextop=*/doNothingObj,
170             /*Aq_scale=*/q_params.scale,
171             /*Bq_scale=*/w_scale.data(),
172             /*Aq_zero_point=*/q_params.zero_point,
173             /*Bq_zero_point=*/w_zp.data(),
174             /*row_offsets=*/packA.getRowOffsetBuffer(),
175             /*col_offsets=*/col_offsets.data(),
176             /*bias=*/bias_ptr,
177             /*nCol=*/N);
178 
179         // Do the GEMM
180         fbgemm::fbgemmPacked(
181             /*packA=*/packA,
182             /*packB=*/*packB,
183             /*C=*/output.data_ptr<float>(),
184             /*C_buffer=*/buffer.data_ptr<int32_t>(),
185             /*ldc=*/N,
186             /*outProcess=*/outputProcObj,
187             /*thread_id=*/task_id,
188             /*num_threads=*/num_tasks);
189 
190       } else if (q_scheme == c10::kPerChannelAffine) {
191         // Process the per channel quantization.
192         //
193         // After the uint8 * int8 matrix multiplication is performed, this
194         // operation does:
195         //  1) Add in row and column offsets to the rows and columns,
196         //  respectively.
197         //  2) Dequantize the results into floating point.
198         //  3) Add in the bias term.
199         fbgemm::ReQuantizeForFloat<
200             ReluFused,
201             fbgemm::QuantizationGranularity::OUT_CHANNEL>
202             outputProcObj(
203                 /*nextop=*/doNothingObj,
204                 /*Aq_scale=*/q_params.scale,
205                 /*Bq_scale=*/w_scale.data(),
206                 /*Aq_zero_point=*/q_params.zero_point,
207                 /*Bq_zero_point=*/w_zp.data(),
208                 /*row_offsets=*/packA.getRowOffsetBuffer(),
209                 /*col_offsets=*/col_offsets.data(),
210                 /*bias=*/bias_ptr,
211                 /*nCol=*/N);
212 
213         // Do the GEMM
214         fbgemm::fbgemmPacked(
215             /*packA=*/packA,
216             /*packB=*/*packB,
217             /*C=*/output.data_ptr<float>(),
218             /*C_buffer=*/buffer.data_ptr<int32_t>(),
219             /*ldc=*/N,
220             /*outProcess=*/outputProcObj,
221             /*thread_id=*/task_id,
222             /*num_threads=*/num_tasks);
223       }
224     }
225   });
226 
227   return output;
228 }
229 
apply_dynamic(at::Tensor input,bool reduce_range)230 at::Tensor PackedLinearWeight::apply_dynamic(
231     at::Tensor input,
232     bool reduce_range) {
233   return apply_dynamic_impl</*ReluFused=*/false>(
234       std::move(input), reduce_range);
235 }
236 
apply_dynamic_relu(at::Tensor input,bool reduce_range)237 at::Tensor PackedLinearWeight::apply_dynamic_relu(
238     at::Tensor input,
239     bool reduce_range) {
240   return apply_dynamic_impl</*ReluFused=*/true>(std::move(input), reduce_range);
241 }
242 
243 #endif // USE_FBGEMM
244 
245 #ifdef USE_PYTORCH_QNNPACK
246 template <bool ReluFused>
apply_dynamic_impl(at::Tensor input,bool reduce_range)247 at::Tensor PackedLinearWeightsQnnp::apply_dynamic_impl(
248     at::Tensor input,
249     bool reduce_range) {
250   if (reduce_range) {
251     TORCH_WARN_ONCE("Currently, qnnpack incorrectly ignores reduce_range when it is set to true; this may change in a future release.");
252   }
253 
254   using at::Tensor;
255   TORCH_CHECK(
256       input.dim() >= 2,
257       "The dimension of input tensor should be larger than or equal to 2");
258   auto input_contig = input.contiguous();
259   // C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
260   // matrices, respectively.
261 
262   // Weight packing is not thread safe
263   std::lock_guard<std::mutex> lock(qnnp_mutex_);
264   auto packB = w.get();
265   size_t rows_w = bias_.size(0);
266   size_t cols_w = input_contig.size(input_contig.dim() - 1);
267 
268   at::Tensor bias_vec = bias_;
269 
270   TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
271 
272   auto bias_contig = bias_vec.contiguous();
273   const float* bias_ptr = bias_contig.const_data_ptr<float>();
274 
275   // Calculate statistics for quantization of input Tensor
276   // TODO: optimized kernel
277   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
278   float x_min;
279   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
280   float x_max;
281   if (input.numel() > 0) {
282     x_min = input_contig.min().item<float>();
283     x_max = input_contig.max().item<float>();
284   } else {
285     // On empty input, no output data will be generated,
286     // so use arbitrary qparams.
287     x_min = 0;
288     x_max = 0;
289   }
290 
291   auto q_params = quant_utils::ChooseQuantizationParams(
292       /*min=*/x_min,
293       /*max=*/x_max,
294       /*qmin=*/0,
295       /*qmax=*/255);
296   float* weight_scales_data = w_scales.data_ptr<float>();
297 
298   if (!input_scale.has_value() || input_scale.value() != q_params.scale) {
299     generate_requantization_scales(
300         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
301         w_scales,
302         q_params.scale,
303         1.f,
304         requantization_scales);
305   }
306 
307   if (!input_scale.has_value()) {
308     // Get the original weight and adjust it to uint8 from int8
309     auto weight_contig = orig_weight;
310 
311     // TODO(kimishpatel), we are allocating affine_quantized regardless of per
312     // channel or not. This allocation is actually used only for packing weight
313     // and thus will be freed. Still we should be consistent. Fix this.
314     Tensor qnnp_weight = at::_empty_affine_quantized(
315         weight_contig.sizes(),
316         at::device(c10::kCPU).dtype(c10::kQUInt8),
317         weight_scales_data[0],
318         w_zero_points[0]);
319     auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>();
320     int8_t* w_data = (int8_t*)weight_contig.data_ptr<c10::qint8>();
321     auto wt_numel = weight_contig.numel();
322     for (const auto i : c10::irange(wt_numel)) {
323       qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
324     }
325 
326     // Pass in nullptr for bias, as we pass FP32 bias to run function.
327     w.reset();
328     w = std::make_unique<qnnpack::PackBMatrix>(
329         cols_w /* input_channels */,
330         rows_w /* output_channels */,
331         w_zero_points.data(),
332         requantization_scales.data(),
333         (uint8_t*)qnnp_w_data,
334         nullptr);
335     packB = w.get();
336     if (at::globalContext().releaseWeightsWhenPrepacking()) {
337       // On mobile, we release the original weight by resetting the
338       // intrusive_ptr. Calling unpack after this will throw an assertion.
339       orig_weight.reset();
340     }
341   }
342 
343   // Update the input scale to not pack weights again.
344   // as well as to avoid repopulating requant scale if scale has not changed.
345   input_scale = q_params.scale;
346 
347   // Quantize input
348   Tensor q_input = at::quantize_per_tensor(
349       input_contig, q_params.scale, q_params.zero_point, c10::kQUInt8);
350 
351   // The resulting matrix here is 2-D, let's view it with the original
352   // left hand dimensions of the input. Here are two examples:
353   // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
354   // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
355   std::vector<int64_t> out_sizes = input.sizes().vec();
356   out_sizes.back() = rows_w;
357 
358   auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
359 
360   size_t rows_input = 1;
361   size_t cols_input = input_contig.size(input_contig.dim() - 1);
362   for (const auto i : c10::irange(input_contig.dim() - 1)) {
363     rows_input *= input_contig.size(i);
364   }
365   pytorch_qnnp_status runStatus = qnnpack::qnnpackLinearDynamic(
366       rows_input /* batch_size */,
367       cols_input /* input_channels */,
368       rows_w /* output_channels */,
369       q_input.q_zero_point(),
370       w_zero_points.data(),
371       /* for dynamic should really be called dequant scale */
372       requantization_scales.data(),
373       (uint8_t*)q_input.data_ptr<c10::quint8>(),
374       cols_input /* input_stride */,
375       packB->getPackedWeights(),
376       bias_ptr,
377       output.data_ptr<float>(),
378       rows_w /* output_stride */,
379       caffe2::pthreadpool_() /* threadpool */);
380 
381   TORCH_INTERNAL_ASSERT(
382       runStatus == pytorch_qnnp_status_success,
383       "failed to run QNNPACK Linear operator");
384 
385   // Call the relu operator here until qlinear dynamic in QNNPACK
386   // supports it natively.
387   if (ReluFused) {
388     output.relu_();
389   }
390   return output;
391 }
392 
apply_dynamic(at::Tensor input,bool reduce_range)393 at::Tensor PackedLinearWeightsQnnp::apply_dynamic(
394     at::Tensor input,
395     bool reduce_range) {
396   return apply_dynamic_impl</*ReluFused=*/false>(std::move(input), reduce_range);
397 }
398 
apply_dynamic_relu(at::Tensor input,bool reduce_range)399 at::Tensor PackedLinearWeightsQnnp::apply_dynamic_relu(
400     at::Tensor input,
401     bool reduce_range ) {
402   return apply_dynamic_impl</*ReluFused=*/true>(std::move(input), reduce_range);
403 }
404 
405 #endif // USE_PYTORCH_QNNPACK
406 
407 #ifdef USE_FBGEMM
408 
409 template <bool ReluFused>
apply_dynamic_impl(const at::Tensor & input,at::Tensor & output)410 at::Tensor& PackedLinearWeightFp16::apply_dynamic_impl(
411     const at::Tensor& input,
412     at::Tensor& output) {
413   const at::Tensor input_contig = input.contiguous();
414   const float* input_ptr = input_contig.const_data_ptr<float>();
415 
416   auto& packed_weight_fp16 = *w;
417 
418   TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows())
419   TORCH_CHECK(input.dim() >= 2);
420 
421   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
422   const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
423   const int64_t N = packed_weight_fp16.numCols();
424   std::vector<int64_t> output_sizes = input.sizes().vec();
425   TORCH_CHECK(!output_sizes.empty())
426   output_sizes.back() = N;
427   // Resize output Tensor
428   output.resize_(output_sizes);
429 
430   auto output_data = output.data_ptr<float>();
431 
432   int num_tasks = at::get_num_threads();
433   at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
434     for (const auto task_id : c10::irange(begin, end)) {
435       // Call the fp16 gemm interface
436       fbgemm::cblas_gemm_compute(
437           /*transa=*/fbgemm::matrix_op_t::NoTranspose,
438           /*m=*/static_cast<int>(M),
439           /*A=*/input_ptr,
440           /*Bp=*/packed_weight_fp16,
441           /*beta=*/0.0f,
442           /*C=*/output_data,
443           /*thread_id=*/static_cast<int>(task_id),
444           /*num_threads=*/num_tasks);
445     }
446   });
447 
448   // Add bias term
449   if (bias_.has_value()) {
450     TORCH_CHECK(bias_->dim() == 1);
451     output.add_(*bias_);
452   }
453 
454   return output;
455 }
456 
apply_dynamic(at::Tensor input,bool)457 at::Tensor PackedLinearWeightFp16::apply_dynamic(
458     at::Tensor input,
459     bool /* reduce_range */) {
460   at::Tensor output = at::empty({0}, input.options().dtype(at::kFloat));
461   return apply_dynamic_impl</*ReluFused=*/false>(input, output);
462 }
463 
apply_dynamic_relu(at::Tensor input,bool)464 at::Tensor PackedLinearWeightFp16::apply_dynamic_relu(
465     at::Tensor input,
466     bool /* reduce_range */) {
467   at::Tensor output = at::empty({0}, input.options().dtype(at::kFloat));
468   return apply_dynamic_impl</*ReluFused=*/true>(input, output);
469 }
470 
apply_dynamic_out(const at::Tensor & input,at::Tensor & output,bool)471 at::Tensor& PackedLinearWeightFp16::apply_dynamic_out(
472     const at::Tensor& input,
473     at::Tensor& output,
474     bool /* reduce_range */) {
475   TORCH_CHECK((output.device() == c10::kCPU) && (output.dtype() == at::kFloat));
476   return apply_dynamic_impl<false>(input, output);
477 }
478 
apply_dynamic_relu_out(const at::Tensor & input,at::Tensor & output,bool)479 at::Tensor& PackedLinearWeightFp16::apply_dynamic_relu_out(
480     const at::Tensor& input,
481     at::Tensor& output,
482     bool /* reduce_range */) {
483   TORCH_CHECK((output.device() == c10::kCPU) && (output.dtype() == at::kFloat));
484   return apply_dynamic_impl<true>(input, output);
485 }
486 
set_bias(std::optional<at::Tensor> bias)487 void PackedLinearWeightFp16::set_bias(std::optional<at::Tensor> bias) {
488   bias_ = std::move(bias);
489 }
490 
491 #endif // USE_FBGEMM
492 
493 #if AT_MKLDNN_ENABLED()
494 template <bool ReluFused>
apply_dynamic_impl(at::Tensor input,bool reduce_range)495 at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl(
496     at::Tensor input,
497     bool reduce_range) {
498   // Dynamic: fp32 * int8 -> fp32
499   using at::Tensor;
500 
501   TORCH_CHECK(
502       input.dim() >= 2,
503       "The dimension of input tensor should be larger than or equal to 2");
504   TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float,
505       "qlinear_dynamic (ONEDNN): data type of input should be float.");
506 
507   // Input -> uint8
508   auto input_contig = input.contiguous();
509   const int64_t dim = input.dim();
510   auto input_reshaped =
511       dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
512   auto input_dims = input_reshaped.sizes().vec();
513   auto input_data_type = dnnl::memory::data_type::f32;
514   auto input_desc = ideep::tensor::desc(input_dims, input_data_type);
515   ideep::attr_t op_attr = ReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t();
516   ideep::tensor x;
517   x.init(input_desc, input_contig.data_ptr());
518   // Find quantization parameters
519   float x_max = 0, x_min = 0;
520 #ifdef USE_FBGEMM
521   // Use FBGEMM's FindMinMax if available since it's faster
522   fbgemm::FindMinMax(
523       /*m=*/input_contig.data_ptr<float>(),
524       /*min=*/&x_min,
525       /*max=*/&x_max,
526       /*len=*/input.numel());
527 #else
528   if (input_contig.numel() > 0) {
529     auto [t_min, t_max] = at::aminmax(input_contig);
530     x_max = t_max.item<float>();
531     x_min = t_min.item<float>();
532   }
533 #endif
534 
535 #if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
536   // oneDNN+ACL has optimized kernels for s8s8 matmul, so input is signed
537   using input_qtype = int8_t;
538 #else
539   using input_qtype = uint8_t;
540 #endif
541 
542   auto q_params = quant_utils::ChooseQuantizationParams(
543       /*min=*/x_min,
544       /*max=*/x_max,
545       /*qmin=*/std::numeric_limits<input_qtype>::min(),
546       /*qmax=*/std::numeric_limits<input_qtype>::max(),
547       /*preserve_sparsity=*/false,
548       /*force_scale_power_of_two=*/false,
549       /*reduce_range=*/reduce_range);
550   const std::vector<int32_t>& src_zero_point = std::vector<int32_t>(1, q_params.zero_point);
551   // weights, dst
552   auto w = *(weight_.get());
553   auto dst_dims = {x.get_dim(0), w.get_dim(1)};
554   const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/q_params.scale);
555   const ideep::scale_t& weights_scales = w.get_scale();
556   // Compute -> f32
557   // Use ideep::matmul_forward instead of ideep::inner_product_forward,
558   // since the latter does not support asymmetric quantization
559   // Allocate output Tensor
560   at::Tensor output = at::empty(dst_dims, input.options().dtype(at::kFloat));
561   if (output.numel() == 0) return output;
562   ideep::tensor y({dst_dims, ideep::tensor::data_type::f32,
563                    {output.strides().cbegin(), output.strides().cend()}},
564                   output.data_ptr());
565   bool with_bias = bias_.has_value();
566   if (with_bias) {
567     // Bias might be modified outside (e.g. by quantization bias correction).
568     // If so, update the prepacked bias as well.
569     if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) {
570       bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr());
571     }
572   }
573   const auto& b = with_bias ? bias_.value() : ideep::tensor();
574   // Primitive cache is initialized when called for the first time
575   // and won't be updated afterwards.
576   int num_threads = at::get_num_threads();
577   PrimitiveCacheKey cache_key = std::make_tuple(
578       q_params.scale, q_params.zero_point, input_dims, 1.0, 0, num_threads, /*accum scale*/1.0, /*accum zero point*/0);
579   c10::call_once(*cache_initialized_flag, [&](){
580       LinearParams params;
581       ideep::matmul_forward::prepare</*is_dynamic=*/true>(
582           params, x, w, b, y,
583           src_scales, weights_scales, ideep::scale_t(),
584           src_zero_point, ideep::zero_point_t(), 1.0f, 1.0f, op_attr,
585           ideep::tensor::data_type::f32, std::is_signed_v<input_qtype> ? ideep::s8s8 : ideep::u8s8);
586       get_cache() = LinearPrimitiveCache(cache_key, params);
587       w = w.reorder_if_differ_in(params.pd.weights_desc());
588   });
589   if (get_cache().hit_dynamic(cache_key)) {
590     LinearParams& params = get_cache().get_param();
591     ideep::matmul_forward::compute(params, x, w, b, y, src_scales, src_zero_point);
592   } else {
593     ideep::matmul_forward::compute(x, w, b, y,
594                                    src_scales, weights_scales, ideep::scale_t(),
595                                    src_zero_point, ideep::zero_point_t(),
596                                    1.0f, 1.0f, op_attr);
597   }
598   auto out_sizes = input.sizes().vec();
599   out_sizes.back() = w.get_dim(1);
600   if (output.sizes().vec() == out_sizes)
601     return output;
602   return output.reshape(out_sizes);
603 }
604 
apply_dynamic(at::Tensor input,bool reduce_range)605 at::Tensor PackedLinearWeightsOnednn::apply_dynamic(
606     at::Tensor input,
607     bool reduce_range) {
608   return apply_dynamic_impl</*ReluFused=*/false>(
609       std::move(input), reduce_range);
610 }
611 
apply_dynamic_relu(at::Tensor input,bool reduce_range)612 at::Tensor PackedLinearWeightsOnednn::apply_dynamic_relu(
613     at::Tensor input,
614     bool reduce_range) {
615   return apply_dynamic_impl</*ReluFused=*/true>(
616       std::move(input), reduce_range);
617 }
618 
619 #endif // #if AT_MKLDNN_ENABLED()
620 
621 namespace at {
622 namespace native {
623 namespace {
624 
625 template <bool ReluFused>
626 class QLinearDynamicInt8 final {
627  public:
run(at::Tensor input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight,bool reduce_range)628   static at::Tensor run(
629       at::Tensor input,
630       const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
631       bool reduce_range) {
632     if (ReluFused) {
633       return packed_weight->apply_dynamic_relu(std::move(input), reduce_range);
634     } else {
635       return packed_weight->apply_dynamic(std::move(input), reduce_range);
636     }
637   }
638 };
639 
640 template <bool ReluFused>
641 class QLinearDynamicFp16 final {
642  public:
643 #ifdef USE_FBGEMM
run(at::Tensor input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight)644   static at::Tensor run(
645       at::Tensor input,
646       const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight) {
647     // We make a strong guarantee that models using these operators will have
648     // the same numerics across different machines. Therefore, we do not provide
649     // a fallback path and rather fail loudly if we cannot run FBGEMM.
650     TORCH_CHECK(
651         fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
652 
653     auto output = packed_weight->apply_dynamic(std::move(input));
654 
655     // Call the relu operator here until fp16 linear dynamic in FBGEMM
656     // supports it natively.
657     if (ReluFused) {
658       output.relu_();
659     }
660     return output;
661   }
662 #else // USE_FBGEMM
663   static at::Tensor run(
664       at::Tensor /* input */,
665       const c10::intrusive_ptr<LinearPackedParamsBase>& /* packed_weight */) {
666     // We make a strong guarantee that models using these operators will have
667     // the same numerics across different machines. Therefore, we do not provide
668     // a fallback path and rather fail loudly if we cannot run FBGEMM.
669     TORCH_CHECK(
670         false, "This PyTorch installation was not built with FBGEMM operators");
671   }
672 #endif // USE_FBGEMM
673 };
674 
675 class QLinearUnpackedDynamicFp16 final {
676  public:
677 #ifdef USE_FBGEMM
run(at::Tensor input,const at::Tensor & weight,const at::Tensor & bias)678   static at::Tensor run(
679       at::Tensor input,
680       const at::Tensor& weight,
681       const at::Tensor& bias) {
682     // We make a strong guarantee that models using these operators will have
683     // the same numerics across different machines. Therefore, we do not provide
684     // a fallback path and rather fail loudly if we cannot run FBGEMM.
685     TORCH_CHECK(
686         fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
687 
688     TORCH_CHECK(
689         weight.dim() == 2,
690         "The dimension of weight tensor should be equal to 2");
691 
692     auto packed_weight = PackedLinearWeightFp16::prepack(weight, bias);
693     auto output = packed_weight->apply_dynamic(std::move(input));
694 
695     return output;
696   }
697 
meta(at::Tensor input,const at::Tensor & weight,const at::Tensor & bias)698   static at::Tensor meta(
699       at::Tensor input,
700       const at::Tensor& weight,
701       const at::Tensor& bias) {
702     // We make a strong guarantee that models using these operators will have
703     // the same numerics across different machines. Therefore, we do not provide
704     // a fallback path and rather fail loudly if we cannot run FBGEMM.
705     TORCH_CHECK(
706         fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
707 
708     TORCH_CHECK(
709         weight.dim() == 2,
710         "The dimension of weight tensor should be equal to 2");
711 
712     auto out_channel = weight.sym_sizes().vec()[0];
713     auto out_sizes = input.sym_sizes().vec();
714     out_sizes[out_sizes.size() - 1] = out_channel;
715 
716     return at::empty_symint(out_sizes, input.options());
717   }
718 #else // USE_FBGEMM
719   static at::Tensor run(
720       at::Tensor /* input */,
721       const at::Tensor& weight,
722       const at::Tensor& bias) {
723     // We make a strong guarantee that models using these operators will have
724     // the same numerics across different machines. Therefore, we do not provide
725     // a fallback path and rather fail loudly if we cannot run FBGEMM.
726     TORCH_CHECK(
727         false, "This PyTorch installation was not built with FBGEMM operators");
728   }
729 
730   static at::Tensor meta(
731       at::Tensor /* input */,
732       const at::Tensor& weight,
733       const at::Tensor& bias) {
734     TORCH_CHECK(
735         false, "This PyTorch installation was not built with FBGEMM operators");
736   }
737 #endif // USE_FBGEMM
738 };
739 
wrapped_fbgemm_pack_gemm_matrix_fp16(const at::Tensor & weight)740 at::Tensor wrapped_fbgemm_pack_gemm_matrix_fp16(const at::Tensor& weight) {
741 #ifdef USE_FBGEMM
742   TORCH_CHECK(
743       weight.dim() == 2,
744       "fbgemm weight packing only packs matrices not vectors.");
745   return at::native::fbgemm_pack_gemm_matrix_fp16(weight);
746 #else // USE_FBGEMM
747   TORCH_CHECK(
748       false, "This PyTorch installation was not built with FBGEMM operators");
749 #endif // USE_FBGEMM
750 }
751 
wrapped_fbgemm_pack_gemm_matrix_fp16_meta(const at::Tensor & weight)752 at::Tensor wrapped_fbgemm_pack_gemm_matrix_fp16_meta(const at::Tensor& weight) {
753 #ifdef USE_FBGEMM
754   // Strictly speaking this is not correct. However we do not know the exact
755   // size of the packed matrix as it's being maintained by the object itself,
756   // therefore we return the view we have here.
757   return at::empty({8}, weight.options().dtype(at::kByte));
758 #else // USE_FBGEMM
759   TORCH_CHECK(
760       false, "This PyTorch installation was not built with FBGEMM operators");
761 #endif // USE_FBGEMM
762 }
763 
wrapped_fbgemm_linear_fp16_weight(const at::Tensor & input,const at::Tensor & weight,const at::Tensor & bias,int64_t out_channel)764 at::Tensor wrapped_fbgemm_linear_fp16_weight(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, int64_t out_channel) {
765 #ifdef USE_FBGEMM
766   return at::native::fbgemm_linear_fp16_weight(input, weight, bias);
767 #else // USE_FBGEMM
768   TORCH_CHECK(
769       false, "This PyTorch installation was not built with FBGEMM operators");
770 #endif // USE_FBGEMM
771 }
772 
wrapped_fbgemm_linear_fp16_weight_meta(const at::Tensor & input,const at::Tensor & weight,const at::Tensor & bias,int64_t out_channel)773 at::Tensor wrapped_fbgemm_linear_fp16_weight_meta(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, int64_t out_channel) {
774 #ifdef USE_FBGEMM
775   // For the meta function, we need users to provide the dimension explicitly
776   // as we don't have access to the weight.
777   auto out_sizes = input.sym_sizes().vec();
778   if (out_channel == -1) {
779     out_sizes.pop_back();
780   } else {
781     out_sizes.back() = out_channel;
782   }
783   return at::empty_symint(out_sizes, input.options());
784 #else // USE_FBGEMM
785   TORCH_CHECK(
786       false, "This PyTorch installation was not built with FBGEMM operators");
787 #endif // USE_FBGEMM
788 }
789 
790 
TORCH_LIBRARY_IMPL(quantized,CPU,m)791 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
792   register_linear_params();
793   m.impl(
794       TORCH_SELECTIVE_NAME("quantized::linear_dynamic"),
795       TORCH_FN(QLinearDynamicInt8<false>::run));
796   m.impl(
797       TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic"),
798       TORCH_FN(QLinearDynamicInt8<true>::run));
799   m.impl(
800       TORCH_SELECTIVE_NAME("quantized::linear_dynamic_fp16"),
801       TORCH_FN(QLinearDynamicFp16<false>::run));
802   m.impl(
803       TORCH_SELECTIVE_NAME("quantized::linear_dynamic_fp16_unpacked_weight"),
804       TORCH_FN(QLinearUnpackedDynamicFp16::run));
805   m.impl(
806       TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic_fp16"),
807       TORCH_FN(QLinearDynamicFp16<true>::run));
808 }
809 
TORCH_LIBRARY_IMPL(quantized,Meta,m)810 TORCH_LIBRARY_IMPL(quantized, Meta, m) {
811   m.impl(
812       TORCH_SELECTIVE_NAME("quantized::linear_dynamic_fp16_unpacked_weight"),
813       TORCH_FN(QLinearUnpackedDynamicFp16::meta));
814 }
815 
TORCH_LIBRARY_IMPL(_quantized,CPU,m)816 TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
817   register_linear_params();
818   m.impl(
819       TORCH_SELECTIVE_NAME("_quantized::linear_dynamic"),
820       TORCH_FN(QLinearDynamicInt8<false>::run));
821   m.impl(
822       TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16"),
823       wrapped_fbgemm_pack_gemm_matrix_fp16);
824   m.impl(
825       TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_linear_fp16_weight"),
826       wrapped_fbgemm_linear_fp16_weight);
827 }
828 
TORCH_LIBRARY_IMPL(_quantized,Meta,m)829 TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
830   m.impl(
831       TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16"),
832       wrapped_fbgemm_pack_gemm_matrix_fp16_meta);
833   m.impl(
834       TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_linear_fp16_weight"),
835       wrapped_fbgemm_linear_fp16_weight_meta);
836 }
837 
838 } // namespace
839 } // namespace native
840 } // namespace at
841