1 #include <ATen/ATen.h>
2 #include <ATen/Dispatch.h>
3 #include <ATen/NativeFunctions.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/native/cpu/Loops.h>
6 #include <ATen/native/quantized/FakeQuantAffine.h>
7
8 // FakeQuantize Op for PerTensorAffine quantization scheme.
9
10 namespace at::native {
11
12 // Use REGISTER_DISPATCH to run CPU and CUDA backend.
13 DEFINE_DISPATCH(fake_quant_tensor_cachemask_stub);
14 DEFINE_DISPATCH(fake_quant_grad_learnable_tensor_stub);
15 DEFINE_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_stub);
16
17 /* Fake-quantizes the 'inputs' tensor.
18
19 Args:
20 self: Forward input tensor.
21 dY: Backward input tensor (_backward op only).
22 scale: scale of per tensor affine quantization
23 zero_point: zero_point of per tensor affine quantization
24 quant_min: minimum quantized value
25 quant_max: maximum quantized value
26
27 Returns:
28 Quantized tensor (double dtype).
29
30 */
fake_quantize_per_tensor_affine(const Tensor & self,double scale,int64_t zero_point,int64_t quant_min,int64_t quant_max)31 Tensor fake_quantize_per_tensor_affine(
32 const Tensor& self,
33 double scale,
34 int64_t zero_point,
35 int64_t quant_min,
36 int64_t quant_max) {
37 const auto res = at::fake_quantize_per_tensor_affine_cachemask(
38 self, scale, zero_point, quant_min, quant_max);
39 return std::get<0>(res);
40 }
41
fake_quantize_per_tensor_affine(const Tensor & self,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max)42 Tensor fake_quantize_per_tensor_affine(
43 const Tensor& self,
44 const Tensor& scale,
45 const Tensor& zero_point,
46 int64_t quant_min,
47 int64_t quant_max) {
48 const auto res = at::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
49 self, scale, zero_point, at::ones(1, self.options().dtype(at::kLong)), quant_min, quant_max);
50 return std::get<0>(res);
51 }
52
53 /* Fake-quantizes the 'inputs' tensor, saving a mask for the backward pass.
54
55 This is numerically equivalent to `fake_quantize_per_tensor_affine`,
56 but has a lower memory overhead in the backward pass.
57
58 Args:
59 self: Forward input tensor.
60 scale: scale of per tensor affine quantization
61 zero_point: zero_point of per tensor affine quantization
62 quant_min: minimum quantized value
63 quant_max: maximum quantized value
64
65 Returns:
66 Quantized tensor (double dtype).
67 Mask (bool dtype).
68 */
fake_quantize_per_tensor_affine_cachemask(const Tensor & self,double scale,int64_t zero_point,int64_t quant_min,int64_t quant_max)69 std::tuple<Tensor, Tensor> fake_quantize_per_tensor_affine_cachemask(
70 const Tensor& self,
71 double scale,
72 int64_t zero_point,
73 int64_t quant_min,
74 int64_t quant_max) {
75 TORCH_CHECK(
76 quant_min <= quant_max,
77 "`quant_min` should be less than or \
78 equal to `quant_max`.");
79 TORCH_CHECK(
80 zero_point >= quant_min && zero_point <= quant_max,
81 "`zero_point` must be between `quant_min` and `quant_max`.");
82
83 auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
84 auto mask = at::empty_like(self, at::kBool, MemoryFormat::Preserve);
85 fake_quant_tensor_cachemask_stub(
86 self.device().type(), Y, mask, self, scale, zero_point, quant_min, quant_max);
87 // TODO(future, optional): look into packing the mask further (BoolTensor uses
88 // 1 byte per element, we only need 1 bit per element).
89 return std::make_tuple(Y, mask);
90 }
91
_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(const Tensor & self,const Tensor & scale,const Tensor & zero_point,const Tensor & fake_quant_enabled,int64_t quant_min,int64_t quant_max)92 std::tuple<Tensor, Tensor> _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
93 const Tensor& self,
94 const Tensor& scale,
95 const Tensor& zero_point,
96 const Tensor& fake_quant_enabled,
97 int64_t quant_min,
98 int64_t quant_max) {
99 TORCH_CHECK(
100 quant_min <= quant_max,
101 "`quant_min` should be less than or \
102 equal to `quant_max`.");
103 auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
104 auto mask = at::empty_like(self, at::kBool, MemoryFormat::Preserve);
105 fake_quant_tensor_cachemask_tensor_qparams_stub(
106 self.device().type(), Y, mask, self, scale, zero_point, fake_quant_enabled, quant_min, quant_max);
107 // TODO(future, optional): look into packing the mask further (BoolTensor uses
108 // 1 byte per element, we only need 1 bit per element).
109 return std::make_tuple(Y, mask);
110 }
111
112 /* Backward path to fake-quantize the 'inputs' tensor, with mask.
113
114 Args:
115 dY: output grad.
116 mask: mask tensor from the forward pass.
117
118 Returns:
119 dX (input grad).
120 */
fake_quantize_per_tensor_affine_cachemask_backward(const Tensor & dY,const Tensor & mask)121 Tensor fake_quantize_per_tensor_affine_cachemask_backward(
122 const Tensor& dY,
123 const Tensor& mask) {
124 TORCH_CHECK(mask.scalar_type() == ScalarType::Bool);
125 TORCH_CHECK(mask.sym_numel() == dY.sym_numel(),
126 "`mask` and `dY` are not the same size: ",
127 "`mask` is size ", mask.sym_numel(), " and `dY` is size ", dY.sym_numel());
128 if (dY.sym_numel() <= 0) {
129 return dY;
130 }
131 // Note: no additional kernels needed, since mask is pre-computed
132 // and we can use the existing tensor multiplication kernels.
133 return dY * mask;
134 }
135
_get_zero_point_from_tensor(const Tensor & zero_point,int64_t quant_min,int64_t quant_max,bool is_forward)136 static int64_t _get_zero_point_from_tensor(
137 const Tensor& zero_point,
138 int64_t quant_min,
139 int64_t quant_max,
140 bool is_forward) {
141 float zero_point_fp = zero_point[0].item<float>();
142 zero_point_fp = is_forward ? std::nearbyint(zero_point_fp) : zero_point_fp + 0.5f;
143 float zero_point_clamped = std::min(std::max(zero_point_fp, static_cast<float>(quant_min)),
144 static_cast<float>(quant_max));
145 return static_cast<int64_t>(zero_point_clamped);
146 }
147
_fake_quantize_learnable_per_tensor_affine(const Tensor & self,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,double grad_factor)148 Tensor _fake_quantize_learnable_per_tensor_affine(
149 const Tensor& self,
150 const Tensor& scale,
151 const Tensor& zero_point,
152 int64_t quant_min,
153 int64_t quant_max,
154 double grad_factor) {
155 float scale_val = scale[0].item<float>();
156 int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point, quant_min, quant_max, true);
157 return native::fake_quantize_per_tensor_affine(
158 self, scale_val, zero_point_val, quant_min, quant_max);
159 }
160
_fake_quantize_learnable_per_tensor_affine_backward(const Tensor & dY,const Tensor & X,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,double grad_factor)161 std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_backward(
162 const Tensor& dY,
163 const Tensor& X,
164 const Tensor& scale,
165 const Tensor& zero_point,
166 int64_t quant_min,
167 int64_t quant_max,
168 double grad_factor) {
169 /* The gradients for scale and zero point are calculated as below:
170 Let Xfq be the fake quantized version of X.
171 Let Xq be the quantized version of X (clamped at qmin and qmax).
172 Let Delta and z be the scale and the zero point.
173 :math:
174 \frac{d\Delta }{dx} =
175 \begin{cases}
176 q_{\min} - z& \text{ if } X_q= q_{\min} \\
177 q_{\max} - z& \text{ if } X_q= q_{\max} \\
178 (X_{fq} - X) / \Delta & \text{ else }
179 \end{cases}
180
181 \frac{dz }{dx} =
182 \begin{cases}
183 -\Delta& \text{ if } X_q= q_{\min} \text{ or } X_q = q_{\max} \\
184 0 & \text{ else }
185 \end{cases}
186 */
187 float scale_val = scale[0].item<float>();
188 float inv_scale_val = 1.0f / scale_val;
189 int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point, quant_min, quant_max, false);
190
191 TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
192 TORCH_CHECK(X.scalar_type() == ScalarType::Float);
193 TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
194 TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
195 TORCH_CHECK(X.numel() == dY.numel(), "`X` and `dY` are not the same size");
196 TORCH_CHECK(
197 quant_min <= 0 && quant_max >= 0,
198 "`quant_min` should be less than or \
199 equal to `quant_max`, and the quantization range should include 0.");
200 TORCH_CHECK(
201 zero_point_val >= quant_min && zero_point_val <= quant_max,
202 "`zero_point` must be between `quant_min` and `quant_max`.");
203 if (X.numel() <= 0) {
204 return std::make_tuple(X, scale, zero_point);
205 }
206
207 auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
208 auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
209 auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
210
211 auto iter = TensorIteratorConfig()
212 .add_output(dX)
213 .add_output(dScale_vec)
214 .add_output(dZeroPoint_vec)
215 .add_input(X)
216 .add_input(dY)
217 .build();
218
219 fake_quant_grad_learnable_tensor_stub(
220 X.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor);
221
222 // The total sums over the scale and zero point gradient vectors are what will be returned in the end.
223 auto dScale = dScale_vec.sum().unsqueeze(0).to(scale.device());
224 auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point.device());
225
226 return std::make_tuple(dX, dScale, dZeroPoint);
227 }
228
229 } // namespace at::native
230