1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <c10/core/QScheme.h> 5 6 #ifdef USE_FBGEMM 7 #include <fbgemm/Fbgemm.h> 8 #include <fbgemm/FbgemmSparse.h> 9 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h> 10 11 namespace ao { 12 namespace sparse { 13 14 struct TORCH_API PackedLinearWeight 15 : public LinearPackedParamsBase { PackedLinearWeightPackedLinearWeight16 PackedLinearWeight(std::unique_ptr<fbgemm::BCSRMatrix<int8_t>> w, 17 std::optional<at::Tensor> bias, 18 std::vector<int32_t> col_offsets, 19 std::vector<float> w_scale, 20 std::vector<int32_t> w_zp, 21 c10::QScheme q_scheme, 22 const int64_t out_features_block_size /* block sparsity size across output_features */, 23 const int64_t in_features_block_size /* block sparsity size across input_features */) 24 : LinearPackedParamsBase( 25 out_features_block_size, 26 in_features_block_size), 27 w(std::move(w)), 28 bias_(std::move(bias)), 29 col_offsets(std::move(col_offsets)), 30 w_scale(std::move(w_scale)), 31 w_zp(std::move(w_zp)), 32 q_scheme(q_scheme) {} 33 std::unique_ptr<fbgemm::BCSRMatrix<int8_t>> w; 34 std::optional<at::Tensor> bias_; 35 std::vector<int32_t> col_offsets; 36 std::vector<float> w_scale; 37 std::vector<int32_t> w_zp; 38 c10::QScheme q_scheme; 39 40 at::Tensor apply( 41 const at::Tensor& input, 42 double output_scale, 43 int64_t output_zero_point) override; 44 at::Tensor apply_relu( 45 const at::Tensor& input, 46 double output_scale, 47 int64_t output_zero_point) override; 48 apply_dynamicPackedLinearWeight49 at::Tensor apply_dynamic(const at::Tensor& input) override { 50 TORCH_INTERNAL_ASSERT( 51 false, 52 "Sparse quantized dynamic linear with fused relu is not yet " 53 "supported on qnnpack backend."); 54 return at::Tensor(); 55 } apply_dynamic_reluPackedLinearWeight56 at::Tensor apply_dynamic_relu(const at::Tensor& input) override { 57 TORCH_INTERNAL_ASSERT( 58 false, 59 "Sparse quantized dynamic linear with fused relu is not yet " 60 "supported on qnnpack backend."); 61 return at::Tensor(); 62 } 63 64 LinearPackedSerializationType unpack() override; 65 66 BCSRSerializationType serialize() override; 67 68 static c10::intrusive_ptr<LinearPackedParamsBase> deserialize( 69 const BCSRSerializationType& serialized); 70 biasPackedLinearWeight71 std::optional<at::Tensor> bias() override { 72 return bias_; 73 } 74 75 static c10::intrusive_ptr<LinearPackedParamsBase> prepack( 76 const at::Tensor& weight, 77 const std::optional<at::Tensor>& bias, 78 const int64_t out_features_block_size, 79 const int64_t in_features_block_size); 80 81 private: 82 template <bool ReluFused> 83 at::Tensor apply_impl( 84 const at::Tensor& input, 85 double output_scale, 86 int64_t output_zero_point); 87 }; 88 89 }} // namespace ao::sparse 90 91 #endif // USE_FBGEMM 92 93 namespace ao { 94 namespace sparse { 95 int register_linear_params(); 96 }} // namespace ao::sparse 97