xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/model_transformer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
17 
18 #include <deque>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/lite/delegates/gpu/common/model.h"
25 
26 namespace tflite {
27 namespace gpu {
28 
Apply(const std::string & name,SequenceTransformation * transformation)29 bool ModelTransformer::Apply(const std::string& name,
30                              SequenceTransformation* transformation) {
31   // Seed transformations with starting node. Each node may start a chain of
32   // transformations.
33   for (auto input : graph_->inputs()) {
34     for (auto node : graph_->FindConsumers(input->id)) {
35       AddNodeToProcess(node);
36     }
37   }
38   while (!to_process_.empty()) {
39     auto node = graph_->GetNode(to_process_.front());
40     if (node) {
41       if (!ApplyStartingWithNode(name, transformation, node)) {
42         return false;
43       }
44     }
45     to_process_.pop_front();
46   }
47   processed_.clear();
48   return true;
49 }
50 
Apply(const std::string & name,NodeTransformation * transformation)51 bool ModelTransformer::Apply(const std::string& name,
52                              NodeTransformation* transformation) {
53   // Apply a transformation only to nodes that are present in the graph before
54   // transformation.
55   std::vector<NodeId> nodes;
56   for (auto node : graph_->nodes()) {
57     nodes.push_back(node->id);
58   }
59   for (auto node_id : nodes) {
60     auto node = graph_->GetNode(node_id);
61     if (!node) {
62       continue;
63     }
64     auto result = transformation->ApplyToNode(node, graph_);
65     last_transformation_message_ = result.message;
66     if (result.status == TransformStatus::INVALID) {
67       return false;
68     }
69   }
70   return true;
71 }
72 
last_transformation_message() const73 const std::string& ModelTransformer::last_transformation_message() const {
74   return last_transformation_message_;
75 }
76 
ApplyStartingWithNode(const std::string & name,SequenceTransformation * transformation,Node * begin)77 bool ModelTransformer::ApplyStartingWithNode(
78     const std::string& name, SequenceTransformation* transformation,
79     Node* begin) {
80   int expected_sequence_length = transformation->ExpectedSequenceLength();
81 
82   std::deque<NodeId> sequence;
83   std::vector<Node*> nodes;
84   nodes.reserve(transformation->ExpectedSequenceLength());
85   sequence.push_back(begin->id);
86 
87   // Go over nodes with sequence sliding window of size
88   // expected_sequence_length until a node with multiple dependents is found.
89   while (true) {
90     // Apply transformation if possible.
91     if (sequence.size() == expected_sequence_length) {
92       nodes.clear();
93       for (NodeId id : sequence) {
94         // Nodes present in sequence should be present in a graph. If they are
95         // not, then this transformation changes a graph but didn't say it.
96         Node* node = graph_->GetNode(id);
97         if (node == nullptr) {
98           return false;
99         }
100         nodes.push_back(node);
101       }
102 
103       NodeId first_in_sequence = sequence.front();
104       auto preceding_node =
105           graph_->FindProducer(graph_->FindInputs(first_in_sequence)[0]->id);
106       auto result = transformation->ApplyToNodesSequence(nodes, graph_);
107       last_transformation_message_ = result.message;
108       if (result.status == TransformStatus::INVALID) {
109         // graph is broken now.
110         return false;
111       }
112       if (result.status == TransformStatus::APPLIED) {
113         // Also remove first node of a sequence from a set of processed node.
114         // Out of all nodes in a sequence only first one may have been added
115         // to "processed" set because other nodes do not have more than one
116         // dependent. However, if a sequence is changed, then processing needs
117         // to be restarted again.
118         processed_.erase(first_in_sequence);
119         // Transformation was successful. Restart sequence from the node that
120         // precedes current sequence.
121         if (preceding_node) {
122           processed_.erase(preceding_node->id);
123           AddNodeToProcess(preceding_node);
124         } else {
125           // This is the first node in the graph. Re-seed transformation.
126           for (auto input : graph_->inputs()) {
127             for (auto node : graph_->FindConsumers(input->id)) {
128               AddNodeToProcess(node);
129             }
130           }
131         }
132         return true;
133       }
134     }
135 
136     // Try to extend current sequence.
137     Node* next_node_in_sequence = nullptr;
138     bool has_multiple_children = false;
139 
140     // Check that all outputs from last node are consumed by a single node.
141     for (auto output_value : graph_->FindOutputs(sequence.back())) {
142       for (auto dependent : graph_->FindConsumers(output_value->id)) {
143         if (has_multiple_children) {
144           AddNodeToProcess(dependent);
145         } else if (next_node_in_sequence == nullptr) {
146           next_node_in_sequence = dependent;
147         } else if (next_node_in_sequence != dependent) {
148           // There are more than two nodes depend on the output from end node,
149           // therefore here a sequence stops and new will start. Push all such
150           // nodes.
151           has_multiple_children = true;
152           AddNodeToProcess(dependent);
153           AddNodeToProcess(next_node_in_sequence);
154         }
155       }
156     }
157 
158     // Now check that next node has inputs only produced by the last node.
159     if (!has_multiple_children && next_node_in_sequence) {
160       for (auto input : graph_->FindInputs(next_node_in_sequence->id)) {
161         auto producer = graph_->FindProducer(input->id);
162         if (producer == nullptr || producer->id != sequence.back()) {
163           has_multiple_children = true;
164           AddNodeToProcess(next_node_in_sequence);
165           break;
166         }
167       }
168     }
169 
170     if (has_multiple_children || next_node_in_sequence == nullptr) {
171       // reached end of this transformation sequence.
172       return true;
173     }
174 
175     sequence.push_back(next_node_in_sequence->id);
176     // Decrease sequence until it matches expected length.
177     if (sequence.size() > expected_sequence_length) {
178       sequence.pop_front();
179     }
180   }
181   return true;
182 }
183 
184 }  // namespace gpu
185 }  // namespace tflite
186