xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/quantized_instance_norm.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
19 #define USE_NEON
20 #include <arm_neon.h>
21 #endif
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/kernels/quantization_utils.h"
30 
31 #ifdef USE_NEON
32 namespace {
33 
34 // Single pass mean and variance.
35 // Shape of `input` is [rows x cols], shape of both `mean` and `variance`
36 // is [cols].
37 // Note, `mean` and `variance` are of 'i' (not scaled).
38 // The following is a straightforward implementation of the parallel algorithm
39 // described in
40 // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
ColMeanAndVariance(const uint8_t * input,const uint32_t rows,const uint32_t cols,float * mean,float * variance)41 void ColMeanAndVariance(const uint8_t* input, const uint32_t rows,
42                         const uint32_t cols, float* mean, float* variance) {
43   // The implementation operates on for 16 columns at a time.
44   // Assumes cols % 16 == 0
45   for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
46     // Vector registers to track the running sum across the rows. Since there
47     // are 16 columns, we have 4 32x4 registers.
48     uint32x4_t sum[4] = {0};
49 
50     float nA = 0.0f;
51     // Running average and the second moment.
52     float32x4_t xA[4] = {0.0f};
53     float32x4_t M2A[4] = {0.0f};
54 
55     const uint8_t* inp_ptr = input + col_offset;
56     // Go over the rows in chunks of 256. This is so that we can use 16 bit adds
57     // to do the accumulation.
58     for (uint32_t row = 0; row < rows; row += 256) {
59       // Running sum and sum of squares for the 256 rows.
60       uint32x4_t sub_sum[4] = {0};
61       uint32x4_t sub_sq_sum[4] = {0};
62       const uint32_t limit = std::min(rows, row + 256);
63       const float nB = limit - row;
64       for (uint32_t subrow = row; subrow < limit; ++subrow) {
65         const uint8x16_t v = vld1q_u8(inp_ptr);
66         inp_ptr += cols;
67 
68         const uint8x8_t v_high = vget_high_u8(v);
69         const uint8x8_t v_low = vget_low_u8(v);
70 
71         const uint16x8_t v_high_u16 = vmovl_u8(v_high);
72         const uint16x8_t v_low_u16 = vmovl_u8(v_low);
73 
74         const uint16x4_t v_high_high = vget_high_u16(v_high_u16);
75         const uint16x4_t v_high_low = vget_low_u16(v_high_u16);
76         const uint16x4_t v_low_high = vget_high_u16(v_low_u16);
77         const uint16x4_t v_low_low = vget_low_u16(v_low_u16);
78 
79         sub_sum[0] = vaddw_u16(sub_sum[0], v_high_high);
80         sub_sum[1] = vaddw_u16(sub_sum[1], v_high_low);
81         sub_sum[2] = vaddw_u16(sub_sum[2], v_low_high);
82         sub_sum[3] = vaddw_u16(sub_sum[3], v_low_low);
83 
84         sub_sq_sum[0] = vmlal_u16(sub_sq_sum[0], v_high_high, v_high_high);
85         sub_sq_sum[1] = vmlal_u16(sub_sq_sum[1], v_high_low, v_high_low);
86         sub_sq_sum[2] = vmlal_u16(sub_sq_sum[2], v_low_high, v_low_high);
87         sub_sq_sum[3] = vmlal_u16(sub_sq_sum[3], v_low_low, v_low_low);
88       }
89 
90       // Update the full running sum and moment from the ones for 256 rows.
91       for (int i = 0; i < 4; ++i) {
92         sum[i] = vaddq_u32(sum[i], sub_sum[i]);
93         const float nX = nA + nB;
94         // xB is the average of up to 256 elements.
95         const float32x4_t xB =
96             vmulq_n_f32(vcvtq_f32_u32(sub_sum[i]), 1.0f / nB);
97 
98         // delta = xB - xA
99         const float32x4_t delta = vsubq_f32(xB, xA[i]);
100         // xA = (nA * xA + nB * xB) / (nA + nB)
101         xA[i] = vmulq_n_f32(
102             vaddq_f32(vmulq_n_f32(xA[i], nA), vmulq_n_f32(xB, nB)), 1.0f / nX);
103 
104         const float32x4_t sub_sum_f32 = vcvtq_f32_u32(sub_sum[i]);
105         const float32x4_t sub_sum_sq = vmulq_f32(sub_sum_f32, sub_sum_f32);
106 
107         // M2B = sum(xB^2) - sum(xB)^2/nB
108         const float32x4_t M2B = vsubq_f32(vcvtq_f32_u32(sub_sq_sum[i]),
109                                           vmulq_n_f32(sub_sum_sq, 1.0f / nB));
110         const float32x4_t last_term =
111             vmulq_n_f32(vmulq_f32(delta, delta), nA * nB / nX);
112         // M2A = oldM2A + M2B + delta^2 * nA*nB/nX
113         M2A[i] = vaddq_f32(vaddq_f32(M2A[i], M2B), last_term);
114       }
115       nA += limit;
116     }
117 
118     // Write the final mean and variance for the 16 columns.
119     const float inv_rows = 1.0f / static_cast<float>(rows);
120     vst1q_f32(mean + col_offset, vmulq_n_f32(vcvtq_f32_u32(sum[3]), inv_rows));
121     vst1q_f32(mean + col_offset + 4,
122               vmulq_n_f32(vcvtq_f32_u32(sum[2]), inv_rows));
123     vst1q_f32(mean + col_offset + 8,
124               vmulq_n_f32(vcvtq_f32_u32(sum[1]), inv_rows));
125     vst1q_f32(mean + col_offset + 12,
126               vmulq_n_f32(vcvtq_f32_u32(sum[0]), inv_rows));
127 
128     vst1q_f32(variance + col_offset, vmulq_n_f32(M2A[3], inv_rows));
129     vst1q_f32(variance + col_offset + 4, vmulq_n_f32(M2A[2], inv_rows));
130     vst1q_f32(variance + col_offset + 8, vmulq_n_f32(M2A[1], inv_rows));
131     vst1q_f32(variance + col_offset + 12, vmulq_n_f32(M2A[0], inv_rows));
132   }
133 }
134 
135 // Compute min and max of (input - mean) / sqrt(variance + epsilon).
136 // This is done in a separate pass so that the normalized value can be
137 // temporarily computed in floating point precision and not stored anywhere.
MinAndMax(const uint8_t * input,const uint32_t rows,const uint32_t cols,const float * mean_ptr,const float * variance_ptr,float variance_epsilon,float * minimum,float * maximum)138 void MinAndMax(const uint8_t* input, const uint32_t rows, const uint32_t cols,
139                const float* mean_ptr, const float* variance_ptr,
140                float variance_epsilon, float* minimum, float* maximum) {
141   float v_maximum = std::numeric_limits<float>::min();
142   float v_minimum = std::numeric_limits<float>::max();
143   const float32x4_t eps = vdupq_n_f32(variance_epsilon);
144 
145   for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
146     const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset),
147                                  vld1q_f32(mean_ptr + col_offset + 4),
148                                  vld1q_f32(mean_ptr + col_offset + 8),
149                                  vld1q_f32(mean_ptr + col_offset + 12)};
150     const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset),
151                                      vld1q_f32(variance_ptr + col_offset + 4),
152                                      vld1q_f32(variance_ptr + col_offset + 8),
153                                      vld1q_f32(variance_ptr + col_offset + 12)};
154     const float32x4_t inv_stddev[4] = {
155         vrsqrteq_f32(vaddq_f32(variance[0], eps)),
156         vrsqrteq_f32(vaddq_f32(variance[1], eps)),
157         vrsqrteq_f32(vaddq_f32(variance[2], eps)),
158         vrsqrteq_f32(vaddq_f32(variance[3], eps))};
159 
160     const uint8_t* inp_ptr = input + col_offset;
161     for (uint32_t row = 0; row < rows; ++row) {
162       const uint8x16_t v = vld1q_u8(inp_ptr);
163       inp_ptr += cols;
164 
165       const uint16x8_t v_high = vmovl_u8(vget_high_u8(v));
166       const uint16x8_t v_low = vmovl_u8(vget_low_u8(v));
167 
168       const float32x4_t v_float[4] = {
169           vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))),
170           vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))),
171           vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))),
172           vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))};
173 
174       for (int i = 0; i < 4; ++i) {
175         const float32x4_t normed =
176             vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]);
177         const float32x2_t high = vget_high_f32(normed);
178         const float32x2_t low = vget_low_f32(normed);
179         float32x2_t tmp_max = vpmax_f32(low, high);
180         tmp_max = vpmax_f32(tmp_max, tmp_max);
181         v_maximum = std::max(v_maximum, vget_lane_f32(tmp_max, 0));
182         float32x2_t tmp_min = vpmin_f32(low, high);
183         tmp_min = vpmin_f32(tmp_min, tmp_min);
184         v_minimum = std::min(v_minimum, vget_lane_f32(tmp_min, 0));
185       }
186     }
187   }
188   *minimum = v_minimum;
189   *maximum = v_maximum;
190 }
191 
192 // Compute (input - mean) / sqrt(variance + epsilon) in floating point, quantize
193 // it in the range (minimum, maximum) and store the result as quint8.
InstanceNorm(const uint8_t * input,const uint32_t rows,const uint32_t cols,const float * mean_ptr,const float * variance_ptr,float variance_epsilon,float minimum,float maximum,uint8_t * output)194 void InstanceNorm(const uint8_t* input, const uint32_t rows,
195                   const uint32_t cols, const float* mean_ptr,
196                   const float* variance_ptr, float variance_epsilon,
197                   float minimum, float maximum, uint8_t* output) {
198   const float32x4_t eps = vdupq_n_f32(variance_epsilon);
199   const float32x4_t out_min = vdupq_n_f32(minimum);
200   const float out_scale = 255.0f / (maximum - minimum);
201 
202   for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
203     const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset + 12),
204                                  vld1q_f32(mean_ptr + col_offset + 8),
205                                  vld1q_f32(mean_ptr + col_offset + 4),
206                                  vld1q_f32(mean_ptr + col_offset)};
207     const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset + 12),
208                                      vld1q_f32(variance_ptr + col_offset + 8),
209                                      vld1q_f32(variance_ptr + col_offset + 4),
210                                      vld1q_f32(variance_ptr + col_offset)};
211     const float32x4_t inv_stddev[4] = {
212         vrsqrteq_f32(vaddq_f32(variance[0], eps)),
213         vrsqrteq_f32(vaddq_f32(variance[1], eps)),
214         vrsqrteq_f32(vaddq_f32(variance[2], eps)),
215         vrsqrteq_f32(vaddq_f32(variance[3], eps))};
216     const uint8_t* inp_ptr = input + col_offset;
217     uint8_t* out_ptr = output + col_offset;
218     for (uint32_t row = 0; row < rows; ++row) {
219       const uint8x16_t v = vld1q_u8(inp_ptr);
220       inp_ptr += cols;
221       const uint16x8_t v_high = vmovl_u8(vget_high_u8(v));
222       const uint16x8_t v_low = vmovl_u8(vget_low_u8(v));
223 
224       const float32x4_t v_float[4] = {
225           vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))),
226           vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))),
227           vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))),
228           vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))};
229 
230       uint16x4_t normed_uint16[4];
231       for (int i = 0; i < 4; ++i) {
232         const float32x4_t normed =
233             vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]);
234         const int32x4_t normed_int32 =
235             vcvtq_s32_f32(vmulq_n_f32(vsubq_f32(normed, out_min), out_scale));
236         normed_uint16[i] = vqmovun_s32(normed_int32);
237       }
238       vst1_u8(out_ptr,
239               vqmovn_u16(vcombine_u16(normed_uint16[3], normed_uint16[2])));
240       vst1_u8(out_ptr + 8,
241               vqmovn_u16(vcombine_u16(normed_uint16[1], normed_uint16[0])));
242       out_ptr += cols;
243     }
244   }
245 }
246 
247 }  // end namespace
248 #endif  // USE_NEON
249 
250 namespace tensorflow {
251 
252 typedef Eigen::ThreadPoolDevice CPUDevice;
253 
254 class QuantizedInstanceNorm : public OpKernel {
255  public:
QuantizedInstanceNorm(OpKernelConstruction * context)256   explicit QuantizedInstanceNorm(OpKernelConstruction* context)
257       : OpKernel(context) {
258     OP_REQUIRES_OK(context,
259                    context->GetAttr("variance_epsilon", &variance_epsilon_));
260     OP_REQUIRES_OK(context,
261                    context->GetAttr("min_separation", &min_separation_));
262     OP_REQUIRES_OK(
263         context, context->GetAttr("output_range_given", &output_range_given_));
264     if (output_range_given_) {
265       OP_REQUIRES_OK(context, context->GetAttr("given_y_min", &given_y_min_));
266       OP_REQUIRES_OK(context, context->GetAttr("given_y_max", &given_y_max_));
267       OP_REQUIRES(context, given_y_min_ < given_y_max_,
268                   errors::InvalidArgument(
269                       "given_y_min must be less than given_y_max : ",
270                       given_y_min_, " >= ", given_y_max_));
271     }
272   }
273 
Compute(OpKernelContext * context)274   void Compute(OpKernelContext* context) override {
275     const Tensor& input = context->input(0);
276 
277     const Tensor& x_min = context->input(1);
278     const Tensor& x_max = context->input(2);
279     OP_REQUIRES(context, TensorShapeUtils::IsScalar(x_min.shape()),
280                 errors::InvalidArgument("`x_min` must be rank 0 but is rank ",
281                                         x_min.dims()));
282     OP_REQUIRES(context, TensorShapeUtils::IsScalar(x_max.shape()),
283                 errors::InvalidArgument("`x_max` must be rank 0 but is rank ",
284                                         x_max.dims()));
285     float input_min = x_min.scalar<float>()();
286     float input_max = x_max.scalar<float>()();
287     float input_scale = (input_max - input_min) / 255.0f;
288 
289     OP_REQUIRES(context, input_min < input_max,
290                 errors::InvalidArgument(
291                     "input_min must be less than input_max : ", input_min,
292                     " >= ", input_max));
293 
294     auto input_tensor = input.tensor<quint8, 4>();
295     auto N = input_tensor.dimension(0);
296     auto H = input_tensor.dimension(1);
297     auto W = input_tensor.dimension(2);
298     auto C = input_tensor.dimension(3);
299 
300     Tensor* output = nullptr;
301     OP_REQUIRES_OK(context,
302                    context->allocate_output(0, input.shape(), &output));
303 
304     Tensor* output_min = nullptr;
305     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
306     Tensor* output_max = nullptr;
307     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
308 
309     typedef TTypes<float>::Tensor::Index Index;
310 
311     const Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2>>
312         reduction_indices;
313     Eigen::IndexList<Eigen::type2index<1>, Index, Index, Eigen::type2index<1>>
314         broadcast_spec;
315     broadcast_spec.set(1, H);
316     broadcast_spec.set(2, W);
317     Eigen::IndexList<Index, Eigen::type2index<1>, Eigen::type2index<1>, Index>
318         expand_spec;
319     expand_spec.set(0, N);
320     expand_spec.set(3, C);
321 
322     Eigen::Tensor<float, 2, Eigen::RowMajor> float_mean(N, C);
323     Eigen::Tensor<float, 2, Eigen::RowMajor> float_variance(N, C);
324 
325 #ifdef USE_NEON
326     if (N == 1 && (C % 16 == 0)) {
327       VLOG(2) << "Calling optimized";
328       ColMeanAndVariance(reinterpret_cast<const uint8_t*>(input_tensor.data()),
329                          H * W, C, float_mean.data(), float_variance.data());
330 
331       float minimum = given_y_min_, maximum = given_y_max_;
332       if (!output_range_given_) {
333         MinAndMax(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W,
334                   C, float_mean.data(), float_variance.data(),
335                   variance_epsilon_, &minimum, &maximum);
336       }
337 
338       if (maximum - minimum < min_separation_) {
339         maximum = minimum + min_separation_;
340       }
341 
342       InstanceNorm(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W,
343                    C, float_mean.data(), float_variance.data(),
344                    variance_epsilon_, minimum, maximum,
345                    reinterpret_cast<uint8_t*>(output->flat<quint8>().data()));
346       output_min->scalar<float>()() = minimum;
347       output_max->scalar<float>()() = maximum;
348     } else  // NOLINT(readability/braces)
349 #endif
350     {
351       VLOG(2) << "Calling unoptimized";
352       float_mean = input_tensor.cast<float>().reduce(
353           reduction_indices, Eigen::internal::MeanReducer<float>());
354 
355       float_variance =
356           (input_scale *
357            ((input_tensor.cast<float>() -
358              float_mean.reshape(expand_spec).broadcast(broadcast_spec))))
359               .square()
360               .reduce(reduction_indices, Eigen::internal::MeanReducer<float>());
361 
362       Eigen::Tensor<float, 4, Eigen::RowMajor> instance_normed =
363           input_scale *
364           (input_tensor.cast<float>() -
365            float_mean.reshape(expand_spec).broadcast(broadcast_spec)) *
366           (float_variance + variance_epsilon_)
367               .rsqrt()
368               .reshape(expand_spec)
369               .broadcast(broadcast_spec);
370 
371       Eigen::Tensor<float, 0, Eigen::RowMajor> normed_min;
372       Eigen::Tensor<float, 0, Eigen::RowMajor> normed_max;
373 
374       if (!output_range_given_) {
375         normed_min = instance_normed.minimum();
376         normed_max = instance_normed.maximum();
377       } else {
378         normed_min() = given_y_min_;
379         normed_max() = given_y_max_;
380       }
381 
382       if (normed_max() - normed_min() < min_separation_) {
383         normed_max() = normed_min() + min_separation_;
384       }
385 
386       FloatToQuantizedStruct<quint8> output_f2q(normed_min(), normed_max());
387       auto instance_normed_quantized =
388           QUANTIZE_WITH_EIGEN(instance_normed, output_f2q, quint8);
389 
390       output->tensor<quint8, 4>().device(
391           context->template eigen_device<CPUDevice>()) =
392           instance_normed_quantized;
393       output_min->flat<float>()(0) = normed_min();
394       output_max->flat<float>()(0) = normed_max();
395     }
396   }
397 
398  private:
399   float variance_epsilon_;
400   float min_separation_;
401   bool output_range_given_;
402   float given_y_min_;
403   float given_y_max_;
404 };
405 
406 REGISTER_KERNEL_BUILDER(Name("QuantizedInstanceNorm")
407                             .Device(DEVICE_CPU)
408                             .TypeConstraint<quint8>("T"),
409                         QuantizedInstanceNorm);
410 
411 }  // namespace tensorflow
412