1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
17
18 #include <deque>
19 #include <string>
20 #include <vector>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/lite/delegates/gpu/common/model.h"
25
26 namespace tflite {
27 namespace gpu {
28
Apply(const std::string & name,SequenceTransformation * transformation)29 bool ModelTransformer::Apply(const std::string& name,
30 SequenceTransformation* transformation) {
31 // Seed transformations with starting node. Each node may start a chain of
32 // transformations.
33 for (auto input : graph_->inputs()) {
34 for (auto node : graph_->FindConsumers(input->id)) {
35 AddNodeToProcess(node);
36 }
37 }
38 while (!to_process_.empty()) {
39 auto node = graph_->GetNode(to_process_.front());
40 if (node) {
41 if (!ApplyStartingWithNode(name, transformation, node)) {
42 return false;
43 }
44 }
45 to_process_.pop_front();
46 }
47 processed_.clear();
48 return true;
49 }
50
Apply(const std::string & name,NodeTransformation * transformation)51 bool ModelTransformer::Apply(const std::string& name,
52 NodeTransformation* transformation) {
53 // Apply a transformation only to nodes that are present in the graph before
54 // transformation.
55 std::vector<NodeId> nodes;
56 for (auto node : graph_->nodes()) {
57 nodes.push_back(node->id);
58 }
59 for (auto node_id : nodes) {
60 auto node = graph_->GetNode(node_id);
61 if (!node) {
62 continue;
63 }
64 auto result = transformation->ApplyToNode(node, graph_);
65 last_transformation_message_ = result.message;
66 if (result.status == TransformStatus::INVALID) {
67 return false;
68 }
69 }
70 return true;
71 }
72
last_transformation_message() const73 const std::string& ModelTransformer::last_transformation_message() const {
74 return last_transformation_message_;
75 }
76
ApplyStartingWithNode(const std::string & name,SequenceTransformation * transformation,Node * begin)77 bool ModelTransformer::ApplyStartingWithNode(
78 const std::string& name, SequenceTransformation* transformation,
79 Node* begin) {
80 int expected_sequence_length = transformation->ExpectedSequenceLength();
81
82 std::deque<NodeId> sequence;
83 std::vector<Node*> nodes;
84 nodes.reserve(transformation->ExpectedSequenceLength());
85 sequence.push_back(begin->id);
86
87 // Go over nodes with sequence sliding window of size
88 // expected_sequence_length until a node with multiple dependents is found.
89 while (true) {
90 // Apply transformation if possible.
91 if (sequence.size() == expected_sequence_length) {
92 nodes.clear();
93 for (NodeId id : sequence) {
94 // Nodes present in sequence should be present in a graph. If they are
95 // not, then this transformation changes a graph but didn't say it.
96 Node* node = graph_->GetNode(id);
97 if (node == nullptr) {
98 return false;
99 }
100 nodes.push_back(node);
101 }
102
103 NodeId first_in_sequence = sequence.front();
104 auto preceding_node =
105 graph_->FindProducer(graph_->FindInputs(first_in_sequence)[0]->id);
106 auto result = transformation->ApplyToNodesSequence(nodes, graph_);
107 last_transformation_message_ = result.message;
108 if (result.status == TransformStatus::INVALID) {
109 // graph is broken now.
110 return false;
111 }
112 if (result.status == TransformStatus::APPLIED) {
113 // Also remove first node of a sequence from a set of processed node.
114 // Out of all nodes in a sequence only first one may have been added
115 // to "processed" set because other nodes do not have more than one
116 // dependent. However, if a sequence is changed, then processing needs
117 // to be restarted again.
118 processed_.erase(first_in_sequence);
119 // Transformation was successful. Restart sequence from the node that
120 // precedes current sequence.
121 if (preceding_node) {
122 processed_.erase(preceding_node->id);
123 AddNodeToProcess(preceding_node);
124 } else {
125 // This is the first node in the graph. Re-seed transformation.
126 for (auto input : graph_->inputs()) {
127 for (auto node : graph_->FindConsumers(input->id)) {
128 AddNodeToProcess(node);
129 }
130 }
131 }
132 return true;
133 }
134 }
135
136 // Try to extend current sequence.
137 Node* next_node_in_sequence = nullptr;
138 bool has_multiple_children = false;
139
140 // Check that all outputs from last node are consumed by a single node.
141 for (auto output_value : graph_->FindOutputs(sequence.back())) {
142 for (auto dependent : graph_->FindConsumers(output_value->id)) {
143 if (has_multiple_children) {
144 AddNodeToProcess(dependent);
145 } else if (next_node_in_sequence == nullptr) {
146 next_node_in_sequence = dependent;
147 } else if (next_node_in_sequence != dependent) {
148 // There are more than two nodes depend on the output from end node,
149 // therefore here a sequence stops and new will start. Push all such
150 // nodes.
151 has_multiple_children = true;
152 AddNodeToProcess(dependent);
153 AddNodeToProcess(next_node_in_sequence);
154 }
155 }
156 }
157
158 // Now check that next node has inputs only produced by the last node.
159 if (!has_multiple_children && next_node_in_sequence) {
160 for (auto input : graph_->FindInputs(next_node_in_sequence->id)) {
161 auto producer = graph_->FindProducer(input->id);
162 if (producer == nullptr || producer->id != sequence.back()) {
163 has_multiple_children = true;
164 AddNodeToProcess(next_node_in_sequence);
165 break;
166 }
167 }
168 }
169
170 if (has_multiple_children || next_node_in_sequence == nullptr) {
171 // reached end of this transformation sequence.
172 return true;
173 }
174
175 sequence.push_back(next_node_in_sequence->id);
176 // Decrease sequence until it matches expected length.
177 if (sequence.size() > expected_sequence_length) {
178 sequence.pop_front();
179 }
180 }
181 return true;
182 }
183
184 } // namespace gpu
185 } // namespace tflite
186