xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/Dispatch.h>
3 #include <ATen/NativeFunctions.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/native/cpu/Loops.h>
6 #include <ATen/native/quantized/FakeQuantAffine.h>
7 
8 // FakeQuantize Op for PerTensorAffine quantization scheme.
9 
10 namespace at::native {
11 
12 // Use REGISTER_DISPATCH to run CPU and CUDA backend.
13 DEFINE_DISPATCH(fake_quant_tensor_cachemask_stub);
14 DEFINE_DISPATCH(fake_quant_grad_learnable_tensor_stub);
15 DEFINE_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_stub);
16 
17 /* Fake-quantizes the 'inputs' tensor.
18 
19 Args:
20   self: Forward input tensor.
21   dY: Backward input tensor (_backward op only).
22   scale: scale of per tensor affine quantization
23   zero_point: zero_point of per tensor affine quantization
24   quant_min: minimum quantized value
25   quant_max: maximum quantized value
26 
27 Returns:
28   Quantized tensor (double dtype).
29 
30 */
fake_quantize_per_tensor_affine(const Tensor & self,double scale,int64_t zero_point,int64_t quant_min,int64_t quant_max)31 Tensor fake_quantize_per_tensor_affine(
32     const Tensor& self,
33     double scale,
34     int64_t zero_point,
35     int64_t quant_min,
36     int64_t quant_max) {
37   const auto res = at::fake_quantize_per_tensor_affine_cachemask(
38       self, scale, zero_point, quant_min, quant_max);
39   return std::get<0>(res);
40 }
41 
fake_quantize_per_tensor_affine(const Tensor & self,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max)42 Tensor fake_quantize_per_tensor_affine(
43     const Tensor& self,
44     const Tensor& scale,
45     const Tensor& zero_point,
46     int64_t quant_min,
47     int64_t quant_max) {
48   const auto res = at::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
49       self, scale, zero_point, at::ones(1, self.options().dtype(at::kLong)), quant_min, quant_max);
50   return std::get<0>(res);
51 }
52 
53 /* Fake-quantizes the 'inputs' tensor, saving a mask for the backward pass.
54 
55 This is numerically equivalent to `fake_quantize_per_tensor_affine`,
56 but has a lower memory overhead in the backward pass.
57 
58 Args:
59   self: Forward input tensor.
60   scale: scale of per tensor affine quantization
61   zero_point: zero_point of per tensor affine quantization
62   quant_min: minimum quantized value
63   quant_max: maximum quantized value
64 
65 Returns:
66   Quantized tensor (double dtype).
67   Mask (bool dtype).
68 */
fake_quantize_per_tensor_affine_cachemask(const Tensor & self,double scale,int64_t zero_point,int64_t quant_min,int64_t quant_max)69 std::tuple<Tensor, Tensor> fake_quantize_per_tensor_affine_cachemask(
70     const Tensor& self,
71     double scale,
72     int64_t zero_point,
73     int64_t quant_min,
74     int64_t quant_max) {
75   TORCH_CHECK(
76       quant_min <= quant_max,
77       "`quant_min` should be less than or \
78         equal to `quant_max`.");
79   TORCH_CHECK(
80       zero_point >= quant_min && zero_point <= quant_max,
81       "`zero_point` must be between `quant_min` and `quant_max`.");
82 
83   auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
84   auto mask = at::empty_like(self, at::kBool, MemoryFormat::Preserve);
85   fake_quant_tensor_cachemask_stub(
86       self.device().type(), Y, mask, self, scale, zero_point, quant_min, quant_max);
87   // TODO(future, optional): look into packing the mask further (BoolTensor uses
88   //   1 byte per element, we only need 1 bit per element).
89   return std::make_tuple(Y, mask);
90 }
91 
_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(const Tensor & self,const Tensor & scale,const Tensor & zero_point,const Tensor & fake_quant_enabled,int64_t quant_min,int64_t quant_max)92 std::tuple<Tensor, Tensor> _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
93     const Tensor& self,
94     const Tensor& scale,
95     const Tensor& zero_point,
96     const Tensor& fake_quant_enabled,
97     int64_t quant_min,
98     int64_t quant_max) {
99   TORCH_CHECK(
100       quant_min <= quant_max,
101       "`quant_min` should be less than or \
102         equal to `quant_max`.");
103   auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
104   auto mask = at::empty_like(self, at::kBool, MemoryFormat::Preserve);
105   fake_quant_tensor_cachemask_tensor_qparams_stub(
106       self.device().type(), Y, mask, self, scale, zero_point, fake_quant_enabled, quant_min, quant_max);
107   // TODO(future, optional): look into packing the mask further (BoolTensor uses
108   //   1 byte per element, we only need 1 bit per element).
109   return std::make_tuple(Y, mask);
110 }
111 
112 /* Backward path to fake-quantize the 'inputs' tensor, with mask.
113 
114 Args:
115   dY: output grad.
116   mask: mask tensor from the forward pass.
117 
118 Returns:
119   dX (input grad).
120 */
fake_quantize_per_tensor_affine_cachemask_backward(const Tensor & dY,const Tensor & mask)121 Tensor fake_quantize_per_tensor_affine_cachemask_backward(
122     const Tensor& dY,
123     const Tensor& mask) {
124   TORCH_CHECK(mask.scalar_type() == ScalarType::Bool);
125   TORCH_CHECK(mask.sym_numel() == dY.sym_numel(),
126       "`mask` and `dY` are not the same size: ",
127       "`mask` is size ", mask.sym_numel(), " and `dY` is size ", dY.sym_numel());
128   if (dY.sym_numel() <= 0) {
129     return dY;
130   }
131   // Note: no additional kernels needed, since mask is pre-computed
132   // and we can use the existing tensor multiplication kernels.
133   return dY * mask;
134 }
135 
_get_zero_point_from_tensor(const Tensor & zero_point,int64_t quant_min,int64_t quant_max,bool is_forward)136 static int64_t _get_zero_point_from_tensor(
137     const Tensor& zero_point,
138     int64_t quant_min,
139     int64_t quant_max,
140     bool is_forward) {
141   float zero_point_fp = zero_point[0].item<float>();
142   zero_point_fp = is_forward ? std::nearbyint(zero_point_fp) : zero_point_fp + 0.5f;
143   float zero_point_clamped = std::min(std::max(zero_point_fp, static_cast<float>(quant_min)),
144                                        static_cast<float>(quant_max));
145   return static_cast<int64_t>(zero_point_clamped);
146 }
147 
_fake_quantize_learnable_per_tensor_affine(const Tensor & self,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,double grad_factor)148 Tensor _fake_quantize_learnable_per_tensor_affine(
149     const Tensor& self,
150     const Tensor& scale,
151     const Tensor& zero_point,
152     int64_t quant_min,
153     int64_t quant_max,
154     double grad_factor) {
155   float scale_val = scale[0].item<float>();
156   int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point, quant_min, quant_max, true);
157   return native::fake_quantize_per_tensor_affine(
158     self, scale_val, zero_point_val, quant_min, quant_max);
159 }
160 
_fake_quantize_learnable_per_tensor_affine_backward(const Tensor & dY,const Tensor & X,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,double grad_factor)161 std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_backward(
162     const Tensor& dY,
163     const Tensor& X,
164     const Tensor& scale,
165     const Tensor& zero_point,
166     int64_t quant_min,
167     int64_t quant_max,
168     double grad_factor) {
169   /* The gradients for scale and zero point are calculated as below:
170      Let Xfq be the fake quantized version of X.
171      Let Xq be the quantized version of X (clamped at qmin and qmax).
172      Let Delta and z be the scale and the zero point.
173      :math:
174       \frac{d\Delta }{dx} =
175         \begin{cases}
176           q_{\min} - z& \text{ if } X_q= q_{\min} \\
177           q_{\max} - z& \text{ if } X_q= q_{\max} \\
178           (X_{fq} - X) / \Delta & \text{ else }
179         \end{cases}
180 
181       \frac{dz }{dx} =
182         \begin{cases}
183           -\Delta& \text{ if } X_q= q_{\min} \text{ or } X_q = q_{\max} \\
184           0 & \text{ else }
185         \end{cases}
186   */
187   float scale_val = scale[0].item<float>();
188   float inv_scale_val = 1.0f / scale_val;
189   int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point, quant_min, quant_max, false);
190 
191   TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
192   TORCH_CHECK(X.scalar_type() == ScalarType::Float);
193   TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
194   TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
195   TORCH_CHECK(X.numel() == dY.numel(), "`X` and `dY` are not the same size");
196   TORCH_CHECK(
197       quant_min <= 0 && quant_max >= 0,
198       "`quant_min` should be less than or \
199         equal to `quant_max`, and the quantization range should include 0.");
200   TORCH_CHECK(
201       zero_point_val >= quant_min && zero_point_val <= quant_max,
202       "`zero_point` must be between `quant_min` and `quant_max`.");
203   if (X.numel() <= 0) {
204     return std::make_tuple(X, scale, zero_point);
205   }
206 
207   auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
208   auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
209   auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
210 
211   auto iter = TensorIteratorConfig()
212     .add_output(dX)
213     .add_output(dScale_vec)
214     .add_output(dZeroPoint_vec)
215     .add_input(X)
216     .add_input(dY)
217     .build();
218 
219   fake_quant_grad_learnable_tensor_stub(
220     X.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor);
221 
222   // The total sums over the scale and zero point gradient vectors are what will be returned in the end.
223   auto dScale = dScale_vec.sum().unsqueeze(0).to(scale.device());
224   auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point.device());
225 
226   return std::make_tuple(dX, dScale, dZeroPoint);
227 }
228 
229 } // namespace at::native
230