1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_ 18 19 #include <map> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/lite/delegates/gpu/common/access_type.h" 25 #include "tensorflow/lite/delegates/gpu/common/status.h" 26 #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h" 27 #include "tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h" 28 #include "tensorflow/lite/delegates/gpu/common/types.h" 29 #include "tensorflow/lite/delegates/gpu/common/util.h" 30 31 namespace tflite { 32 namespace gpu { 33 namespace cl { 34 class CLArguments; 35 } 36 37 class ArgumentsBinder { 38 public: 39 virtual absl::Status SetInt(const std::string& name, int value) = 0; 40 virtual absl::Status SetFloat(const std::string& name, float value) = 0; 41 virtual absl::Status SetHalf(const std::string& name, half value) = 0; 42 virtual ~ArgumentsBinder() = default; 43 }; 44 45 class Arguments : public ArgumentsBinder { 46 public: 47 Arguments() = default; 48 ~Arguments() override = default; 49 50 // Move only 51 Arguments(Arguments&& args) = default; 52 Arguments& operator=(Arguments&& args) = default; 53 Arguments(const Arguments&) = delete; 54 Arguments& operator=(const Arguments&) = delete; 55 56 void AddFloat(const std::string& name, float value = 0.0f); 57 void AddHalf(const std::string& name, half value = half(0.0f)); 58 void AddInt(const std::string& name, int value = 0); 59 absl::Status SetInt(const std::string& name, int value) override; 60 absl::Status SetFloat(const std::string& name, float value) override; 61 absl::Status SetHalf(const std::string& name, half value) override; 62 void AddObjectRef(const std::string& name, AccessType access_type, 63 GPUObjectDescriptorPtr&& descriptor_ptr); 64 void AddObject(const std::string& name, 65 GPUObjectDescriptorPtr&& descriptor_ptr); 66 67 void RenameArgs(const std::string& postfix, std::string* code) const; 68 absl::Status Merge(Arguments&& args, const std::string& postfix, 69 const std::vector<std::string>& exception_names = {}); 70 71 absl::Status GetDescriptor(const std::string& name, 72 GPUObjectDescriptor** descriptor) const; 73 74 int GetReadTexturesCount(const GpuInfo& gpu_info) const; 75 int GetWriteTexturesCount(const GpuInfo& gpu_info) const; 76 77 void ReleaseCPURepresentation(); 78 79 void GetActiveArguments(const std::string& code); 80 81 void SetStateValueForAllObjects(const std::string& key, 82 const std::string& value); 83 84 struct IntValue { 85 int value; 86 87 // many uniforms generated automatically and not used 88 // to reduce amount of data transferred we adding this optimization 89 bool active = false; 90 }; 91 struct FloatValue { 92 float value; 93 94 // many uniforms generated automatically and not used 95 // to reduce amount of data transferred we adding this optimization 96 bool active = false; 97 }; 98 struct HalfValue { 99 half value; 100 101 // many uniforms generated automatically and not used 102 // to reduce amount of data transferred we adding this optimization 103 bool active = false; 104 }; 105 GetIntValues()106 const std::map<std::string, IntValue>& GetIntValues() const { 107 return int_values_; 108 } GetFloatValues()109 const std::map<std::string, FloatValue>& GetFloatValues() const { 110 return float_values_; 111 } GetHalfValues()112 const std::map<std::string, HalfValue>& GetHalfValues() const { 113 return half_values_; 114 } 115 GetObjectRefs()116 const std::map<std::string, GPUObjectDescriptorPtr>& GetObjectRefs() const { 117 return object_refs_; 118 } GetObjects()119 const std::map<std::string, GPUObjectDescriptorPtr>& GetObjects() const { 120 return objects_; 121 } MoveObjectRefs(std::map<std::string,GPUObjectDescriptorPtr> * result)122 void MoveObjectRefs(std::map<std::string, GPUObjectDescriptorPtr>* result) { 123 *result = std::move(object_refs_); 124 } 125 126 absl::Status Compile(const GpuInfo& gpu_info, 127 const std::map<std::string, std::string>& linkables, 128 std::string* code); 129 130 absl::Status ResolveConstExprPass(const GpuInfo& gpu_info, 131 std::string* code) const; 132 133 absl::Status ResolveConstExpr(const GpuInfo& gpu_info, 134 const std::string& object_name, 135 const std::string& const_expr, 136 std::string* result) const; 137 138 absl::Status ResolveSelectorsPass( 139 const GpuInfo& gpu_info, 140 const std::map<std::string, std::string>& linkables, 141 std::string* code) const; 142 143 absl::Status ResolveSelector( 144 const GpuInfo& gpu_info, 145 const std::map<std::string, std::string>& linkables, 146 const std::string& object_name, const std::string& selector, 147 const std::vector<std::string>& function_args, 148 const std::vector<std::string>& template_args, std::string* result) const; 149 150 void ResolveObjectNames(const std::string& object_name, 151 const std::vector<std::string>& member_names, 152 std::string* code) const; 153 absl::Status AddObjectsScalarArgs(const GpuInfo& gpu_info); 154 void ResolveArgsPass(std::string* code) const; 155 156 private: 157 friend flatbuffers::Offset<tflite::gpu::data::Arguments> Encode( 158 const Arguments& args, flatbuffers::FlatBufferBuilder* builder); 159 friend absl::Status Decode(const tflite::gpu::data::Arguments* fb_args, 160 Arguments* args); 161 162 absl::Status ResolveKernelGlobalSpaceBuffers(const GpuInfo& gpu_info, 163 std::string* code); 164 165 friend class cl::CLArguments; 166 167 static constexpr char kArgsPrefix[] = "args."; 168 169 std::map<std::string, IntValue> int_values_; 170 std::map<std::string, FloatValue> float_values_; 171 std::map<std::string, HalfValue> half_values_; 172 173 std::map<std::string, GPUObjectDescriptorPtr> object_refs_; 174 std::map<std::string, GPUObjectDescriptorPtr> objects_; 175 }; 176 177 } // namespace gpu 178 } // namespace tflite 179 180 #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_ 181