1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
6 #include <ATen/native/quantized/PackedParams.h>
7 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
8 #include <ATen/native/quantized/cpu/OnednnUtils.h>
9 #include <ATen/native/quantized/cpu/QuantUtils.h>
10 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
11 #include <torch/library.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #else
16 #include <ATen/ops/_empty_affine_quantized.h>
17 #include <ATen/ops/aminmax.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
20 #include <ATen/ops/fbgemm_linear_fp16_weight_native.h>
21 #include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h>
22 #include <ATen/ops/quantize_per_tensor.h>
23 #endif
24
25 #include <c10/util/irange.h>
26
27 #include <algorithm>
28 #include <string>
29 #include <type_traits>
30
31 int register_linear_params();
32
33 #ifdef USE_FBGEMM
34 template <bool ReluFused>
apply_dynamic_impl(at::Tensor input,bool reduce_range)35 at::Tensor PackedLinearWeight::apply_dynamic_impl(
36 at::Tensor input,
37 bool reduce_range) {
38 using at::Tensor;
39 // fp32 * int8 -> fp32 (with quantization on activation, and dequantization
40 // on the result).
41
42 // We make a strong guarantee that models using these operators will have
43 // the same numerics across different machines. Therefore, we do not provide
44 // a fallback path and rather fail loudly if we cannot run FBGEMM.
45 TORCH_CHECK(
46 fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
47
48 // TODO: contiguous is called for further jit optimizations.
49 auto input_contig = input.contiguous();
50 const auto* input_ptr = input_contig.const_data_ptr<float>();
51
52 TORCH_CHECK(
53 input.dim() >= 2,
54 "The dimension of input tensor should be larger than or equal to 2");
55 // C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
56 // matrices, respectively.
57 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
58 int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
59
60 auto packB = w.get();
61
62 int64_t N = static_cast<int64_t>(packB->numCols());
63 int64_t K = input.size(input.dim() - 1);
64 TORCH_CHECK(
65 K == static_cast<int64_t>(packB->numRows()),
66 "The number of rows in the packB should be equal to K: " +
67 std::to_string(K));
68
69 // Calculate statistics for quantization of the input Tensor
70 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
71 float x_min, x_max;
72 fbgemm::FindMinMax(
73 /*m=*/input_ptr,
74 /*min=*/&x_min,
75 /*max=*/&x_max,
76 /*len=*/input.numel());
77
78 // Input tensor is quantized as 8-bit unsigned values
79 static constexpr int precision = 8;
80 static constexpr bool is_signed = false;
81
82 // Calculate scale and zero point for quantization of input tensor
83 auto q_params = quant_utils::ChooseQuantizationParams(
84 /*min=*/x_min,
85 /*max=*/x_max,
86 /*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
87 /*qmax=*/
88 is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
89 /*preserve_sparsity=*/false,
90 /*force_scale_power_of_two=*/false,
91 /*reduce_range=*/reduce_range);
92
93 q_params.precision = precision;
94
95 // ReQuantizeForFloat requires pointers to the zero point values,
96 // since in the case of rowwise quantization these will be arrays rather
97 // than scalars. But in this case, we're doing whole-tensor quantization so
98 // we just pass a pointer to the scale values (and internally
99 // ReQuantizeForFloat won't index past 0.
100
101 const float* bias_ptr = nullptr;
102 at::Tensor bias_vec;
103 if (bias_.has_value()) {
104 bias_vec = bias_.value();
105 TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
106 TORCH_CHECK(
107 bias_vec.size(0) == N,
108 "bias should have N elements: " + std::to_string(N));
109 // TODO: contiguous is called for further jit optimizations.
110 auto bias_contig = bias_vec.contiguous();
111 bias_ptr = bias_contig.data_ptr<float>();
112 }
113 // The resulting matrix here is 2-D, let's view it with the original
114 // left hand dimensions of the input. Here are two examples:
115 // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
116 // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
117 std::vector<int64_t> out_sizes = input.sizes().vec();
118 out_sizes.back() = N;
119 // Allocate output Tensor and a buffer for fbgemmPacked to use
120 auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
121 auto buffer = at::empty_like(
122 output,
123 output.options().dtype(at::kInt),
124 LEGACY_CONTIGUOUS_MEMORY_FORMAT);
125
126 int num_tasks = at::get_num_threads();
127 at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
128 // This operation does the following:
129 // 1) Quantizes the input matrix given the statistics we've calculated
130 // above
131 // 2) Creates a "row buffer" vector with offset values that must be
132 // added
133 // to the integer matrix multiplication operation to ensure
134 // correctness. This "row buffer" is also called the row offset, and it
135 // is needed when we use affine quantization for weights.
136 // 3) Packs the resulting quantized matrix into vector-register and cache
137 // friendly tiles.
138 //
139 // Note this is not executed eagerly, but rather within the fbgemmPacked
140 // call below.
141
142 fbgemm::PackAWithQuantRowOffset<uint8_t> packA(
143 /*trans=*/fbgemm::matrix_op_t::NoTranspose,
144 /*nRow=*/M,
145 /*nCol=*/K,
146 /*smat=*/input_ptr,
147 /*ld=*/K,
148 /*pmat=*/nullptr, // Currently, packA manages ownership of `pmat`.
149 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
150 /*scale=*/q_params.scale,
151 /*zero_pt=*/q_params.zero_point);
152 // TODO: Consider a way to pre-allocate and reuse
153 // pmat buffer.
154
155 // This is the end of the pipeline, pass the resulting matrix through.
156 fbgemm::DoNothing<float, float> doNothingObj{};
157
158 for (const auto task_id : c10::irange(begin, end)) {
159 if (q_scheme == c10::kPerTensorAffine) {
160 // Process the per tensor quantization.
161 //
162 // After the uint8 * int8 matrix multiplication is performed, this
163 // operation does:
164 // 1) Add in row and column offsets to the rows and columns,
165 // respectively.
166 // 2) Dequantize the results into floating point.
167 // 3) Add in the bias term.
168 fbgemm::ReQuantizeForFloat<ReluFused> outputProcObj(
169 /*nextop=*/doNothingObj,
170 /*Aq_scale=*/q_params.scale,
171 /*Bq_scale=*/w_scale.data(),
172 /*Aq_zero_point=*/q_params.zero_point,
173 /*Bq_zero_point=*/w_zp.data(),
174 /*row_offsets=*/packA.getRowOffsetBuffer(),
175 /*col_offsets=*/col_offsets.data(),
176 /*bias=*/bias_ptr,
177 /*nCol=*/N);
178
179 // Do the GEMM
180 fbgemm::fbgemmPacked(
181 /*packA=*/packA,
182 /*packB=*/*packB,
183 /*C=*/output.data_ptr<float>(),
184 /*C_buffer=*/buffer.data_ptr<int32_t>(),
185 /*ldc=*/N,
186 /*outProcess=*/outputProcObj,
187 /*thread_id=*/task_id,
188 /*num_threads=*/num_tasks);
189
190 } else if (q_scheme == c10::kPerChannelAffine) {
191 // Process the per channel quantization.
192 //
193 // After the uint8 * int8 matrix multiplication is performed, this
194 // operation does:
195 // 1) Add in row and column offsets to the rows and columns,
196 // respectively.
197 // 2) Dequantize the results into floating point.
198 // 3) Add in the bias term.
199 fbgemm::ReQuantizeForFloat<
200 ReluFused,
201 fbgemm::QuantizationGranularity::OUT_CHANNEL>
202 outputProcObj(
203 /*nextop=*/doNothingObj,
204 /*Aq_scale=*/q_params.scale,
205 /*Bq_scale=*/w_scale.data(),
206 /*Aq_zero_point=*/q_params.zero_point,
207 /*Bq_zero_point=*/w_zp.data(),
208 /*row_offsets=*/packA.getRowOffsetBuffer(),
209 /*col_offsets=*/col_offsets.data(),
210 /*bias=*/bias_ptr,
211 /*nCol=*/N);
212
213 // Do the GEMM
214 fbgemm::fbgemmPacked(
215 /*packA=*/packA,
216 /*packB=*/*packB,
217 /*C=*/output.data_ptr<float>(),
218 /*C_buffer=*/buffer.data_ptr<int32_t>(),
219 /*ldc=*/N,
220 /*outProcess=*/outputProcObj,
221 /*thread_id=*/task_id,
222 /*num_threads=*/num_tasks);
223 }
224 }
225 });
226
227 return output;
228 }
229
apply_dynamic(at::Tensor input,bool reduce_range)230 at::Tensor PackedLinearWeight::apply_dynamic(
231 at::Tensor input,
232 bool reduce_range) {
233 return apply_dynamic_impl</*ReluFused=*/false>(
234 std::move(input), reduce_range);
235 }
236
apply_dynamic_relu(at::Tensor input,bool reduce_range)237 at::Tensor PackedLinearWeight::apply_dynamic_relu(
238 at::Tensor input,
239 bool reduce_range) {
240 return apply_dynamic_impl</*ReluFused=*/true>(std::move(input), reduce_range);
241 }
242
243 #endif // USE_FBGEMM
244
245 #ifdef USE_PYTORCH_QNNPACK
246 template <bool ReluFused>
apply_dynamic_impl(at::Tensor input,bool reduce_range)247 at::Tensor PackedLinearWeightsQnnp::apply_dynamic_impl(
248 at::Tensor input,
249 bool reduce_range) {
250 if (reduce_range) {
251 TORCH_WARN_ONCE("Currently, qnnpack incorrectly ignores reduce_range when it is set to true; this may change in a future release.");
252 }
253
254 using at::Tensor;
255 TORCH_CHECK(
256 input.dim() >= 2,
257 "The dimension of input tensor should be larger than or equal to 2");
258 auto input_contig = input.contiguous();
259 // C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
260 // matrices, respectively.
261
262 // Weight packing is not thread safe
263 std::lock_guard<std::mutex> lock(qnnp_mutex_);
264 auto packB = w.get();
265 size_t rows_w = bias_.size(0);
266 size_t cols_w = input_contig.size(input_contig.dim() - 1);
267
268 at::Tensor bias_vec = bias_;
269
270 TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)");
271
272 auto bias_contig = bias_vec.contiguous();
273 const float* bias_ptr = bias_contig.const_data_ptr<float>();
274
275 // Calculate statistics for quantization of input Tensor
276 // TODO: optimized kernel
277 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
278 float x_min;
279 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
280 float x_max;
281 if (input.numel() > 0) {
282 x_min = input_contig.min().item<float>();
283 x_max = input_contig.max().item<float>();
284 } else {
285 // On empty input, no output data will be generated,
286 // so use arbitrary qparams.
287 x_min = 0;
288 x_max = 0;
289 }
290
291 auto q_params = quant_utils::ChooseQuantizationParams(
292 /*min=*/x_min,
293 /*max=*/x_max,
294 /*qmin=*/0,
295 /*qmax=*/255);
296 float* weight_scales_data = w_scales.data_ptr<float>();
297
298 if (!input_scale.has_value() || input_scale.value() != q_params.scale) {
299 generate_requantization_scales(
300 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
301 w_scales,
302 q_params.scale,
303 1.f,
304 requantization_scales);
305 }
306
307 if (!input_scale.has_value()) {
308 // Get the original weight and adjust it to uint8 from int8
309 auto weight_contig = orig_weight;
310
311 // TODO(kimishpatel), we are allocating affine_quantized regardless of per
312 // channel or not. This allocation is actually used only for packing weight
313 // and thus will be freed. Still we should be consistent. Fix this.
314 Tensor qnnp_weight = at::_empty_affine_quantized(
315 weight_contig.sizes(),
316 at::device(c10::kCPU).dtype(c10::kQUInt8),
317 weight_scales_data[0],
318 w_zero_points[0]);
319 auto* qnnp_w_data = qnnp_weight.data_ptr<c10::quint8>();
320 int8_t* w_data = (int8_t*)weight_contig.data_ptr<c10::qint8>();
321 auto wt_numel = weight_contig.numel();
322 for (const auto i : c10::irange(wt_numel)) {
323 qnnp_w_data[i] = static_cast<c10::quint8>(w_data[i] + 128);
324 }
325
326 // Pass in nullptr for bias, as we pass FP32 bias to run function.
327 w.reset();
328 w = std::make_unique<qnnpack::PackBMatrix>(
329 cols_w /* input_channels */,
330 rows_w /* output_channels */,
331 w_zero_points.data(),
332 requantization_scales.data(),
333 (uint8_t*)qnnp_w_data,
334 nullptr);
335 packB = w.get();
336 if (at::globalContext().releaseWeightsWhenPrepacking()) {
337 // On mobile, we release the original weight by resetting the
338 // intrusive_ptr. Calling unpack after this will throw an assertion.
339 orig_weight.reset();
340 }
341 }
342
343 // Update the input scale to not pack weights again.
344 // as well as to avoid repopulating requant scale if scale has not changed.
345 input_scale = q_params.scale;
346
347 // Quantize input
348 Tensor q_input = at::quantize_per_tensor(
349 input_contig, q_params.scale, q_params.zero_point, c10::kQUInt8);
350
351 // The resulting matrix here is 2-D, let's view it with the original
352 // left hand dimensions of the input. Here are two examples:
353 // 1. If the input tensor is {M, K}, the output tensor is {M, N}.
354 // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
355 std::vector<int64_t> out_sizes = input.sizes().vec();
356 out_sizes.back() = rows_w;
357
358 auto output = at::empty(out_sizes, input.options().dtype(at::kFloat));
359
360 size_t rows_input = 1;
361 size_t cols_input = input_contig.size(input_contig.dim() - 1);
362 for (const auto i : c10::irange(input_contig.dim() - 1)) {
363 rows_input *= input_contig.size(i);
364 }
365 pytorch_qnnp_status runStatus = qnnpack::qnnpackLinearDynamic(
366 rows_input /* batch_size */,
367 cols_input /* input_channels */,
368 rows_w /* output_channels */,
369 q_input.q_zero_point(),
370 w_zero_points.data(),
371 /* for dynamic should really be called dequant scale */
372 requantization_scales.data(),
373 (uint8_t*)q_input.data_ptr<c10::quint8>(),
374 cols_input /* input_stride */,
375 packB->getPackedWeights(),
376 bias_ptr,
377 output.data_ptr<float>(),
378 rows_w /* output_stride */,
379 caffe2::pthreadpool_() /* threadpool */);
380
381 TORCH_INTERNAL_ASSERT(
382 runStatus == pytorch_qnnp_status_success,
383 "failed to run QNNPACK Linear operator");
384
385 // Call the relu operator here until qlinear dynamic in QNNPACK
386 // supports it natively.
387 if (ReluFused) {
388 output.relu_();
389 }
390 return output;
391 }
392
apply_dynamic(at::Tensor input,bool reduce_range)393 at::Tensor PackedLinearWeightsQnnp::apply_dynamic(
394 at::Tensor input,
395 bool reduce_range) {
396 return apply_dynamic_impl</*ReluFused=*/false>(std::move(input), reduce_range);
397 }
398
apply_dynamic_relu(at::Tensor input,bool reduce_range)399 at::Tensor PackedLinearWeightsQnnp::apply_dynamic_relu(
400 at::Tensor input,
401 bool reduce_range ) {
402 return apply_dynamic_impl</*ReluFused=*/true>(std::move(input), reduce_range);
403 }
404
405 #endif // USE_PYTORCH_QNNPACK
406
407 #ifdef USE_FBGEMM
408
409 template <bool ReluFused>
apply_dynamic_impl(const at::Tensor & input,at::Tensor & output)410 at::Tensor& PackedLinearWeightFp16::apply_dynamic_impl(
411 const at::Tensor& input,
412 at::Tensor& output) {
413 const at::Tensor input_contig = input.contiguous();
414 const float* input_ptr = input_contig.const_data_ptr<float>();
415
416 auto& packed_weight_fp16 = *w;
417
418 TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows())
419 TORCH_CHECK(input.dim() >= 2);
420
421 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
422 const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
423 const int64_t N = packed_weight_fp16.numCols();
424 std::vector<int64_t> output_sizes = input.sizes().vec();
425 TORCH_CHECK(!output_sizes.empty())
426 output_sizes.back() = N;
427 // Resize output Tensor
428 output.resize_(output_sizes);
429
430 auto output_data = output.data_ptr<float>();
431
432 int num_tasks = at::get_num_threads();
433 at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
434 for (const auto task_id : c10::irange(begin, end)) {
435 // Call the fp16 gemm interface
436 fbgemm::cblas_gemm_compute(
437 /*transa=*/fbgemm::matrix_op_t::NoTranspose,
438 /*m=*/static_cast<int>(M),
439 /*A=*/input_ptr,
440 /*Bp=*/packed_weight_fp16,
441 /*beta=*/0.0f,
442 /*C=*/output_data,
443 /*thread_id=*/static_cast<int>(task_id),
444 /*num_threads=*/num_tasks);
445 }
446 });
447
448 // Add bias term
449 if (bias_.has_value()) {
450 TORCH_CHECK(bias_->dim() == 1);
451 output.add_(*bias_);
452 }
453
454 return output;
455 }
456
apply_dynamic(at::Tensor input,bool)457 at::Tensor PackedLinearWeightFp16::apply_dynamic(
458 at::Tensor input,
459 bool /* reduce_range */) {
460 at::Tensor output = at::empty({0}, input.options().dtype(at::kFloat));
461 return apply_dynamic_impl</*ReluFused=*/false>(input, output);
462 }
463
apply_dynamic_relu(at::Tensor input,bool)464 at::Tensor PackedLinearWeightFp16::apply_dynamic_relu(
465 at::Tensor input,
466 bool /* reduce_range */) {
467 at::Tensor output = at::empty({0}, input.options().dtype(at::kFloat));
468 return apply_dynamic_impl</*ReluFused=*/true>(input, output);
469 }
470
apply_dynamic_out(const at::Tensor & input,at::Tensor & output,bool)471 at::Tensor& PackedLinearWeightFp16::apply_dynamic_out(
472 const at::Tensor& input,
473 at::Tensor& output,
474 bool /* reduce_range */) {
475 TORCH_CHECK((output.device() == c10::kCPU) && (output.dtype() == at::kFloat));
476 return apply_dynamic_impl<false>(input, output);
477 }
478
apply_dynamic_relu_out(const at::Tensor & input,at::Tensor & output,bool)479 at::Tensor& PackedLinearWeightFp16::apply_dynamic_relu_out(
480 const at::Tensor& input,
481 at::Tensor& output,
482 bool /* reduce_range */) {
483 TORCH_CHECK((output.device() == c10::kCPU) && (output.dtype() == at::kFloat));
484 return apply_dynamic_impl<true>(input, output);
485 }
486
set_bias(std::optional<at::Tensor> bias)487 void PackedLinearWeightFp16::set_bias(std::optional<at::Tensor> bias) {
488 bias_ = std::move(bias);
489 }
490
491 #endif // USE_FBGEMM
492
493 #if AT_MKLDNN_ENABLED()
494 template <bool ReluFused>
apply_dynamic_impl(at::Tensor input,bool reduce_range)495 at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl(
496 at::Tensor input,
497 bool reduce_range) {
498 // Dynamic: fp32 * int8 -> fp32
499 using at::Tensor;
500
501 TORCH_CHECK(
502 input.dim() >= 2,
503 "The dimension of input tensor should be larger than or equal to 2");
504 TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float,
505 "qlinear_dynamic (ONEDNN): data type of input should be float.");
506
507 // Input -> uint8
508 auto input_contig = input.contiguous();
509 const int64_t dim = input.dim();
510 auto input_reshaped =
511 dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
512 auto input_dims = input_reshaped.sizes().vec();
513 auto input_data_type = dnnl::memory::data_type::f32;
514 auto input_desc = ideep::tensor::desc(input_dims, input_data_type);
515 ideep::attr_t op_attr = ReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t();
516 ideep::tensor x;
517 x.init(input_desc, input_contig.data_ptr());
518 // Find quantization parameters
519 float x_max = 0, x_min = 0;
520 #ifdef USE_FBGEMM
521 // Use FBGEMM's FindMinMax if available since it's faster
522 fbgemm::FindMinMax(
523 /*m=*/input_contig.data_ptr<float>(),
524 /*min=*/&x_min,
525 /*max=*/&x_max,
526 /*len=*/input.numel());
527 #else
528 if (input_contig.numel() > 0) {
529 auto [t_min, t_max] = at::aminmax(input_contig);
530 x_max = t_max.item<float>();
531 x_min = t_min.item<float>();
532 }
533 #endif
534
535 #if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
536 // oneDNN+ACL has optimized kernels for s8s8 matmul, so input is signed
537 using input_qtype = int8_t;
538 #else
539 using input_qtype = uint8_t;
540 #endif
541
542 auto q_params = quant_utils::ChooseQuantizationParams(
543 /*min=*/x_min,
544 /*max=*/x_max,
545 /*qmin=*/std::numeric_limits<input_qtype>::min(),
546 /*qmax=*/std::numeric_limits<input_qtype>::max(),
547 /*preserve_sparsity=*/false,
548 /*force_scale_power_of_two=*/false,
549 /*reduce_range=*/reduce_range);
550 const std::vector<int32_t>& src_zero_point = std::vector<int32_t>(1, q_params.zero_point);
551 // weights, dst
552 auto w = *(weight_.get());
553 auto dst_dims = {x.get_dim(0), w.get_dim(1)};
554 const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/q_params.scale);
555 const ideep::scale_t& weights_scales = w.get_scale();
556 // Compute -> f32
557 // Use ideep::matmul_forward instead of ideep::inner_product_forward,
558 // since the latter does not support asymmetric quantization
559 // Allocate output Tensor
560 at::Tensor output = at::empty(dst_dims, input.options().dtype(at::kFloat));
561 if (output.numel() == 0) return output;
562 ideep::tensor y({dst_dims, ideep::tensor::data_type::f32,
563 {output.strides().cbegin(), output.strides().cend()}},
564 output.data_ptr());
565 bool with_bias = bias_.has_value();
566 if (with_bias) {
567 // Bias might be modified outside (e.g. by quantization bias correction).
568 // If so, update the prepacked bias as well.
569 if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) {
570 bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr());
571 }
572 }
573 const auto& b = with_bias ? bias_.value() : ideep::tensor();
574 // Primitive cache is initialized when called for the first time
575 // and won't be updated afterwards.
576 int num_threads = at::get_num_threads();
577 PrimitiveCacheKey cache_key = std::make_tuple(
578 q_params.scale, q_params.zero_point, input_dims, 1.0, 0, num_threads, /*accum scale*/1.0, /*accum zero point*/0);
579 c10::call_once(*cache_initialized_flag, [&](){
580 LinearParams params;
581 ideep::matmul_forward::prepare</*is_dynamic=*/true>(
582 params, x, w, b, y,
583 src_scales, weights_scales, ideep::scale_t(),
584 src_zero_point, ideep::zero_point_t(), 1.0f, 1.0f, op_attr,
585 ideep::tensor::data_type::f32, std::is_signed_v<input_qtype> ? ideep::s8s8 : ideep::u8s8);
586 get_cache() = LinearPrimitiveCache(cache_key, params);
587 w = w.reorder_if_differ_in(params.pd.weights_desc());
588 });
589 if (get_cache().hit_dynamic(cache_key)) {
590 LinearParams& params = get_cache().get_param();
591 ideep::matmul_forward::compute(params, x, w, b, y, src_scales, src_zero_point);
592 } else {
593 ideep::matmul_forward::compute(x, w, b, y,
594 src_scales, weights_scales, ideep::scale_t(),
595 src_zero_point, ideep::zero_point_t(),
596 1.0f, 1.0f, op_attr);
597 }
598 auto out_sizes = input.sizes().vec();
599 out_sizes.back() = w.get_dim(1);
600 if (output.sizes().vec() == out_sizes)
601 return output;
602 return output.reshape(out_sizes);
603 }
604
apply_dynamic(at::Tensor input,bool reduce_range)605 at::Tensor PackedLinearWeightsOnednn::apply_dynamic(
606 at::Tensor input,
607 bool reduce_range) {
608 return apply_dynamic_impl</*ReluFused=*/false>(
609 std::move(input), reduce_range);
610 }
611
apply_dynamic_relu(at::Tensor input,bool reduce_range)612 at::Tensor PackedLinearWeightsOnednn::apply_dynamic_relu(
613 at::Tensor input,
614 bool reduce_range) {
615 return apply_dynamic_impl</*ReluFused=*/true>(
616 std::move(input), reduce_range);
617 }
618
619 #endif // #if AT_MKLDNN_ENABLED()
620
621 namespace at {
622 namespace native {
623 namespace {
624
625 template <bool ReluFused>
626 class QLinearDynamicInt8 final {
627 public:
run(at::Tensor input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight,bool reduce_range)628 static at::Tensor run(
629 at::Tensor input,
630 const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
631 bool reduce_range) {
632 if (ReluFused) {
633 return packed_weight->apply_dynamic_relu(std::move(input), reduce_range);
634 } else {
635 return packed_weight->apply_dynamic(std::move(input), reduce_range);
636 }
637 }
638 };
639
640 template <bool ReluFused>
641 class QLinearDynamicFp16 final {
642 public:
643 #ifdef USE_FBGEMM
run(at::Tensor input,const c10::intrusive_ptr<LinearPackedParamsBase> & packed_weight)644 static at::Tensor run(
645 at::Tensor input,
646 const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight) {
647 // We make a strong guarantee that models using these operators will have
648 // the same numerics across different machines. Therefore, we do not provide
649 // a fallback path and rather fail loudly if we cannot run FBGEMM.
650 TORCH_CHECK(
651 fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
652
653 auto output = packed_weight->apply_dynamic(std::move(input));
654
655 // Call the relu operator here until fp16 linear dynamic in FBGEMM
656 // supports it natively.
657 if (ReluFused) {
658 output.relu_();
659 }
660 return output;
661 }
662 #else // USE_FBGEMM
663 static at::Tensor run(
664 at::Tensor /* input */,
665 const c10::intrusive_ptr<LinearPackedParamsBase>& /* packed_weight */) {
666 // We make a strong guarantee that models using these operators will have
667 // the same numerics across different machines. Therefore, we do not provide
668 // a fallback path and rather fail loudly if we cannot run FBGEMM.
669 TORCH_CHECK(
670 false, "This PyTorch installation was not built with FBGEMM operators");
671 }
672 #endif // USE_FBGEMM
673 };
674
675 class QLinearUnpackedDynamicFp16 final {
676 public:
677 #ifdef USE_FBGEMM
run(at::Tensor input,const at::Tensor & weight,const at::Tensor & bias)678 static at::Tensor run(
679 at::Tensor input,
680 const at::Tensor& weight,
681 const at::Tensor& bias) {
682 // We make a strong guarantee that models using these operators will have
683 // the same numerics across different machines. Therefore, we do not provide
684 // a fallback path and rather fail loudly if we cannot run FBGEMM.
685 TORCH_CHECK(
686 fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
687
688 TORCH_CHECK(
689 weight.dim() == 2,
690 "The dimension of weight tensor should be equal to 2");
691
692 auto packed_weight = PackedLinearWeightFp16::prepack(weight, bias);
693 auto output = packed_weight->apply_dynamic(std::move(input));
694
695 return output;
696 }
697
meta(at::Tensor input,const at::Tensor & weight,const at::Tensor & bias)698 static at::Tensor meta(
699 at::Tensor input,
700 const at::Tensor& weight,
701 const at::Tensor& bias) {
702 // We make a strong guarantee that models using these operators will have
703 // the same numerics across different machines. Therefore, we do not provide
704 // a fallback path and rather fail loudly if we cannot run FBGEMM.
705 TORCH_CHECK(
706 fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
707
708 TORCH_CHECK(
709 weight.dim() == 2,
710 "The dimension of weight tensor should be equal to 2");
711
712 auto out_channel = weight.sym_sizes().vec()[0];
713 auto out_sizes = input.sym_sizes().vec();
714 out_sizes[out_sizes.size() - 1] = out_channel;
715
716 return at::empty_symint(out_sizes, input.options());
717 }
718 #else // USE_FBGEMM
719 static at::Tensor run(
720 at::Tensor /* input */,
721 const at::Tensor& weight,
722 const at::Tensor& bias) {
723 // We make a strong guarantee that models using these operators will have
724 // the same numerics across different machines. Therefore, we do not provide
725 // a fallback path and rather fail loudly if we cannot run FBGEMM.
726 TORCH_CHECK(
727 false, "This PyTorch installation was not built with FBGEMM operators");
728 }
729
730 static at::Tensor meta(
731 at::Tensor /* input */,
732 const at::Tensor& weight,
733 const at::Tensor& bias) {
734 TORCH_CHECK(
735 false, "This PyTorch installation was not built with FBGEMM operators");
736 }
737 #endif // USE_FBGEMM
738 };
739
wrapped_fbgemm_pack_gemm_matrix_fp16(const at::Tensor & weight)740 at::Tensor wrapped_fbgemm_pack_gemm_matrix_fp16(const at::Tensor& weight) {
741 #ifdef USE_FBGEMM
742 TORCH_CHECK(
743 weight.dim() == 2,
744 "fbgemm weight packing only packs matrices not vectors.");
745 return at::native::fbgemm_pack_gemm_matrix_fp16(weight);
746 #else // USE_FBGEMM
747 TORCH_CHECK(
748 false, "This PyTorch installation was not built with FBGEMM operators");
749 #endif // USE_FBGEMM
750 }
751
wrapped_fbgemm_pack_gemm_matrix_fp16_meta(const at::Tensor & weight)752 at::Tensor wrapped_fbgemm_pack_gemm_matrix_fp16_meta(const at::Tensor& weight) {
753 #ifdef USE_FBGEMM
754 // Strictly speaking this is not correct. However we do not know the exact
755 // size of the packed matrix as it's being maintained by the object itself,
756 // therefore we return the view we have here.
757 return at::empty({8}, weight.options().dtype(at::kByte));
758 #else // USE_FBGEMM
759 TORCH_CHECK(
760 false, "This PyTorch installation was not built with FBGEMM operators");
761 #endif // USE_FBGEMM
762 }
763
wrapped_fbgemm_linear_fp16_weight(const at::Tensor & input,const at::Tensor & weight,const at::Tensor & bias,int64_t out_channel)764 at::Tensor wrapped_fbgemm_linear_fp16_weight(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, int64_t out_channel) {
765 #ifdef USE_FBGEMM
766 return at::native::fbgemm_linear_fp16_weight(input, weight, bias);
767 #else // USE_FBGEMM
768 TORCH_CHECK(
769 false, "This PyTorch installation was not built with FBGEMM operators");
770 #endif // USE_FBGEMM
771 }
772
wrapped_fbgemm_linear_fp16_weight_meta(const at::Tensor & input,const at::Tensor & weight,const at::Tensor & bias,int64_t out_channel)773 at::Tensor wrapped_fbgemm_linear_fp16_weight_meta(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, int64_t out_channel) {
774 #ifdef USE_FBGEMM
775 // For the meta function, we need users to provide the dimension explicitly
776 // as we don't have access to the weight.
777 auto out_sizes = input.sym_sizes().vec();
778 if (out_channel == -1) {
779 out_sizes.pop_back();
780 } else {
781 out_sizes.back() = out_channel;
782 }
783 return at::empty_symint(out_sizes, input.options());
784 #else // USE_FBGEMM
785 TORCH_CHECK(
786 false, "This PyTorch installation was not built with FBGEMM operators");
787 #endif // USE_FBGEMM
788 }
789
790
TORCH_LIBRARY_IMPL(quantized,CPU,m)791 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
792 register_linear_params();
793 m.impl(
794 TORCH_SELECTIVE_NAME("quantized::linear_dynamic"),
795 TORCH_FN(QLinearDynamicInt8<false>::run));
796 m.impl(
797 TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic"),
798 TORCH_FN(QLinearDynamicInt8<true>::run));
799 m.impl(
800 TORCH_SELECTIVE_NAME("quantized::linear_dynamic_fp16"),
801 TORCH_FN(QLinearDynamicFp16<false>::run));
802 m.impl(
803 TORCH_SELECTIVE_NAME("quantized::linear_dynamic_fp16_unpacked_weight"),
804 TORCH_FN(QLinearUnpackedDynamicFp16::run));
805 m.impl(
806 TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic_fp16"),
807 TORCH_FN(QLinearDynamicFp16<true>::run));
808 }
809
TORCH_LIBRARY_IMPL(quantized,Meta,m)810 TORCH_LIBRARY_IMPL(quantized, Meta, m) {
811 m.impl(
812 TORCH_SELECTIVE_NAME("quantized::linear_dynamic_fp16_unpacked_weight"),
813 TORCH_FN(QLinearUnpackedDynamicFp16::meta));
814 }
815
TORCH_LIBRARY_IMPL(_quantized,CPU,m)816 TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
817 register_linear_params();
818 m.impl(
819 TORCH_SELECTIVE_NAME("_quantized::linear_dynamic"),
820 TORCH_FN(QLinearDynamicInt8<false>::run));
821 m.impl(
822 TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16"),
823 wrapped_fbgemm_pack_gemm_matrix_fp16);
824 m.impl(
825 TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_linear_fp16_weight"),
826 wrapped_fbgemm_linear_fp16_weight);
827 }
828
TORCH_LIBRARY_IMPL(_quantized,Meta,m)829 TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
830 m.impl(
831 TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16"),
832 wrapped_fbgemm_pack_gemm_matrix_fp16_meta);
833 m.impl(
834 TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_linear_fp16_weight"),
835 wrapped_fbgemm_linear_fp16_weight_meta);
836 }
837
838 } // namespace
839 } // namespace native
840 } // namespace at
841