1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <ATen/core/ivalue.h> 5 6 struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { 7 virtual at::Tensor apply( 8 at::Tensor input, 9 double output_scale, 10 int64_t output_zero_point) = 0; 11 virtual at::Tensor apply_relu( 12 at::Tensor input, 13 double output_scale, 14 int64_t output_zero_point) = 0; 15 16 // out variant of LinearPackedParamsBase::apply apply_outLinearPackedParamsBase17 virtual at::Tensor& apply_out( 18 const at::Tensor& /*input*/, 19 double /*output_scale*/, 20 int64_t /*output_zero_point*/, 21 at::Tensor& output) { 22 throw std::runtime_error( 23 "apply_out is not implemented for this packed " 24 "parameter type"); 25 return output; 26 } 27 apply_relu_outLinearPackedParamsBase28 virtual at::Tensor& apply_relu_out( 29 const at::Tensor& /*input*/, 30 double /*output_scale*/, 31 int64_t /*output_zero_point*/, 32 at::Tensor& output) { 33 throw std::runtime_error( 34 "apply_relu_out is not implemented for this packed " 35 "parameter type"); 36 return output; 37 } 38 39 // Corresponding pattern (the ops with `*` are part of the pattern that 40 // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32): 41 // input -> q* -> dq* -> linear* -> 42 // qweight -> dq* / 43 // 44 // After fusion: 45 // input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* -> 46 // qweight / 47 // 48 // Additional Note: the weight is packed as well 49 // Params: 50 // X: float32 Tensor, will be quantized to quint8 in the op 51 // W_prepack: packed qint8 quantized weight and bias 52 // Returns: 53 // Y: float32 Tensor apply_with_input_q_dq_qweight_dq_output_fp32LinearPackedParamsBase54 virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32( 55 at::Tensor input, 56 double input_scale, 57 int64_t input_zero_point) { 58 throw std::runtime_error( 59 "apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed " 60 "parameter type"); 61 return {}; 62 } 63 64 // Corresponding pattern (the ops with `*` are part of the pattern that 65 // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32): 66 // input -> q* -> dq* -> linear* -> relu* -> 67 // qweight -> dq* / 68 // 69 // After fusion: 70 // input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* -> 71 // qweight / 72 // 73 // Additional Note: the weight is packed as well 74 // Params: 75 // input: float32 Tensor, will be quantized to quint8 in the op 76 // Returns: 77 // float32 Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32LinearPackedParamsBase78 virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32( 79 at::Tensor input, 80 double input_scale, 81 int64_t input_zero_point) { 82 throw std::runtime_error( 83 "apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed " 84 "parameter type"); 85 return {}; 86 } 87 88 virtual at::Tensor apply_dynamic( 89 at::Tensor input, 90 bool reduce_range = false) = 0; 91 virtual at::Tensor apply_dynamic_relu( 92 at::Tensor input, 93 bool reduce_range = false) = 0; 94 apply_dynamic_outLinearPackedParamsBase95 virtual at::Tensor& apply_dynamic_out( 96 const at::Tensor& /* input */, 97 at::Tensor& output, 98 bool /* reduce_range */) { 99 throw std::runtime_error( 100 "apply_dynamic_out is not implemented for this packed " 101 "parameter type"); 102 return output; 103 } apply_dynamic_relu_outLinearPackedParamsBase104 virtual at::Tensor& apply_dynamic_relu_out( 105 const at::Tensor& /* input */, 106 at::Tensor& output, 107 bool /* reduce_range */) { 108 throw std::runtime_error( 109 "apply_dynamic_relu_out is not implemented for this packed " 110 "parameter type"); 111 return output; 112 } 113 114 virtual std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() = 0; 115 116 virtual std::optional<at::Tensor> bias() = 0; 117 set_biasLinearPackedParamsBase118 virtual void set_bias(std::optional<at::Tensor> /*bias*/) { 119 throw std::runtime_error( 120 "set_bias is not implemented for this packed " 121 "parameter type"); 122 } 123 }; 124 125 template <int kSpatialDim = 2> 126 struct ConvPackedParamsBase : public torch::jit::CustomClassHolder { 127 virtual at::Tensor apply( 128 const at::Tensor& input, 129 double output_scale, 130 int64_t output_zero_point) = 0; 131 virtual at::Tensor apply_relu( 132 const at::Tensor& input, 133 double output_scale, 134 int64_t output_zero_point) = 0; 135 virtual at::Tensor apply_dynamic( 136 const at::Tensor& input, 137 bool reduce_range) = 0; 138 139 virtual std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() = 0; 140 141 virtual torch::List<int64_t> stride() const = 0; 142 virtual torch::List<int64_t> padding() const = 0; 143 virtual torch::List<int64_t> output_padding() const = 0; 144 virtual torch::List<int64_t> dilation() const = 0; 145 virtual int64_t groups() const = 0; 146 virtual bool transpose() const = 0; 147 }; 148