xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/PackedParams.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/core/ivalue.h>
5 
6 struct LinearPackedParamsBase : public torch::jit::CustomClassHolder {
7   virtual at::Tensor apply(
8       at::Tensor input,
9       double output_scale,
10       int64_t output_zero_point) = 0;
11   virtual at::Tensor apply_relu(
12       at::Tensor input,
13       double output_scale,
14       int64_t output_zero_point) = 0;
15 
16   // out variant of LinearPackedParamsBase::apply
apply_outLinearPackedParamsBase17   virtual at::Tensor& apply_out(
18       const at::Tensor& /*input*/,
19       double /*output_scale*/,
20       int64_t /*output_zero_point*/,
21       at::Tensor& output) {
22     throw std::runtime_error(
23         "apply_out is not implemented for this packed "
24         "parameter type");
25     return output;
26   }
27 
apply_relu_outLinearPackedParamsBase28   virtual at::Tensor& apply_relu_out(
29       const at::Tensor& /*input*/,
30       double /*output_scale*/,
31       int64_t /*output_zero_point*/,
32       at::Tensor& output) {
33     throw std::runtime_error(
34         "apply_relu_out is not implemented for this packed "
35         "parameter type");
36     return output;
37   }
38 
39   // Corresponding pattern (the ops with `*` are part of the pattern that
40   // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32):
41   // input -> q* -> dq* -> linear* ->
42   //         qweight -> dq* /
43   //
44   // After fusion:
45   // input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* ->
46   //         qweight /
47   //
48   // Additional Note: the weight is packed as well
49   // Params:
50   //    X: float32 Tensor, will be quantized to quint8 in the op
51   //    W_prepack: packed qint8 quantized weight and bias
52   // Returns:
53   //    Y: float32 Tensor
apply_with_input_q_dq_qweight_dq_output_fp32LinearPackedParamsBase54   virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32(
55       at::Tensor input,
56       double input_scale,
57       int64_t input_zero_point) {
58     throw std::runtime_error(
59         "apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed "
60         "parameter type");
61     return {};
62   }
63 
64   // Corresponding pattern (the ops with `*` are part of the pattern that
65   // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32):
66   // input -> q* -> dq* -> linear* -> relu* ->
67   //         qweight -> dq* /
68   //
69   // After fusion:
70   // input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* ->
71   //         qweight /
72   //
73   // Additional Note: the weight is packed as well
74   // Params:
75   //    input: float32 Tensor, will be quantized to quint8 in the op
76   // Returns:
77   //    float32 Tensor
apply_with_input_q_dq_qweight_dq_relu_output_fp32LinearPackedParamsBase78   virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32(
79       at::Tensor input,
80       double input_scale,
81       int64_t input_zero_point) {
82     throw std::runtime_error(
83         "apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed "
84         "parameter type");
85     return {};
86   }
87 
88   virtual at::Tensor apply_dynamic(
89       at::Tensor input,
90       bool reduce_range = false) = 0;
91   virtual at::Tensor apply_dynamic_relu(
92       at::Tensor input,
93       bool reduce_range = false) = 0;
94 
apply_dynamic_outLinearPackedParamsBase95   virtual at::Tensor& apply_dynamic_out(
96       const at::Tensor& /* input */,
97       at::Tensor& output,
98       bool /* reduce_range */) {
99     throw std::runtime_error(
100         "apply_dynamic_out is not implemented for this packed "
101         "parameter type");
102     return output;
103   }
apply_dynamic_relu_outLinearPackedParamsBase104   virtual at::Tensor& apply_dynamic_relu_out(
105       const at::Tensor& /* input */,
106       at::Tensor& output,
107       bool /* reduce_range */) {
108     throw std::runtime_error(
109         "apply_dynamic_relu_out is not implemented for this packed "
110         "parameter type");
111     return output;
112   }
113 
114   virtual std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() = 0;
115 
116   virtual std::optional<at::Tensor> bias() = 0;
117 
set_biasLinearPackedParamsBase118   virtual void set_bias(std::optional<at::Tensor> /*bias*/) {
119     throw std::runtime_error(
120         "set_bias is not implemented for this packed "
121         "parameter type");
122   }
123 };
124 
125 template <int kSpatialDim = 2>
126 struct ConvPackedParamsBase : public torch::jit::CustomClassHolder {
127   virtual at::Tensor apply(
128       const at::Tensor& input,
129       double output_scale,
130       int64_t output_zero_point) = 0;
131   virtual at::Tensor apply_relu(
132       const at::Tensor& input,
133       double output_scale,
134       int64_t output_zero_point) = 0;
135   virtual at::Tensor apply_dynamic(
136       const at::Tensor& input,
137       bool reduce_range) = 0;
138 
139   virtual std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() = 0;
140 
141   virtual torch::List<int64_t> stride() const = 0;
142   virtual torch::List<int64_t> padding() const = 0;
143   virtual torch::List<int64_t> output_padding() const = 0;
144   virtual torch::List<int64_t> dilation() const = 0;
145   virtual int64_t groups() const = 0;
146   virtual bool transpose() const = 0;
147 };
148