1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h"
17
18 #include <any>
19 #include <string>
20 #include <vector>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_replace.h"
25 #include "absl/types/any.h"
26 #include "absl/types/variant.h"
27 #include "tensorflow/lite/delegates/gpu/common/model.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/types.h"
30 #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
31
32 namespace tflite {
33 namespace gpu {
34 namespace gl {
35 namespace {
36
MakeValueReplacement(int n,int k)37 std::pair<std::string, std::string> MakeValueReplacement(int n, int k) {
38 return {absl::StrCat("value_", n), absl::StrCat("value_", k)};
39 }
40
MakeDataReplacement(int n,int k)41 std::pair<std::string, std::string> MakeDataReplacement(int n, int k) {
42 return {absl::StrCat("input_data_", n), absl::StrCat("input_data_", k)};
43 }
44
45 } // namespace
46
ApplyToNode(Node * node,GraphFloat32 * graph)47 TransformResult FuseAutoInput::ApplyToNode(Node* node, GraphFloat32* graph) {
48 auto& node_attr =
49 std::any_cast<CompiledNodeAttributes&>(node->operation.attributes);
50 auto& node_code = node_attr.code;
51
52 if (node_code.input != IOStructure::AUTO) {
53 return {TransformStatus::SKIPPED, ""};
54 }
55 uint3 workgroup = node_code.workgroup;
56
57 auto node_outputs = graph->FindOutputs(node->id);
58
59 // Check which inputs could be fused into the current node.
60 std::vector<std::pair<Node*, int>> nodes_to_fuse;
61 std::vector<std::pair<ValueId, int>> input_values;
62 int input_num = -1;
63 for (auto input_value : graph->FindInputs(node->id)) {
64 input_num++;
65 const ValueId input_id = input_value->id;
66 input_values.push_back({input_id, input_num});
67
68 if (graph->FindConsumers(input_id).size() > 1) {
69 continue; // input is consumed by >1 nodes
70 }
71 Node* input_producer = graph->FindProducer(input_id);
72 if (input_producer == nullptr) {
73 continue; // graph's input
74 }
75 if (graph->FindOutputs(input_producer->id).size() != 1) {
76 continue; // input node has more than one output
77 }
78 auto& input_producer_attr = std::any_cast<const CompiledNodeAttributes&>(
79 input_producer->operation.attributes);
80 if (input_producer_attr.code.output != IOStructure::AUTO) {
81 continue;
82 }
83 if (input_producer_attr.code.workload != node_code.workload &&
84 uint3() != input_producer_attr.code.workload) {
85 continue;
86 }
87 if (input_producer_attr.code.workgroup != uint3()) {
88 // New fused node should fuse only a single shader that has pre-defined
89 // workgroup. Such shader is considered "heavy". Do not fuse two heavy
90 // shaders into one.
91 // TODO(eignasheva): make sure it still works.
92 if (workgroup != uint3()) {
93 continue;
94 }
95 workgroup = input_producer_attr.code.workgroup;
96 }
97 nodes_to_fuse.push_back({input_producer, input_num});
98 input_values.pop_back(); // this value will not be used as input.
99 }
100 if (nodes_to_fuse.empty()) {
101 return {TransformStatus::SKIPPED, ""};
102 }
103
104 // Skip fusions which will result in duplicate inputs, e.g. diamond shapes.
105 {
106 absl::flat_hash_set<ValueId> all_inputs;
107 for (const auto& node_to_fuse : nodes_to_fuse) {
108 for (const auto& input : graph->FindInputs(node_to_fuse.first->id)) {
109 if (all_inputs.find(input->id) != all_inputs.end()) {
110 return {TransformStatus::SKIPPED, ""};
111 }
112 all_inputs.insert(input->id);
113 }
114 }
115 for (const auto& input : graph->FindInputs(node->id)) {
116 if (all_inputs.find(input->id) != all_inputs.end()) {
117 return {TransformStatus::SKIPPED, ""};
118 }
119 all_inputs.insert(input->id);
120 }
121 }
122
123 // Break connections between current node and its inputs.
124 for (auto value : graph->FindInputs(node->id)) {
125 if (!graph->RemoveConsumer(node->id, value->id).ok()) {
126 return {TransformStatus::INVALID, ""};
127 }
128 }
129
130 std::string operation_type;
131 std::string source_code;
132 std::string values;
133
134 // Node source code need to be appended later to the end.
135 std::swap(source_code, node_code.source_code);
136
137 // Indicates value_k that is beyond originally declared [0..n] values,
138 // therefore, it can be used by newly added dependencies.
139 int extra_input_num = input_num;
140 input_num = 0;
141
142 // Fuse all nodes into one.
143 for (auto input_and_num : nodes_to_fuse) {
144 auto& input = input_and_num.first;
145 auto& attr =
146 std::any_cast<CompiledNodeAttributes&>(input->operation.attributes);
147 auto super_inputs = graph->FindInputs(input->id);
148
149 // Replace all internal references in the input source code. For example:
150 // source code "value_0 = max(0, value_0);" will be rewritten into
151 // "value_2 = max(0, value_2);"
152 std::vector<std::pair<std::string, std::string>> replacements;
153 for (int i = 0; i < super_inputs.size(); ++i) {
154 // Node source code uses value_N to access output value from the fused
155 // node. Use correct reference.
156 //
157 // Here value_N does not correspond to input_N anymore. Instead it tracks
158 // value_n and input_m independently. Value_index uses an index needed
159 // for the "final" shader, while input_num preserves the order of inputs.
160 // For example:
161 // Shader A: input_0, input_1
162 // value_0 = value_0 > value_1 ? value_0 : value_1;
163 //
164 // Shader B: input_0
165 // value_0 = max(0, value_0);
166 //
167 // AddShader: input_0, input_1
168 // value_0 = value_0 + value_1;
169 //
170 // Fused shader is going to have 3 inputs: input_0 (A), input_1 (A),
171 // input_2 (B). But Shader B need to store result in value_1, because
172 // AddShader refers to it as 'value_1'. So, fused shader will look as
173 // follows:
174 //
175 // // Shader A
176 // vec4 value_0 = input_data_0.data[gid.x, gid.y, gid.z];
177 // vec4 value_2 = input_data_1.data[gid.x, gid.y, gid.z];
178 // value_0 = value_0 > value_2 ? value_0 : value_2;
179 //
180 // // Shader B
181 // vec4 value_1 = input_data_2.data[gid.x, gid.y, gid.z];
182 // value_1 = max(0, value_1);
183 //
184 // // AddShader
185 // value_0 = value_0 + value_1;
186 //
187 // output_data_0.data[gid.x, gid.y, gid.z] = value_0;
188 int value_index = i == 0 ? input_and_num.second : ++extra_input_num;
189 replacements.push_back(MakeValueReplacement(i, value_index));
190 replacements.push_back(MakeDataReplacement(i, input_num));
191
192 // Declare input values based on the input structure of the merged node.
193 // This code copies what shader_codegen would do automatically.
194 if (attr.code.input == IOStructure::AUTO) {
195 absl::StrAppend(&values, " value_", value_index, " = $input_data_",
196 input_num, "[gid.x, gid.y, gid.z]$;\n");
197 }
198
199 if (!graph->AddConsumer(node->id, super_inputs[i]->id).ok()) {
200 return {TransformStatus::INVALID, ""};
201 }
202 input_num++;
203 }
204
205 // Also rename all _h and _w parameters to the new names.
206 for (auto& param : attr.code.parameters) {
207 param.name = absl::StrReplaceAll(param.name, replacements);
208 }
209 attr.code.source_code =
210 absl::StrReplaceAll(attr.code.source_code, replacements);
211
212 // Merge all objects, parameters and source code.
213 if (!MergeCode(&attr, &node_attr).ok()) {
214 return {TransformStatus::INVALID, "Unable to merge the code"};
215 }
216 absl::StrAppend(&node_attr.code.source_code, "{\n", attr.code.source_code,
217 "\n}");
218
219 if (!operation_type.empty()) {
220 operation_type += ",";
221 }
222 operation_type += input->operation.type;
223
224 if (!graph->DeleteNode(input->id).ok()) {
225 return {TransformStatus::INVALID, ""};
226 }
227 }
228
229 // Add back all inputs that are used directly by the fused node.
230 for (int i = 0; i < input_values.size(); i++) {
231 if (node_code.input == IOStructure::AUTO) {
232 absl::StrAppend(&values, " value_", input_values[i].second,
233 " = $input_data_", input_num,
234 "[gid.x, gid.y, gid.z]$;\n");
235 }
236 if (!graph->AddConsumer(node->id, input_values[i].first).ok()) {
237 return {TransformStatus::INVALID, ""};
238 }
239 input_num++;
240 }
241
242 node_code.input = IOStructure::ONLY_DEFINITIONS;
243
244 absl::StrAppend(&node->operation.type, "(", operation_type, ")");
245 node_code.source_code =
246 absl::StrCat(values, node_code.source_code, "{//FUSED",
247 node->operation.type, "\n", source_code, "\n}");
248
249 return {TransformStatus::APPLIED, ""};
250 }
251
252 } // namespace gl
253 } // namespace gpu
254 } // namespace tflite
255