xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.h"
17 
18 #include <algorithm>
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 
24 #include "absl/strings/substitute.h"
25 #include "tensorflow/lite/delegates/gpu/common/operations.h"
26 #include "tensorflow/lite/delegates/gpu/common/util.h"
27 
28 namespace tflite {
29 namespace gpu {
30 
31 namespace {
32 
CheckIfValidNodeOfType(const Node * node,OperationType required_type)33 absl::Status CheckIfValidNodeOfType(const Node* node,
34                                     OperationType required_type) {
35   if (node == nullptr) {
36     return absl::NotFoundError("Invalid node.");
37   }
38   if (OperationTypeFromString(node->operation.type) != required_type) {
39     return absl::NotFoundError("Type mismatch.");
40   }
41   return absl::OkStatus();
42 }
43 
GetElementwiseScalarValue(const Node * node,float * result)44 absl::Status GetElementwiseScalarValue(const Node* node, float* result) {
45   auto attr = absl::any_cast<ElementwiseAttributes>(node->operation.attributes);
46   const float* value = absl::get_if<float>(&attr.param);
47   if (!value) {
48     return absl::NotFoundError("Not a scalar value inside attributes.");
49   }
50   *result = *value;
51   return absl::OkStatus();
52 }
53 
GetNextSingleNode(const GraphFloat32 & graph,const Node & node,OperationType next_type,Node ** next_node)54 absl::Status GetNextSingleNode(const GraphFloat32& graph, const Node& node,
55                                OperationType next_type, Node** next_node) {
56   auto consumers = graph.FindConsumers(graph.FindOutputs(node.id)[0]->id);
57   if (consumers.size() != 1) {
58     return absl::NotFoundError("Not a single consumer.");
59   }
60   RETURN_IF_ERROR(CheckIfValidNodeOfType(consumers[0], next_type));
61   *next_node = consumers[0];
62   return absl::OkStatus();
63 }
64 
GetReduceCode(const std::string & src_value,const std::string & dst_value,int3 work_group_size,bool two_step)65 std::string GetReduceCode(const std::string& src_value,
66                           const std::string& dst_value, int3 work_group_size,
67                           bool two_step) {
68   int reduction_size = work_group_size.z;
69   std::string mem_name = work_group_size.x * work_group_size.y != 1
70                              ? "shared_mem[LOCAL_ID_1][LOCAL_ID_0]"
71                              : "shared_mem";
72   if (reduction_size <= 8) {
73     std::string result;
74     result += "  {  // reduction\n";
75     result += "    " + mem_name + "[local_id] = " + src_value + ";\n";
76     result += "    LOCAL_MEM_BARRIER;\n";
77     result += "    " + dst_value + " = " + mem_name + "[0];\n";
78     for (int i = 1; i < reduction_size; ++i) {
79       result += "    " + dst_value + " += " + mem_name + "[" +
80                 std::to_string(i) + "];\n";
81     }
82     if (two_step) {
83       result += "    LOCAL_MEM_BARRIER;\n";
84     }
85     result += "  }\n";
86     return result;
87   } else {
88     // In the reduction step add upper half of the still-to-be-summed vector to
89     // the lower half, while taking care of odd sizes and rounding. E.g.:
90     // Number of items still to be summed before: 5
91     // Local memory before: [a, b, c, d, e];
92     // Local memory after: [a+d, b+e, c, d, e];
93     // Threads doing work: id < 2 = floor(5/2)
94     // Offset to the added items: 3 = ceil(5/2)
95     // Number of items still to be summed after: 3 = ceil(5/2)
96     return absl::Substitute(R"(
97   {  // reduction, all threads inside workgroup must execute this code
98     $3[local_id] = $1;
99     LOCAL_MEM_BARRIER;
100     // The number of items still need to be summed
101     int reduction_size = $0;
102     while (reduction_size > 1) {
103       int active_thread_limit = reduction_size / 2;
104       int offset = (reduction_size + 1) / 2;
105       if (local_id < active_thread_limit) {
106         $1 += $3[local_id + offset];
107         $3[local_id] = $1;
108       }
109       LOCAL_MEM_BARRIER;
110       reduction_size = offset;
111     }
112     $2 = $3[0];
113   }
114 )",
115                             reduction_size, src_value, dst_value, mem_name);
116   }
117 }
118 
ZeroClampVec4Code(const std::string & slice_name,const std::string & channels_name,const std::string & value_name)119 std::string ZeroClampVec4Code(const std::string& slice_name,
120                               const std::string& channels_name,
121                               const std::string& value_name) {
122   return absl::Substitute(R"(
123     // no need to check first element, always valid
124     if ($0 * 4 + 1 >= $1) { $2.y = 0.0f; }
125     if ($0 * 4 + 2 >= $1) { $2.z = 0.0f; }
126     if ($0 * 4 + 3 >= $1) { $2.w = 0.0f; }
127 )",
128                           slice_name, channels_name, value_name);
129 }
130 }  // namespace
131 
MeanStdDevNormalization(const OperationDef & definition,const GpuInfo & gpu_info,const BHWC & shape,float variance_bias,bool two_step)132 MeanStdDevNormalization::MeanStdDevNormalization(const OperationDef& definition,
133                                                  const GpuInfo& gpu_info,
134                                                  const BHWC& shape,
135                                                  float variance_bias,
136                                                  bool two_step)
137     : GPUOperation(definition) {
138   const int tensor_slices = DivideRoundUp(shape.c, 4);
139   int desired_work_group_size = gpu_info.GetMaxWorkGroupSizeForZ();
140   if (gpu_info.IsMali()) {
141     // Don't use more than 64 work items per work group on ARM Mali. They
142     // implement local memory using the global memory, larger workgroups have
143     // severe performance penalty.
144     desired_work_group_size = 64;
145   }
146   if (gpu_info.IsAdreno()) {
147     AdrenoInfo info = gpu_info.adreno_info;
148     desired_work_group_size = 256;
149     if (info.IsAdreno3xx()) {
150       if (info.adreno_gpu == AdrenoGpu::kAdreno320 ||
151           info.adreno_gpu == AdrenoGpu::kAdreno330) {
152         desired_work_group_size = 128;
153       } else {
154         desired_work_group_size = 64;
155       }
156     } else if (info.IsAdreno4xx()) {
157       if (info.adreno_gpu == AdrenoGpu::kAdreno430) {
158         desired_work_group_size = 256;
159       } else {
160         desired_work_group_size = 128;
161       }
162     } else if (info.IsAdreno5xx()) {
163       if (info.adreno_gpu == AdrenoGpu::kAdreno530 ||
164           info.adreno_gpu == AdrenoGpu::kAdreno540) {
165         desired_work_group_size = 256;
166       } else {
167         desired_work_group_size = 128;
168       }
169     }
170   }
171   if (gpu_info.IsPowerVR()) {
172     desired_work_group_size = 64;
173   }
174   if (gpu_info.IsApple()) {
175     desired_work_group_size = 64;
176   }
177   if (gpu_info.IsAMD()) {
178     desired_work_group_size = 512;
179   }
180   if (shape.w * shape.h == 1) {
181     desired_work_group_size =
182         std::min(desired_work_group_size, gpu_info.GetMaxWorkGroupSizeForZ());
183     while (desired_work_group_size >= tensor_slices * 2) {
184       desired_work_group_size /= 2;
185     }
186     work_group_size_.x = 1;
187     work_group_size_.y = 1;
188     work_group_size_.z = desired_work_group_size;
189   } else {
190     if (tensor_slices >= 16) {
191       work_group_size_.z = 8;
192     } else if (tensor_slices >= 10) {
193       work_group_size_.z = 4;
194     } else {
195       std::map<int, int> slices_to_group_size = {
196           {1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 3},
197           {6, 3}, {7, 4}, {8, 4}, {9, 3},
198       };
199       work_group_size_.z = slices_to_group_size[tensor_slices];
200     }
201     desired_work_group_size =
202         std::min(desired_work_group_size, gpu_info.GetMaxWorkGroupTotalSize());
203     work_group_size_.x = 1;
204     work_group_size_.y =
205         desired_work_group_size / AlignByN(work_group_size_.z, 4);
206     while (work_group_size_.y > work_group_size_.x) {
207       work_group_size_.y /= 2;
208       work_group_size_.x *= 2;
209     }
210   }
211   args_.AddFloat("variance_bias", variance_bias);
212   args_.AddFloat("inv_ch_count", 1.0f / shape.c);
213   code_ = GetNormalizationCode(gpu_info, shape.c % 4 == 0, two_step);
214 }
215 
GetNormalizationCode(const GpuInfo & gpu_info,bool channels_x4,bool two_step)216 std::string MeanStdDevNormalization::GetNormalizationCode(
217     const GpuInfo& gpu_info, bool channels_x4, bool two_step) {
218   AddSrcTensor("src_tensor", definition_.src_tensors[0]);
219   AddDstTensor("dst_tensor", definition_.dst_tensors[0]);
220 
221   std::string c;
222   if (gpu_info.IsApiOpenCl()) {
223     c += "__attribute__((reqd_work_group_size(" +
224          std::to_string(work_group_size_.x) + ", " +
225          std::to_string(work_group_size_.y) + ", " +
226          std::to_string(work_group_size_.z) + ")))\n";
227   }
228   c += "MAIN_FUNCTION($0) {\n";
229   std::string accum_type = two_step ? "float" : "float2";
230   if (work_group_size_.x * work_group_size_.y == 1) {
231     c += "__local " + accum_type + " shared_mem[" +
232          std::to_string(work_group_size_.z) + "];\n";
233   } else {
234     c += "__local " + accum_type + " shared_mem[" +
235          std::to_string(work_group_size_.x) + "][" +
236          std::to_string(work_group_size_.y) + "][" +
237          std::to_string(work_group_size_.z) + "];\n";
238   }
239   if (definition_.dst_tensors[0].HasAxis(Axis::BATCH)) {
240     c += "  int linear_id = GLOBAL_ID_0;\n";
241     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
242     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
243     c += "  args.src_tensor.SetBatchRef(B);\n";
244     c += "  args.dst_tensor.SetBatchRef(B);\n";
245   } else {
246     c += "  int X = GLOBAL_ID_0;\n";
247   }
248   c += "  int Y = GLOBAL_ID_1;\n";
249   if (!two_step) {
250     c += "  float4 private_sum4_sq = INIT_FLOAT4(0.0f);\n";
251   }
252   c += R"(
253   float4 private_sum4 = INIT_FLOAT4(0.0f);
254   int local_id = LOCAL_ID_2;
255   int reduction_group_size = GROUP_SIZE_2;
256   for (int S = local_id; S < args.src_tensor.Slices(); S += reduction_group_size) {
257     int x_clamped = min(X, args.src_tensor.Width() - 1);
258     int y_clamped = min(Y, args.src_tensor.Height() - 1);
259     float4 t = args.src_tensor.Read<float>(x_clamped, y_clamped, S);)";
260   if (!channels_x4) {
261     c += ZeroClampVec4Code("S", "args.src_tensor.Channels()", "t");
262   }
263   if (two_step) {
264     c += "    private_sum4 += t;\n";
265     c += "  }\n";
266     c += "  float private_sum = dot(private_sum4, INIT_FLOAT4(1.0f));\n";
267     c += "  float sum;\n";
268   } else {
269     c += "    private_sum4 += t;\n";
270     c += "    private_sum4_sq += t * t;\n";
271     c += "  }\n";
272     c += "  float2 private_sum;\n";
273     c += "  private_sum.x = dot(private_sum4, INIT_FLOAT4(1.0f));\n";
274     c += "  private_sum.y = dot(private_sum4_sq, INIT_FLOAT4(1.0f));\n";
275     c += "  float2 sum;\n";
276   }
277   c += GetReduceCode("private_sum", "sum", work_group_size_, two_step);
278   if (two_step) {
279     c += R"(
280   // Calculate the mean
281   float mean = sum * args.inv_ch_count;
282   // Calculate the squared sum of the difference from the mean.
283   float4 private_sum_diff_sq4 = INIT_FLOAT4(0.0f);
284   for (int S = local_id; S < args.src_tensor.Slices(); S += reduction_group_size) {
285     int x_clamped = min(X, args.src_tensor.Width() - 1);
286     int y_clamped = min(Y, args.src_tensor.Height() - 1);
287     float4 t = args.src_tensor.Read<float>(x_clamped, y_clamped, S);
288     float4 diff = t - mean;)";
289     if (!channels_x4) {
290       c += ZeroClampVec4Code("S", "args.src_tensor.Channels()", "diff");
291     }
292     c += R"(
293     private_sum_diff_sq4 += diff * diff;
294   }
295   // Reduce
296   float private_sum_diff_sq = dot(private_sum_diff_sq4, INIT_FLOAT4(1.0f));
297   float sum_diff_sq;
298 )";
299     c += GetReduceCode("private_sum_diff_sq", "sum_diff_sq", work_group_size_,
300                        two_step);
301     c += "  float variance = sum_diff_sq * args.inv_ch_count;\n";
302   } else {
303     c += "  float mean = sum.x * args.inv_ch_count;\n";
304     c += "  float mean_sq = sum.y * args.inv_ch_count;\n";
305     c += "  float variance = mean_sq - mean * mean;\n";
306   }
307   c += R"(
308   // no more shared memory usage, 'useless' threads can exit now
309   if (X >= args.dst_tensor.Width()) { return; }
310   if (Y >= args.dst_tensor.Height()) { return; }
311   // Calculate 1/stddev (with the 'regulazing constant' as in tensor_utils.cc)
312   float stddev_inv = rsqrt(variance + args.variance_bias);
313   // Calculate (t-mean)/stddev for each element
314   for (int S = local_id; S < args.src_tensor.Slices(); S += reduction_group_size) {
315     float4 t = args.src_tensor.Read<float>(X, Y, S);
316     FLT4 result = TO_FLT4((t - mean) * stddev_inv);
317     args.dst_tensor.Write(result, X, Y, S);
318   }
319 })";
320   return c;
321 }
322 
GetGridSize() const323 int3 MeanStdDevNormalization::GetGridSize() const {
324   // To avoid dealing with global reductions, we restrict the grid size to the
325   // work group size in the first dimension.
326   const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
327   const int grid_y = dst_[0]->Height();
328   const int grid_z = work_group_size_.z;
329   return int3(grid_x, grid_y, grid_z);
330 }
331 
CreateMeanStdDevNormalization(const OperationDef & definition,const GpuInfo & gpu_info,const BHWC & shape,float variance_bias,bool two_step)332 MeanStdDevNormalization CreateMeanStdDevNormalization(
333     const OperationDef& definition, const GpuInfo& gpu_info, const BHWC& shape,
334     float variance_bias, bool two_step) {
335   return MeanStdDevNormalization(definition, gpu_info, shape, variance_bias,
336                                  two_step);
337 }
338 
TryMeanStdDevNormalization(const GpuInfo & gpu_info,CalculationsPrecision precision,const GraphFloat32 & graph,NodeId first_node_id,const std::map<ValueId,TensorDescriptor> & tensor_descriptors,std::set<NodeId> * consumed_nodes,GPUOperationsSubgraph * gpu_subgraph)339 absl::Status TryMeanStdDevNormalization(
340     const GpuInfo& gpu_info, CalculationsPrecision precision,
341     const GraphFloat32& graph, NodeId first_node_id,
342     const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
343     std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) {
344   Node* first_mean_node = graph.GetNode(first_node_id);
345   RETURN_IF_ERROR(CheckIfValidNodeOfType(first_mean_node, OperationType::MEAN));
346   auto first_mean_attr =
347       absl::any_cast<MeanAttributes>(first_mean_node->operation.attributes);
348   if (first_mean_attr.dims != std::set<Axis>{Axis::CHANNELS}) {
349     return absl::NotFoundError("MeanStdDevNormalization not suitable.");
350   }
351   Node* sub_node;
352   RETURN_IF_ERROR(GetNextSingleNode(graph, *first_mean_node, OperationType::SUB,
353                                     &sub_node));
354   auto sub_inputs = graph.FindInputs(sub_node->id);
355   if (sub_inputs.size() != 2) {
356     return absl::NotFoundError("MeanStdDevNormalization not suitable.");
357   } else {
358     // checking structure
359     //       input
360     //       /    \
361     //      |    mean
362     //       \    /
363     //     substraction
364     Node* sub_first_parent = graph.FindProducer(sub_inputs[0]->id);
365     Node* sub_second_parent = graph.FindProducer(sub_inputs[1]->id);
366     if (sub_second_parent != first_mean_node) {
367       return absl::NotFoundError("MeanStdDevNormalization not suitable.");
368     }
369     auto mean_inputs = graph.FindInputs(first_mean_node->id);
370     Node* mean_parent = graph.FindProducer(mean_inputs[0]->id);
371     if (mean_parent != sub_first_parent) {
372       return absl::NotFoundError("MeanStdDevNormalization not suitable.");
373     }
374   }
375   auto sub_output = graph.FindOutputs(sub_node->id)[0]->id;
376   auto consumers = graph.FindConsumers(sub_output);
377   if (consumers.size() != 2) {
378     return absl::NotFoundError("MeanStdDevNormalization not suitable.");
379   }
380   Node* square_node = consumers[0];
381   Node* sub_child_mul_node = consumers[1];
382   if (!CheckIfValidNodeOfType(square_node, OperationType::SQUARE).ok()) {
383     square_node = consumers[1];
384     sub_child_mul_node = consumers[0];
385   }
386   RETURN_IF_ERROR(CheckIfValidNodeOfType(square_node, OperationType::SQUARE));
387   RETURN_IF_ERROR(
388       CheckIfValidNodeOfType(sub_child_mul_node, OperationType::MUL));
389   Node* second_mean_node;
390   RETURN_IF_ERROR(GetNextSingleNode(graph, *square_node, OperationType::MEAN,
391                                     &second_mean_node));
392   auto second_mean_attr =
393       absl::any_cast<MeanAttributes>(second_mean_node->operation.attributes);
394   if (second_mean_attr.dims != std::set<Axis>{Axis::CHANNELS}) {
395     return absl::NotFoundError("MeanStdDevNormalization not suitable.");
396   }
397   Node* add_node;
398   RETURN_IF_ERROR(GetNextSingleNode(graph, *second_mean_node,
399                                     OperationType::ADD, &add_node));
400   float add_value;
401   RETURN_IF_ERROR(GetElementwiseScalarValue(add_node, &add_value));
402   Node* rsqrt_node;
403   RETURN_IF_ERROR(
404       GetNextSingleNode(graph, *add_node, OperationType::RSQRT, &rsqrt_node));
405   Node* mul_node;
406   RETURN_IF_ERROR(
407       GetNextSingleNode(graph, *rsqrt_node, OperationType::MUL, &mul_node));
408   if (sub_child_mul_node != mul_node) {
409     return absl::NotFoundError("MeanStdDevNormalization not suitable.");
410   }
411 
412   OperationDef op_def;
413   op_def.precision = precision;
414   auto input_id = graph.FindInputs(first_mean_node->id)[0]->id;
415   auto it = tensor_descriptors.find(input_id);
416   if (it != tensor_descriptors.end()) {
417     op_def.src_tensors.push_back(it->second);
418   }
419   auto output_id = graph.FindInputs(mul_node->id)[0]->id;
420   it = tensor_descriptors.find(output_id);
421   if (it != tensor_descriptors.end()) {
422     op_def.dst_tensors.push_back(it->second);
423   }
424 
425   auto subgraph_inputs = graph.FindInputs(first_mean_node->id);
426   auto subgraph_outputs = graph.FindOutputs(mul_node->id);
427   std::unique_ptr<GPUOperation>* gpu_op =
428       InitSingleOpSubgraph(subgraph_inputs, subgraph_outputs, gpu_subgraph);
429   *gpu_op =
430       std::make_unique<MeanStdDevNormalization>(CreateMeanStdDevNormalization(
431           op_def, gpu_info, subgraph_inputs[0]->tensor.shape, add_value,
432           /*two_step*/ false));
433 
434   consumed_nodes->insert(first_mean_node->id);
435   consumed_nodes->insert(sub_node->id);
436   consumed_nodes->insert(square_node->id);
437   consumed_nodes->insert(second_mean_node->id);
438   consumed_nodes->insert(add_node->id);
439   consumed_nodes->insert(rsqrt_node->id);
440   consumed_nodes->insert(mul_node->id);
441 
442   return absl::OkStatus();
443 }
444 
445 }  // namespace gpu
446 }  // namespace tflite
447