xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h"
17 
18 #include <any>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_replace.h"
25 #include "absl/types/any.h"
26 #include "absl/types/variant.h"
27 #include "tensorflow/lite/delegates/gpu/common/model.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/types.h"
30 #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
31 
32 namespace tflite {
33 namespace gpu {
34 namespace gl {
35 namespace {
36 
MakeValueReplacement(int n,int k)37 std::pair<std::string, std::string> MakeValueReplacement(int n, int k) {
38   return {absl::StrCat("value_", n), absl::StrCat("value_", k)};
39 }
40 
MakeDataReplacement(int n,int k)41 std::pair<std::string, std::string> MakeDataReplacement(int n, int k) {
42   return {absl::StrCat("input_data_", n), absl::StrCat("input_data_", k)};
43 }
44 
45 }  // namespace
46 
ApplyToNode(Node * node,GraphFloat32 * graph)47 TransformResult FuseAutoInput::ApplyToNode(Node* node, GraphFloat32* graph) {
48   auto& node_attr =
49       std::any_cast<CompiledNodeAttributes&>(node->operation.attributes);
50   auto& node_code = node_attr.code;
51 
52   if (node_code.input != IOStructure::AUTO) {
53     return {TransformStatus::SKIPPED, ""};
54   }
55   uint3 workgroup = node_code.workgroup;
56 
57   auto node_outputs = graph->FindOutputs(node->id);
58 
59   // Check which inputs could be fused into the current node.
60   std::vector<std::pair<Node*, int>> nodes_to_fuse;
61   std::vector<std::pair<ValueId, int>> input_values;
62   int input_num = -1;
63   for (auto input_value : graph->FindInputs(node->id)) {
64     input_num++;
65     const ValueId input_id = input_value->id;
66     input_values.push_back({input_id, input_num});
67 
68     if (graph->FindConsumers(input_id).size() > 1) {
69       continue;  // input is consumed by >1 nodes
70     }
71     Node* input_producer = graph->FindProducer(input_id);
72     if (input_producer == nullptr) {
73       continue;  // graph's input
74     }
75     if (graph->FindOutputs(input_producer->id).size() != 1) {
76       continue;  // input node has more than one output
77     }
78     auto& input_producer_attr = std::any_cast<const CompiledNodeAttributes&>(
79         input_producer->operation.attributes);
80     if (input_producer_attr.code.output != IOStructure::AUTO) {
81       continue;
82     }
83     if (input_producer_attr.code.workload != node_code.workload &&
84         uint3() != input_producer_attr.code.workload) {
85       continue;
86     }
87     if (input_producer_attr.code.workgroup != uint3()) {
88       // New fused node should fuse only a single shader that has pre-defined
89       // workgroup. Such shader is considered "heavy". Do not fuse two heavy
90       // shaders into one.
91       // TODO(eignasheva): make sure it still works.
92       if (workgroup != uint3()) {
93         continue;
94       }
95       workgroup = input_producer_attr.code.workgroup;
96     }
97     nodes_to_fuse.push_back({input_producer, input_num});
98     input_values.pop_back();  // this value will not be used as input.
99   }
100   if (nodes_to_fuse.empty()) {
101     return {TransformStatus::SKIPPED, ""};
102   }
103 
104   // Skip fusions which will result in duplicate inputs, e.g. diamond shapes.
105   {
106     absl::flat_hash_set<ValueId> all_inputs;
107     for (const auto& node_to_fuse : nodes_to_fuse) {
108       for (const auto& input : graph->FindInputs(node_to_fuse.first->id)) {
109         if (all_inputs.find(input->id) != all_inputs.end()) {
110           return {TransformStatus::SKIPPED, ""};
111         }
112         all_inputs.insert(input->id);
113       }
114     }
115     for (const auto& input : graph->FindInputs(node->id)) {
116       if (all_inputs.find(input->id) != all_inputs.end()) {
117         return {TransformStatus::SKIPPED, ""};
118       }
119       all_inputs.insert(input->id);
120     }
121   }
122 
123   // Break connections between current node and its inputs.
124   for (auto value : graph->FindInputs(node->id)) {
125     if (!graph->RemoveConsumer(node->id, value->id).ok()) {
126       return {TransformStatus::INVALID, ""};
127     }
128   }
129 
130   std::string operation_type;
131   std::string source_code;
132   std::string values;
133 
134   // Node source code need to be appended later to the end.
135   std::swap(source_code, node_code.source_code);
136 
137   // Indicates value_k that is beyond originally declared [0..n] values,
138   // therefore, it can be used by newly added dependencies.
139   int extra_input_num = input_num;
140   input_num = 0;
141 
142   // Fuse all nodes into one.
143   for (auto input_and_num : nodes_to_fuse) {
144     auto& input = input_and_num.first;
145     auto& attr =
146         std::any_cast<CompiledNodeAttributes&>(input->operation.attributes);
147     auto super_inputs = graph->FindInputs(input->id);
148 
149     // Replace all internal references in the input source code. For example:
150     // source code "value_0 = max(0, value_0);" will be rewritten into
151     // "value_2 = max(0, value_2);"
152     std::vector<std::pair<std::string, std::string>> replacements;
153     for (int i = 0; i < super_inputs.size(); ++i) {
154       // Node source code uses value_N to access output value from the fused
155       // node. Use correct reference.
156       //
157       // Here value_N does not correspond to input_N anymore. Instead it tracks
158       // value_n and input_m independently. Value_index uses an index needed
159       // for the "final" shader, while input_num preserves the order of inputs.
160       // For example:
161       //    Shader A: input_0, input_1
162       //    value_0 = value_0 > value_1 ? value_0 : value_1;
163       //
164       //    Shader B:  input_0
165       //    value_0 = max(0, value_0);
166       //
167       //    AddShader: input_0, input_1
168       //    value_0 = value_0 + value_1;
169       //
170       //    Fused shader is going to have 3 inputs: input_0 (A), input_1 (A),
171       //    input_2 (B). But Shader B need to store result in value_1, because
172       //    AddShader refers to it as 'value_1'. So, fused shader will look as
173       //    follows:
174       //
175       //    // Shader A
176       //    vec4 value_0 = input_data_0.data[gid.x, gid.y, gid.z];
177       //    vec4 value_2 = input_data_1.data[gid.x, gid.y, gid.z];
178       //    value_0 = value_0 > value_2 ? value_0 : value_2;
179       //
180       //    // Shader B
181       //    vec4 value_1 = input_data_2.data[gid.x, gid.y, gid.z];
182       //    value_1 = max(0, value_1);
183       //
184       //    // AddShader
185       //    value_0 = value_0 + value_1;
186       //
187       //    output_data_0.data[gid.x, gid.y, gid.z] = value_0;
188       int value_index = i == 0 ? input_and_num.second : ++extra_input_num;
189       replacements.push_back(MakeValueReplacement(i, value_index));
190       replacements.push_back(MakeDataReplacement(i, input_num));
191 
192       // Declare input values based on the input structure of the merged node.
193       // This code copies what shader_codegen would do automatically.
194       if (attr.code.input == IOStructure::AUTO) {
195         absl::StrAppend(&values, "  value_", value_index, " = $input_data_",
196                         input_num, "[gid.x, gid.y, gid.z]$;\n");
197       }
198 
199       if (!graph->AddConsumer(node->id, super_inputs[i]->id).ok()) {
200         return {TransformStatus::INVALID, ""};
201       }
202       input_num++;
203     }
204 
205     // Also rename all _h and _w parameters to the new names.
206     for (auto& param : attr.code.parameters) {
207       param.name = absl::StrReplaceAll(param.name, replacements);
208     }
209     attr.code.source_code =
210         absl::StrReplaceAll(attr.code.source_code, replacements);
211 
212     // Merge all objects, parameters and source code.
213     if (!MergeCode(&attr, &node_attr).ok()) {
214       return {TransformStatus::INVALID, "Unable to merge the code"};
215     }
216     absl::StrAppend(&node_attr.code.source_code, "{\n", attr.code.source_code,
217                     "\n}");
218 
219     if (!operation_type.empty()) {
220       operation_type += ",";
221     }
222     operation_type += input->operation.type;
223 
224     if (!graph->DeleteNode(input->id).ok()) {
225       return {TransformStatus::INVALID, ""};
226     }
227   }
228 
229   // Add back all inputs that are used directly by the fused node.
230   for (int i = 0; i < input_values.size(); i++) {
231     if (node_code.input == IOStructure::AUTO) {
232       absl::StrAppend(&values, "  value_", input_values[i].second,
233                       " = $input_data_", input_num,
234                       "[gid.x, gid.y, gid.z]$;\n");
235     }
236     if (!graph->AddConsumer(node->id, input_values[i].first).ok()) {
237       return {TransformStatus::INVALID, ""};
238     }
239     input_num++;
240   }
241 
242   node_code.input = IOStructure::ONLY_DEFINITIONS;
243 
244   absl::StrAppend(&node->operation.type, "(", operation_type, ")");
245   node_code.source_code =
246       absl::StrCat(values, node_code.source_code, "{//FUSED",
247                    node->operation.type, "\n", source_code, "\n}");
248 
249   return {TransformStatus::APPLIED, ""};
250 }
251 
252 }  // namespace gl
253 }  // namespace gpu
254 }  // namespace tflite
255