xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/special/conv_pointwise.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/special/conv_pointwise.h"
17 
18 #include <cstdint>
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/strings/str_cat.h"
27 #include "tensorflow/lite/delegates/gpu/common/operations.h"
28 #include "tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h"
29 #include "tensorflow/lite/delegates/gpu/common/util.h"
30 
31 namespace tflite {
32 namespace gpu {
33 namespace {
GenerateCode()34 std::string GenerateCode() {
35   std::string c = R"(
36 MAIN_FUNCTION($0) {
37   int X = GLOBAL_ID_0;
38   int Y = GLOBAL_ID_1;
39   int S = GLOBAL_ID_2;
40   if (X >= args.dst_tensor.Width() ||
41       Y >= args.dst_tensor.Height() ||
42       S >= args.dst_tensor.Slices()) return;
43   int4 offset0 = args.offsets.Read(S * 2 + 0, 0);
44   int4 offset1 = args.offsets.Read(S * 2 + 1, 0);
45   ACCUM_FLT4 res = INIT_ACCUM_FLT4(0.0f);
46   FLT4 last_mask;
47   int last_src_ch = (args.src_tensor.Slices() - 1) * 4;
48   last_mask.x = INIT_FLT(1.0f);
49   last_mask.y = last_src_ch + 1 < args.src_tensor.Channels() ? INIT_FLT(1.0f) : INIT_FLT(0.0f);
50   last_mask.z = last_src_ch + 2 < args.src_tensor.Channels() ? INIT_FLT(1.0f) : INIT_FLT(0.0f);
51   last_mask.w = last_src_ch + 3 < args.src_tensor.Channels() ? INIT_FLT(1.0f) : INIT_FLT(0.0f);
52   for (int s = 0; s < args.src_tensor.Slices(); ++s) {
53     FLT4 src = args.src_tensor.Read(X, Y, s);
54     FLT4 w0 = args.weights_tensor.Read(X + offset0.x, Y + offset0.y, s);
55     FLT4 w1 = args.weights_tensor.Read(X + offset0.z, Y + offset0.w, s);
56     FLT4 w2 = args.weights_tensor.Read(X + offset1.x, Y + offset1.y, s);
57     FLT4 w3 = args.weights_tensor.Read(X + offset1.z, Y + offset1.w, s);
58     FLT4 mask = INIT_FLT4(1.0f);
59     if (s == (args.src_tensor.Slices() - 1)) {
60       mask = last_mask;
61     }
62     src *= mask;
63     res.x += dot(src, w0);
64     res.y += dot(src, w1);
65     res.z += dot(src, w2);
66     res.w += dot(src, w3);
67   }
68   FLT4 result = TO_FLT4(res) / INIT_FLT(args.src_tensor.Channels());
69   args.dst_tensor.Write(result, X, Y, S);
70 })";
71   return c;
72 }
73 
74 struct NodeContext {
75   Node* node;
76   std::vector<Value*> inputs;
77   std::vector<Value*> outputs;
78 };
79 
IsNode(const GraphFloat32 & graph,OperationType op_type,int inputs_count,int outputs_count,Node * node,NodeContext * node_context)80 absl::Status IsNode(const GraphFloat32& graph, OperationType op_type,
81                     int inputs_count, int outputs_count, Node* node,
82                     NodeContext* node_context) {
83   const std::string op_desc = ToString(op_type);
84   node_context->node = node;
85   if (node_context->node == nullptr) {
86     return absl::NotFoundError(absl::StrCat("Invalid ", op_desc, " node."));
87   }
88   if (OperationTypeFromString(node_context->node->operation.type) != op_type) {
89     return absl::InternalError(
90         absl::StrCat("Not correct node type. Expected ", op_desc, ", received ",
91                      node_context->node->operation.type));
92   }
93   node_context->inputs = graph.FindInputs(node_context->node->id);
94   node_context->outputs = graph.FindOutputs(node_context->node->id);
95   if (inputs_count != -1) {
96     if (node_context->inputs.size() != inputs_count) {
97       return absl::InternalError(
98           absl::StrCat("Expected ", inputs_count, " input in a ", op_desc,
99                        " node. Node has ", node_context->inputs.size()));
100     }
101   }
102   if (node_context->outputs.size() != outputs_count) {
103     return absl::InternalError(
104         absl::StrCat("Expected ", outputs_count, " output in a ", op_desc,
105                      " node. Node has ", node_context->outputs.size()));
106   }
107   return absl::OkStatus();
108 }
109 
IsMeanNode(const GraphFloat32 & graph,Node * node,NodeContext * node_context)110 absl::Status IsMeanNode(const GraphFloat32& graph, Node* node,
111                         NodeContext* node_context) {
112   RETURN_IF_ERROR(IsNode(graph, OperationType::MEAN, 1, 1, node, node_context));
113   auto mean_attr =
114       absl::any_cast<MeanAttributes>(node_context->node->operation.attributes);
115   if (mean_attr.dims != std::set<Axis>{Axis::CHANNELS}) {
116     return absl::InternalError("Expected mean node with channels reduction.");
117   }
118   return absl::OkStatus();
119 }
120 
IsMulNode(const GraphFloat32 & graph,Node * node,NodeContext * node_context)121 absl::Status IsMulNode(const GraphFloat32& graph, Node* node,
122                        NodeContext* node_context) {
123   RETURN_IF_ERROR(IsNode(graph, OperationType::MUL, 2, 1, node, node_context));
124   if (node_context->inputs[0]->tensor.shape !=
125       node_context->inputs[1]->tensor.shape) {
126     return absl::InternalError("Expected mul node with 2 equal tensors.");
127   }
128   return absl::OkStatus();
129 }
130 
IsSliceNode(const GraphFloat32 & graph,Node * node,NodeContext * node_context)131 absl::Status IsSliceNode(const GraphFloat32& graph, Node* node,
132                          NodeContext* node_context) {
133   RETURN_IF_ERROR(
134       IsNode(graph, OperationType::SLICE, 1, 1, node, node_context));
135   auto slice_attr =
136       absl::any_cast<SliceAttributes>(node_context->node->operation.attributes);
137   if (slice_attr.strides != BHWC(1, 1, 1, 1)) {
138     return absl::InternalError("Not valid attributes in slice node.");
139   }
140   return absl::OkStatus();
141 }
142 
IsConcatNode(const GraphFloat32 & graph,Node * node,NodeContext * node_context)143 absl::Status IsConcatNode(const GraphFloat32& graph, Node* node,
144                           NodeContext* node_context) {
145   RETURN_IF_ERROR(
146       IsNode(graph, OperationType::CONCAT, -1, 1, node, node_context));
147   auto concat_attr = absl::any_cast<ConcatAttributes>(
148       node_context->node->operation.attributes);
149   if (concat_attr.axis != Axis::CHANNELS) {
150     return absl::InternalError("Not valid attributes in concat node.");
151   }
152   return absl::OkStatus();
153 }
154 
GetOffset(const GraphFloat32 & graph,NodeId concat_input_node,NodeId second_commom_input_id,int * offset_x,int * offset_y,std::set<NodeId> * consumed_nodes)155 absl::Status GetOffset(const GraphFloat32& graph, NodeId concat_input_node,
156                        NodeId second_commom_input_id, int* offset_x,
157                        int* offset_y, std::set<NodeId>* consumed_nodes) {
158   NodeContext mean_node, mul_node, slice_node;
159   RETURN_IF_ERROR(
160       IsMeanNode(graph, graph.FindProducer(concat_input_node), &mean_node));
161   RETURN_IF_ERROR(
162       IsMulNode(graph, graph.FindProducer(mean_node.inputs[0]->id), &mul_node));
163   const ValueId slice_output_id =
164       mul_node.inputs[0]->id == second_commom_input_id ? mul_node.inputs[1]->id
165                                                        : mul_node.inputs[0]->id;
166   RETURN_IF_ERROR(
167       IsSliceNode(graph, graph.FindProducer(slice_output_id), &slice_node));
168   auto slice_attr =
169       absl::any_cast<SliceAttributes>(slice_node.node->operation.attributes);
170   *offset_x = slice_attr.starts.w;
171   *offset_y = slice_attr.starts.h;
172   consumed_nodes->insert(mean_node.node->id);
173   consumed_nodes->insert(mul_node.node->id);
174   consumed_nodes->insert(slice_node.node->id);
175   return absl::OkStatus();
176 }
177 
178 }  // namespace
179 
CreateConvPointwise(const OperationDef & definition,const ConvPointwiseAttributes & attr)180 GPUOperation CreateConvPointwise(const OperationDef& definition,
181                                  const ConvPointwiseAttributes& attr) {
182   const int dst_channels = attr.offsets.size();
183   const int dst_depth = DivideRoundUp(dst_channels, 4);
184   std::vector<int32_t> offsets_data(dst_depth * 2 * 4, 0);
185   for (int i = 0; i < attr.offsets.size(); ++i) {
186     offsets_data[i * 2 + 0] = attr.offsets[i].x;
187     offsets_data[i * 2 + 1] = attr.offsets[i].y;
188   }
189   for (int i = attr.offsets.size(); i < offsets_data.size() / 2; ++i) {
190     offsets_data[i * 2 + 0] = attr.offsets.back().x;
191     offsets_data[i * 2 + 1] = attr.offsets.back().y;
192   }
193 
194   GPUOperation op(definition);
195   op.AddSrcTensor("src_tensor", definition.src_tensors[0]);
196   op.AddSrcTensor("weights_tensor", definition.src_tensors[1]);
197   op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
198   op.code_ = GenerateCode();
199   op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
200 
201   TensorDescriptor desc = CreateConstantHWVec4TensorDescriptor(
202       DataType::INT32, TensorStorageType::TEXTURE_2D, dst_depth * 2, 1,
203       reinterpret_cast<uint8_t*>(offsets_data.data()));
204   op.args_.AddObject("offsets", std::make_unique<TensorDescriptor>(desc));
205   return op;
206 }
207 
TryFusedPointwiseConv(const GraphFloat32 & graph,NodeId first_node_id,CalculationsPrecision precision,const std::map<ValueId,TensorDescriptor> & tensor_descriptors,std::set<NodeId> * consumed_nodes,GPUOperationsSubgraph * gpu_subgraph)208 absl::Status TryFusedPointwiseConv(
209     const GraphFloat32& graph, NodeId first_node_id,
210     CalculationsPrecision precision,
211     const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
212     std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) {
213   NodeContext slice_node;
214   RETURN_IF_ERROR(
215       IsSliceNode(graph, graph.GetNode(first_node_id), &slice_node));
216   const auto& first_commom_input = slice_node.inputs[0];
217   auto slice_consumers = graph.FindConsumers(slice_node.outputs[0]->id);
218   if (slice_consumers.size() != 1) {
219     return absl::NotFoundError("FusedPointwiseConv not suitable.");
220   }
221   NodeContext mul_node;
222   RETURN_IF_ERROR(IsMulNode(graph, slice_consumers[0], &mul_node));
223   const auto& second_commom_input =
224       mul_node.inputs[0]->id == slice_node.outputs[0]->id ? mul_node.inputs[1]
225                                                           : mul_node.inputs[0];
226   auto mul_consumers = graph.FindConsumers(mul_node.outputs[0]->id);
227   if (mul_consumers.size() != 1) {
228     return absl::NotFoundError("FusedPointwiseConv not suitable.");
229   }
230   NodeContext mean_node;
231   RETURN_IF_ERROR(IsMeanNode(graph, mul_consumers[0], &mean_node));
232   auto mean_consumers = graph.FindConsumers(mean_node.outputs[0]->id);
233   if (mean_consumers.size() != 1) {
234     return absl::NotFoundError("FusedPointwiseConv not suitable.");
235   }
236   NodeContext concat_node;
237   RETURN_IF_ERROR(IsConcatNode(graph, mean_consumers[0], &concat_node));
238   ConvPointwiseAttributes op_attr;
239   std::set<NodeId> temp_consumed_nodes;
240   for (const auto& concat_input : concat_node.inputs) {
241     int offset_x, offset_y;
242     RETURN_IF_ERROR(GetOffset(graph, concat_input->id, second_commom_input->id,
243                               &offset_x, &offset_y, &temp_consumed_nodes));
244     op_attr.offsets.push_back(int2(offset_x, offset_y));
245   }
246   consumed_nodes->insert(temp_consumed_nodes.begin(),
247                          temp_consumed_nodes.end());
248   consumed_nodes->insert(concat_node.node->id);
249   OperationDef op_def;
250   op_def.precision = precision;
251   auto it = tensor_descriptors.find(second_commom_input->id);
252   if (it != tensor_descriptors.end()) {
253     op_def.src_tensors.push_back(it->second);
254   }
255   it = tensor_descriptors.find(first_commom_input->id);
256   if (it != tensor_descriptors.end()) {
257     op_def.src_tensors.push_back(it->second);
258   }
259   it = tensor_descriptors.find(concat_node.outputs[0]->id);
260   if (it != tensor_descriptors.end()) {
261     op_def.dst_tensors.push_back(it->second);
262   }
263   std::unique_ptr<GPUOperation>* gpu_op =
264       InitSingleOpSubgraph({second_commom_input, first_commom_input},
265                            {concat_node.outputs[0]}, gpu_subgraph);
266   auto operation = CreateConvPointwise(op_def, op_attr);
267   *gpu_op = std::make_unique<GPUOperation>(std::move(operation));
268   return absl::OkStatus();
269 }
270 
271 }  // namespace gpu
272 }  // namespace tflite
273