xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/task/arguments.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_
18 
19 #include <map>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/lite/delegates/gpu/common/access_type.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
27 #include "tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h"
28 #include "tensorflow/lite/delegates/gpu/common/types.h"
29 #include "tensorflow/lite/delegates/gpu/common/util.h"
30 
31 namespace tflite {
32 namespace gpu {
33 namespace cl {
34 class CLArguments;
35 }
36 
37 class ArgumentsBinder {
38  public:
39   virtual absl::Status SetInt(const std::string& name, int value) = 0;
40   virtual absl::Status SetFloat(const std::string& name, float value) = 0;
41   virtual absl::Status SetHalf(const std::string& name, half value) = 0;
42   virtual ~ArgumentsBinder() = default;
43 };
44 
45 class Arguments : public ArgumentsBinder {
46  public:
47   Arguments() = default;
48   ~Arguments() override = default;
49 
50   // Move only
51   Arguments(Arguments&& args) = default;
52   Arguments& operator=(Arguments&& args) = default;
53   Arguments(const Arguments&) = delete;
54   Arguments& operator=(const Arguments&) = delete;
55 
56   void AddFloat(const std::string& name, float value = 0.0f);
57   void AddHalf(const std::string& name, half value = half(0.0f));
58   void AddInt(const std::string& name, int value = 0);
59   absl::Status SetInt(const std::string& name, int value) override;
60   absl::Status SetFloat(const std::string& name, float value) override;
61   absl::Status SetHalf(const std::string& name, half value) override;
62   void AddObjectRef(const std::string& name, AccessType access_type,
63                     GPUObjectDescriptorPtr&& descriptor_ptr);
64   void AddObject(const std::string& name,
65                  GPUObjectDescriptorPtr&& descriptor_ptr);
66 
67   void RenameArgs(const std::string& postfix, std::string* code) const;
68   absl::Status Merge(Arguments&& args, const std::string& postfix,
69                      const std::vector<std::string>& exception_names = {});
70 
71   absl::Status GetDescriptor(const std::string& name,
72                              GPUObjectDescriptor** descriptor) const;
73 
74   int GetReadTexturesCount(const GpuInfo& gpu_info) const;
75   int GetWriteTexturesCount(const GpuInfo& gpu_info) const;
76 
77   void ReleaseCPURepresentation();
78 
79   void GetActiveArguments(const std::string& code);
80 
81   void SetStateValueForAllObjects(const std::string& key,
82                                   const std::string& value);
83 
84   struct IntValue {
85     int value;
86 
87     // many uniforms generated automatically and not used
88     // to reduce amount of data transferred we adding this optimization
89     bool active = false;
90   };
91   struct FloatValue {
92     float value;
93 
94     // many uniforms generated automatically and not used
95     // to reduce amount of data transferred we adding this optimization
96     bool active = false;
97   };
98   struct HalfValue {
99     half value;
100 
101     // many uniforms generated automatically and not used
102     // to reduce amount of data transferred we adding this optimization
103     bool active = false;
104   };
105 
GetIntValues()106   const std::map<std::string, IntValue>& GetIntValues() const {
107     return int_values_;
108   }
GetFloatValues()109   const std::map<std::string, FloatValue>& GetFloatValues() const {
110     return float_values_;
111   }
GetHalfValues()112   const std::map<std::string, HalfValue>& GetHalfValues() const {
113     return half_values_;
114   }
115 
GetObjectRefs()116   const std::map<std::string, GPUObjectDescriptorPtr>& GetObjectRefs() const {
117     return object_refs_;
118   }
GetObjects()119   const std::map<std::string, GPUObjectDescriptorPtr>& GetObjects() const {
120     return objects_;
121   }
MoveObjectRefs(std::map<std::string,GPUObjectDescriptorPtr> * result)122   void MoveObjectRefs(std::map<std::string, GPUObjectDescriptorPtr>* result) {
123     *result = std::move(object_refs_);
124   }
125 
126   absl::Status Compile(const GpuInfo& gpu_info,
127                        const std::map<std::string, std::string>& linkables,
128                        std::string* code);
129 
130   absl::Status ResolveConstExprPass(const GpuInfo& gpu_info,
131                                     std::string* code) const;
132 
133   absl::Status ResolveConstExpr(const GpuInfo& gpu_info,
134                                 const std::string& object_name,
135                                 const std::string& const_expr,
136                                 std::string* result) const;
137 
138   absl::Status ResolveSelectorsPass(
139       const GpuInfo& gpu_info,
140       const std::map<std::string, std::string>& linkables,
141       std::string* code) const;
142 
143   absl::Status ResolveSelector(
144       const GpuInfo& gpu_info,
145       const std::map<std::string, std::string>& linkables,
146       const std::string& object_name, const std::string& selector,
147       const std::vector<std::string>& function_args,
148       const std::vector<std::string>& template_args, std::string* result) const;
149 
150   void ResolveObjectNames(const std::string& object_name,
151                           const std::vector<std::string>& member_names,
152                           std::string* code) const;
153   absl::Status AddObjectsScalarArgs(const GpuInfo& gpu_info);
154   void ResolveArgsPass(std::string* code) const;
155 
156  private:
157   friend flatbuffers::Offset<tflite::gpu::data::Arguments> Encode(
158       const Arguments& args, flatbuffers::FlatBufferBuilder* builder);
159   friend absl::Status Decode(const tflite::gpu::data::Arguments* fb_args,
160                              Arguments* args);
161 
162   absl::Status ResolveKernelGlobalSpaceBuffers(const GpuInfo& gpu_info,
163                                                std::string* code);
164 
165   friend class cl::CLArguments;
166 
167   static constexpr char kArgsPrefix[] = "args.";
168 
169   std::map<std::string, IntValue> int_values_;
170   std::map<std::string, FloatValue> float_values_;
171   std::map<std::string, HalfValue> half_values_;
172 
173   std::map<std::string, GPUObjectDescriptorPtr> object_refs_;
174   std::map<std::string, GPUObjectDescriptorPtr> objects_;
175 };
176 
177 }  // namespace gpu
178 }  // namespace tflite
179 
180 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_
181