1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/special/conv_pointwise.h"
17
18 #include <cstdint>
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/strings/str_cat.h"
27 #include "tensorflow/lite/delegates/gpu/common/operations.h"
28 #include "tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h"
29 #include "tensorflow/lite/delegates/gpu/common/util.h"
30
31 namespace tflite {
32 namespace gpu {
33 namespace {
GenerateCode()34 std::string GenerateCode() {
35 std::string c = R"(
36 MAIN_FUNCTION($0) {
37 int X = GLOBAL_ID_0;
38 int Y = GLOBAL_ID_1;
39 int S = GLOBAL_ID_2;
40 if (X >= args.dst_tensor.Width() ||
41 Y >= args.dst_tensor.Height() ||
42 S >= args.dst_tensor.Slices()) return;
43 int4 offset0 = args.offsets.Read(S * 2 + 0, 0);
44 int4 offset1 = args.offsets.Read(S * 2 + 1, 0);
45 ACCUM_FLT4 res = INIT_ACCUM_FLT4(0.0f);
46 FLT4 last_mask;
47 int last_src_ch = (args.src_tensor.Slices() - 1) * 4;
48 last_mask.x = INIT_FLT(1.0f);
49 last_mask.y = last_src_ch + 1 < args.src_tensor.Channels() ? INIT_FLT(1.0f) : INIT_FLT(0.0f);
50 last_mask.z = last_src_ch + 2 < args.src_tensor.Channels() ? INIT_FLT(1.0f) : INIT_FLT(0.0f);
51 last_mask.w = last_src_ch + 3 < args.src_tensor.Channels() ? INIT_FLT(1.0f) : INIT_FLT(0.0f);
52 for (int s = 0; s < args.src_tensor.Slices(); ++s) {
53 FLT4 src = args.src_tensor.Read(X, Y, s);
54 FLT4 w0 = args.weights_tensor.Read(X + offset0.x, Y + offset0.y, s);
55 FLT4 w1 = args.weights_tensor.Read(X + offset0.z, Y + offset0.w, s);
56 FLT4 w2 = args.weights_tensor.Read(X + offset1.x, Y + offset1.y, s);
57 FLT4 w3 = args.weights_tensor.Read(X + offset1.z, Y + offset1.w, s);
58 FLT4 mask = INIT_FLT4(1.0f);
59 if (s == (args.src_tensor.Slices() - 1)) {
60 mask = last_mask;
61 }
62 src *= mask;
63 res.x += dot(src, w0);
64 res.y += dot(src, w1);
65 res.z += dot(src, w2);
66 res.w += dot(src, w3);
67 }
68 FLT4 result = TO_FLT4(res) / INIT_FLT(args.src_tensor.Channels());
69 args.dst_tensor.Write(result, X, Y, S);
70 })";
71 return c;
72 }
73
74 struct NodeContext {
75 Node* node;
76 std::vector<Value*> inputs;
77 std::vector<Value*> outputs;
78 };
79
IsNode(const GraphFloat32 & graph,OperationType op_type,int inputs_count,int outputs_count,Node * node,NodeContext * node_context)80 absl::Status IsNode(const GraphFloat32& graph, OperationType op_type,
81 int inputs_count, int outputs_count, Node* node,
82 NodeContext* node_context) {
83 const std::string op_desc = ToString(op_type);
84 node_context->node = node;
85 if (node_context->node == nullptr) {
86 return absl::NotFoundError(absl::StrCat("Invalid ", op_desc, " node."));
87 }
88 if (OperationTypeFromString(node_context->node->operation.type) != op_type) {
89 return absl::InternalError(
90 absl::StrCat("Not correct node type. Expected ", op_desc, ", received ",
91 node_context->node->operation.type));
92 }
93 node_context->inputs = graph.FindInputs(node_context->node->id);
94 node_context->outputs = graph.FindOutputs(node_context->node->id);
95 if (inputs_count != -1) {
96 if (node_context->inputs.size() != inputs_count) {
97 return absl::InternalError(
98 absl::StrCat("Expected ", inputs_count, " input in a ", op_desc,
99 " node. Node has ", node_context->inputs.size()));
100 }
101 }
102 if (node_context->outputs.size() != outputs_count) {
103 return absl::InternalError(
104 absl::StrCat("Expected ", outputs_count, " output in a ", op_desc,
105 " node. Node has ", node_context->outputs.size()));
106 }
107 return absl::OkStatus();
108 }
109
IsMeanNode(const GraphFloat32 & graph,Node * node,NodeContext * node_context)110 absl::Status IsMeanNode(const GraphFloat32& graph, Node* node,
111 NodeContext* node_context) {
112 RETURN_IF_ERROR(IsNode(graph, OperationType::MEAN, 1, 1, node, node_context));
113 auto mean_attr =
114 absl::any_cast<MeanAttributes>(node_context->node->operation.attributes);
115 if (mean_attr.dims != std::set<Axis>{Axis::CHANNELS}) {
116 return absl::InternalError("Expected mean node with channels reduction.");
117 }
118 return absl::OkStatus();
119 }
120
IsMulNode(const GraphFloat32 & graph,Node * node,NodeContext * node_context)121 absl::Status IsMulNode(const GraphFloat32& graph, Node* node,
122 NodeContext* node_context) {
123 RETURN_IF_ERROR(IsNode(graph, OperationType::MUL, 2, 1, node, node_context));
124 if (node_context->inputs[0]->tensor.shape !=
125 node_context->inputs[1]->tensor.shape) {
126 return absl::InternalError("Expected mul node with 2 equal tensors.");
127 }
128 return absl::OkStatus();
129 }
130
IsSliceNode(const GraphFloat32 & graph,Node * node,NodeContext * node_context)131 absl::Status IsSliceNode(const GraphFloat32& graph, Node* node,
132 NodeContext* node_context) {
133 RETURN_IF_ERROR(
134 IsNode(graph, OperationType::SLICE, 1, 1, node, node_context));
135 auto slice_attr =
136 absl::any_cast<SliceAttributes>(node_context->node->operation.attributes);
137 if (slice_attr.strides != BHWC(1, 1, 1, 1)) {
138 return absl::InternalError("Not valid attributes in slice node.");
139 }
140 return absl::OkStatus();
141 }
142
IsConcatNode(const GraphFloat32 & graph,Node * node,NodeContext * node_context)143 absl::Status IsConcatNode(const GraphFloat32& graph, Node* node,
144 NodeContext* node_context) {
145 RETURN_IF_ERROR(
146 IsNode(graph, OperationType::CONCAT, -1, 1, node, node_context));
147 auto concat_attr = absl::any_cast<ConcatAttributes>(
148 node_context->node->operation.attributes);
149 if (concat_attr.axis != Axis::CHANNELS) {
150 return absl::InternalError("Not valid attributes in concat node.");
151 }
152 return absl::OkStatus();
153 }
154
GetOffset(const GraphFloat32 & graph,NodeId concat_input_node,NodeId second_commom_input_id,int * offset_x,int * offset_y,std::set<NodeId> * consumed_nodes)155 absl::Status GetOffset(const GraphFloat32& graph, NodeId concat_input_node,
156 NodeId second_commom_input_id, int* offset_x,
157 int* offset_y, std::set<NodeId>* consumed_nodes) {
158 NodeContext mean_node, mul_node, slice_node;
159 RETURN_IF_ERROR(
160 IsMeanNode(graph, graph.FindProducer(concat_input_node), &mean_node));
161 RETURN_IF_ERROR(
162 IsMulNode(graph, graph.FindProducer(mean_node.inputs[0]->id), &mul_node));
163 const ValueId slice_output_id =
164 mul_node.inputs[0]->id == second_commom_input_id ? mul_node.inputs[1]->id
165 : mul_node.inputs[0]->id;
166 RETURN_IF_ERROR(
167 IsSliceNode(graph, graph.FindProducer(slice_output_id), &slice_node));
168 auto slice_attr =
169 absl::any_cast<SliceAttributes>(slice_node.node->operation.attributes);
170 *offset_x = slice_attr.starts.w;
171 *offset_y = slice_attr.starts.h;
172 consumed_nodes->insert(mean_node.node->id);
173 consumed_nodes->insert(mul_node.node->id);
174 consumed_nodes->insert(slice_node.node->id);
175 return absl::OkStatus();
176 }
177
178 } // namespace
179
CreateConvPointwise(const OperationDef & definition,const ConvPointwiseAttributes & attr)180 GPUOperation CreateConvPointwise(const OperationDef& definition,
181 const ConvPointwiseAttributes& attr) {
182 const int dst_channels = attr.offsets.size();
183 const int dst_depth = DivideRoundUp(dst_channels, 4);
184 std::vector<int32_t> offsets_data(dst_depth * 2 * 4, 0);
185 for (int i = 0; i < attr.offsets.size(); ++i) {
186 offsets_data[i * 2 + 0] = attr.offsets[i].x;
187 offsets_data[i * 2 + 1] = attr.offsets[i].y;
188 }
189 for (int i = attr.offsets.size(); i < offsets_data.size() / 2; ++i) {
190 offsets_data[i * 2 + 0] = attr.offsets.back().x;
191 offsets_data[i * 2 + 1] = attr.offsets.back().y;
192 }
193
194 GPUOperation op(definition);
195 op.AddSrcTensor("src_tensor", definition.src_tensors[0]);
196 op.AddSrcTensor("weights_tensor", definition.src_tensors[1]);
197 op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
198 op.code_ = GenerateCode();
199 op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
200
201 TensorDescriptor desc = CreateConstantHWVec4TensorDescriptor(
202 DataType::INT32, TensorStorageType::TEXTURE_2D, dst_depth * 2, 1,
203 reinterpret_cast<uint8_t*>(offsets_data.data()));
204 op.args_.AddObject("offsets", std::make_unique<TensorDescriptor>(desc));
205 return op;
206 }
207
TryFusedPointwiseConv(const GraphFloat32 & graph,NodeId first_node_id,CalculationsPrecision precision,const std::map<ValueId,TensorDescriptor> & tensor_descriptors,std::set<NodeId> * consumed_nodes,GPUOperationsSubgraph * gpu_subgraph)208 absl::Status TryFusedPointwiseConv(
209 const GraphFloat32& graph, NodeId first_node_id,
210 CalculationsPrecision precision,
211 const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
212 std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) {
213 NodeContext slice_node;
214 RETURN_IF_ERROR(
215 IsSliceNode(graph, graph.GetNode(first_node_id), &slice_node));
216 const auto& first_commom_input = slice_node.inputs[0];
217 auto slice_consumers = graph.FindConsumers(slice_node.outputs[0]->id);
218 if (slice_consumers.size() != 1) {
219 return absl::NotFoundError("FusedPointwiseConv not suitable.");
220 }
221 NodeContext mul_node;
222 RETURN_IF_ERROR(IsMulNode(graph, slice_consumers[0], &mul_node));
223 const auto& second_commom_input =
224 mul_node.inputs[0]->id == slice_node.outputs[0]->id ? mul_node.inputs[1]
225 : mul_node.inputs[0];
226 auto mul_consumers = graph.FindConsumers(mul_node.outputs[0]->id);
227 if (mul_consumers.size() != 1) {
228 return absl::NotFoundError("FusedPointwiseConv not suitable.");
229 }
230 NodeContext mean_node;
231 RETURN_IF_ERROR(IsMeanNode(graph, mul_consumers[0], &mean_node));
232 auto mean_consumers = graph.FindConsumers(mean_node.outputs[0]->id);
233 if (mean_consumers.size() != 1) {
234 return absl::NotFoundError("FusedPointwiseConv not suitable.");
235 }
236 NodeContext concat_node;
237 RETURN_IF_ERROR(IsConcatNode(graph, mean_consumers[0], &concat_node));
238 ConvPointwiseAttributes op_attr;
239 std::set<NodeId> temp_consumed_nodes;
240 for (const auto& concat_input : concat_node.inputs) {
241 int offset_x, offset_y;
242 RETURN_IF_ERROR(GetOffset(graph, concat_input->id, second_commom_input->id,
243 &offset_x, &offset_y, &temp_consumed_nodes));
244 op_attr.offsets.push_back(int2(offset_x, offset_y));
245 }
246 consumed_nodes->insert(temp_consumed_nodes.begin(),
247 temp_consumed_nodes.end());
248 consumed_nodes->insert(concat_node.node->id);
249 OperationDef op_def;
250 op_def.precision = precision;
251 auto it = tensor_descriptors.find(second_commom_input->id);
252 if (it != tensor_descriptors.end()) {
253 op_def.src_tensors.push_back(it->second);
254 }
255 it = tensor_descriptors.find(first_commom_input->id);
256 if (it != tensor_descriptors.end()) {
257 op_def.src_tensors.push_back(it->second);
258 }
259 it = tensor_descriptors.find(concat_node.outputs[0]->id);
260 if (it != tensor_descriptors.end()) {
261 op_def.dst_tensors.push_back(it->second);
262 }
263 std::unique_ptr<GPUOperation>* gpu_op =
264 InitSingleOpSubgraph({second_commom_input, first_commom_input},
265 {concat_node.outputs[0]}, gpu_subgraph);
266 auto operation = CreateConvPointwise(op_def, op_attr);
267 *gpu_op = std::make_unique<GPUOperation>(std::move(operation));
268 return absl::OkStatus();
269 }
270
271 } // namespace gpu
272 } // namespace tflite
273