xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <c10/core/QScheme.h>
5 
6 #ifdef USE_FBGEMM
7 #include <fbgemm/Fbgemm.h>
8 #include <fbgemm/FbgemmSparse.h>
9 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
10 
11 namespace ao {
12 namespace sparse {
13 
14 struct TORCH_API PackedLinearWeight
15     : public LinearPackedParamsBase {
PackedLinearWeightPackedLinearWeight16   PackedLinearWeight(std::unique_ptr<fbgemm::BCSRMatrix<int8_t>> w,
17                      std::optional<at::Tensor> bias,
18                      std::vector<int32_t> col_offsets,
19                      std::vector<float> w_scale,
20                      std::vector<int32_t> w_zp,
21                      c10::QScheme q_scheme,
22                      const int64_t out_features_block_size /* block sparsity size across output_features */,
23                      const int64_t in_features_block_size /* block sparsity size across input_features */)
24       : LinearPackedParamsBase(
25             out_features_block_size,
26             in_features_block_size),
27         w(std::move(w)),
28         bias_(std::move(bias)),
29         col_offsets(std::move(col_offsets)),
30         w_scale(std::move(w_scale)),
31         w_zp(std::move(w_zp)),
32         q_scheme(q_scheme) {}
33   std::unique_ptr<fbgemm::BCSRMatrix<int8_t>> w;
34   std::optional<at::Tensor> bias_;
35   std::vector<int32_t> col_offsets;
36   std::vector<float> w_scale;
37   std::vector<int32_t> w_zp;
38   c10::QScheme q_scheme;
39 
40   at::Tensor apply(
41       const at::Tensor& input,
42       double output_scale,
43       int64_t output_zero_point) override;
44   at::Tensor apply_relu(
45       const at::Tensor& input,
46       double output_scale,
47       int64_t output_zero_point) override;
48 
apply_dynamicPackedLinearWeight49   at::Tensor apply_dynamic(const at::Tensor& input) override {
50     TORCH_INTERNAL_ASSERT(
51         false,
52         "Sparse quantized dynamic linear with fused relu is not yet "
53         "supported on qnnpack backend.");
54     return at::Tensor();
55   }
apply_dynamic_reluPackedLinearWeight56   at::Tensor apply_dynamic_relu(const at::Tensor& input) override {
57     TORCH_INTERNAL_ASSERT(
58         false,
59         "Sparse quantized dynamic linear with fused relu is not yet "
60         "supported on qnnpack backend.");
61     return at::Tensor();
62   }
63 
64   LinearPackedSerializationType unpack() override;
65 
66   BCSRSerializationType serialize() override;
67 
68   static c10::intrusive_ptr<LinearPackedParamsBase> deserialize(
69       const BCSRSerializationType& serialized);
70 
biasPackedLinearWeight71   std::optional<at::Tensor> bias() override {
72     return bias_;
73   }
74 
75   static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
76       const at::Tensor& weight,
77       const std::optional<at::Tensor>& bias,
78       const int64_t out_features_block_size,
79       const int64_t in_features_block_size);
80 
81  private:
82   template <bool ReluFused>
83   at::Tensor apply_impl(
84       const at::Tensor& input,
85       double output_scale,
86       int64_t output_zero_point);
87 };
88 
89 }}  // namespace ao::sparse
90 
91 #endif // USE_FBGEMM
92 
93 namespace ao {
94 namespace sparse {
95 int register_linear_params();
96 }}  // namespace ao::sparse
97