xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/model_transformer.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_
18 
19 #include <deque>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/lite/delegates/gpu/common/model.h"
26 
27 namespace tflite {
28 namespace gpu {
29 
30 struct TransformationContext {
31   GraphFloat32* graph;
32 };
33 
34 enum class TransformStatus {
35   // Transformation was not applied due to trivial conditions mismatch.
36   //
37   // This is different from DECLINED code below that provides in-depth
38   // explanation why a transformation that could have been applied but was not
39   // due to some issues.
40   SKIPPED,
41 
42   // Transformation was declined, therefore, a model was not modified.
43   DECLINED,
44 
45   // Transformation was applied successfully
46   APPLIED,
47 
48   // Transformation may partially be applied, but left a model in an invalid
49   // state. This error should be considered unrecoverable.
50   INVALID,
51 };
52 
53 struct TransformResult {
54   TransformStatus status;
55   std::string message;
56   bool operator==(const TransformResult& result) const {
57     return this->status == result.status && this->message == result.message;
58   }
59 };
60 
61 // Class responsible for applying a transformation to a single node.
62 class NodeTransformation {
63  public:
64   virtual ~NodeTransformation() = default;
65 
66   virtual TransformResult ApplyToNode(Node* node, GraphFloat32* graph) = 0;
67 };
68 
69 // Class responsible for applying a transformation to a sequence of nodes.
70 // Nodes are guaranteed to depend on each other without extra dependents being
71 // spilled.
72 class SequenceTransformation {
73  public:
74   virtual ~SequenceTransformation() = default;
75 
76   // @return number of nodes in a sequence to apply this transformation.
77   virtual int ExpectedSequenceLength() const = 0;
78 
79   // Applies transformations to a sequence of nodes. Transformation
80   // implementation is free manipulate with sequence nodes including adding
81   // and/or deleting nodes. if there were updates to nodes in the end and/or
82   // beginning of the sequence, then referential consistency should be
83   // maintained by updating relevant references in nodes that precede this
84   // sequence or depend on a last node of the sequence.
85   virtual TransformResult ApplyToNodesSequence(
86       const std::vector<Node*>& sequence, GraphFloat32* graph) = 0;
87 };
88 
89 // Performs model transformations.
90 class ModelTransformer {
91  public:
ModelTransformer(GraphFloat32 * graph)92   explicit ModelTransformer(GraphFloat32* graph) : graph_(graph) {}
93 
94   // @return false if a graph is in the broken states can not be used any more
95   bool Apply(const std::string& name, SequenceTransformation* transformation);
96 
97   // @return false if a graph is in the broken states can not be used any more
98   bool Apply(const std::string& name, NodeTransformation* transformation);
99 
100   // @return last recorded error for graph transformations.
101   const std::string& last_transformation_message() const;
102 
103  private:
104   bool ApplyStartingWithNode(const std::string& name,
105                              SequenceTransformation* transformation,
106                              Node* begin);
107 
AddNodeToProcess(Node * node)108   void AddNodeToProcess(Node* node) {
109     if (node && processed_.insert(node->id).second) {
110       to_process_.push_back(node->id);
111     }
112   }
113 
114   GraphFloat32* graph_;
115 
116   // TODO(b/163423950): Clean up messaging mechanism.
117   std::string last_transformation_message_;
118   std::deque<NodeId> to_process_;
119   absl::flat_hash_set<NodeId> processed_;
120 };
121 
122 }  // namespace gpu
123 }  // namespace tflite
124 
125 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_
126