xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/metal/metal_arguments.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_
16 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_
17 
18 #import <Metal/Metal.h>
19 
20 #include <map>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/lite/delegates/gpu/common/status.h"
25 #include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
26 #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
27 #include "tensorflow/lite/delegates/gpu/metal/gpu_object.h"
28 #include "tensorflow/lite/delegates/gpu/metal/metal_device.h"
29 
30 namespace tflite {
31 namespace gpu {
32 namespace metal {
33 
34 class MetalArguments : public ArgumentsBinder {
35  public:
36   MetalArguments() = default;
37 
38   absl::Status Init(bool use_arguments_buffer, MetalDevice* device,
39                     Arguments* args, std::string* code);
40 
41   absl::Status Init(bool use_arguments_buffer, MetalDevice* device,
42                     Arguments* args);
43 
44   // Move only
45   MetalArguments(MetalArguments&& args) = default;
46   MetalArguments& operator=(MetalArguments&& args) = default;
47   MetalArguments(const MetalArguments&) = delete;
48   MetalArguments& operator=(const MetalArguments&) = delete;
49 
50   absl::Status SetInt(const std::string& name, int value) override;
51   absl::Status SetFloat(const std::string& name, float value) override;
52   absl::Status SetHalf(const std::string& name, half value) override;
53   absl::Status SetObjectRef(const std::string& name, const GPUObject& object);
54 
55   void Encode(id<MTLComputeCommandEncoder> encoder, int buffer_offset,
56               int texture_offset = 0) const;
57 
58   // For usage with Argument Buffers
59   API_AVAILABLE(ios(11.0), macos(10.13), tvos(11.0))
60   void AddResourcesToEncoder(id<MTLComputeCommandEncoder> encoder) const;
61   API_AVAILABLE(ios(11.0), macos(10.13), tvos(11.0))
62   void EncodeArguments(id<MTLArgumentEncoder> arguments_encoder);
63 
64  private:
65   // creates structure with layout:
66   // struct uniforms_buffer {
67   //   int val_0;
68   //   int val_1;
69   //   float val_2;
70   //   int dummy;  // for alignment
71   // };
72   std::string CopyScalarArgumentsToStructWithScalarFields(
73       const Arguments& args, const std::string& call_prefix = "",
74       std::string* code = nullptr);
75 
76   // creates structure with layout:
77   // struct uniforms_buffer {
78   //   int4 val_0_val_1_dummy_dummy;
79   //   float4 val_2_dummy_dummy_dummy;
80   // };
81   std::string CopyScalarArgumentsToStructWithVec4Fields(
82       const Arguments& args, const std::string& call_prefix = "",
83       std::string* code = nullptr);
84 
85   absl::Status AllocateObjects(const Arguments& args, id<MTLDevice> device);
86   absl::Status AddObjectArgs(const GpuInfo& gpu_info, const Arguments& args);
87 
88   void AddGPUResources(const std::string& name, const GPUResources& resources);
89 
90   std::string GetListOfArgs(int buffer_offset, int textures_offset = 0);
91 
92   std::string GetArgumentBufferStructDefinition(bool add_constants_struct);
93 
94   absl::Status SetGPUResources(const std::string& name,
95                                const GPUResourcesWithValue& resources);
96 
97   void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
98   void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
99   void AddImage2DArray(const std::string& name,
100                        const GPUImage2DArrayDescriptor& desc);
101   void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc);
102   void AddImageBuffer(const std::string& name,
103                       const GPUImageBufferDescriptor& desc);
104 
105   absl::Status SetBuffer(const std::string& name, id<MTLBuffer> handle,
106                          uint64_t offset);
107   absl::Status SetImage2D(const std::string& name, id<MTLTexture> handle);
108   absl::Status SetImage2DArray(const std::string& name, id<MTLTexture> handle);
109   absl::Status SetImage3D(const std::string& name, id<MTLTexture> handle);
110   absl::Status SetImageBuffer(const std::string& name, id<MTLTexture> handle);
111 
112   absl::Status SetObjectsResources(const Arguments& args);
113 
114   static constexpr char kArgsPrefix[] = "args.";
115   struct IntValue {
116     int value;
117 
118     // many arguments generated automatically and not used
119     // to reduce amount of data transferred we adding this optimization
120     bool active = false;
121 
122     // offset to shared storage.
123     uint32_t bytes_offset = -1;
124   };
125   std::map<std::string, IntValue> int_values_;
126 
127   struct FloatValue {
128     float value;
129 
130     // many arguments generated automatically and not used
131     // to reduce amount of data transferred we adding this optimization
132     bool active = false;
133 
134     // offset to shared storage.
135     uint32_t bytes_offset = -1;
136   };
137   std::map<std::string, FloatValue> float_values_;
138   std::vector<uint8_t> const_data_;
139 
140   struct MetalBufferDescriptor {
141     GPUBufferDescriptor desc;
142     id<MTLBuffer> handle;
143     uint64_t offset;
144   };
145   struct MetalImage2DDescriptor {
146     GPUImage2DDescriptor desc;
147     id<MTLTexture> handle;
148   };
149   struct MetalImage2DArrayDescriptor {
150     GPUImage2DArrayDescriptor desc;
151     id<MTLTexture> handle;
152   };
153   struct MetalImage3DDescriptor {
154     GPUImage3DDescriptor desc;
155     id<MTLTexture> handle;
156   };
157   struct MetalImageBufferDescriptor {
158     GPUImageBufferDescriptor desc;
159     id<MTLTexture> handle;
160   };
161 
162   std::map<std::string, MetalBufferDescriptor> buffers_;
163   std::map<std::string, MetalImage2DDescriptor> images2d_;
164   std::map<std::string, MetalImage2DArrayDescriptor> image2d_arrays_;
165   std::map<std::string, MetalImage3DDescriptor> images3d_;
166   std::map<std::string, MetalImageBufferDescriptor> image_buffers_;
167 
168   std::map<std::string, GPUObjectDescriptorPtr> object_refs_;
169   std::vector<GPUObjectPtr> objects_;
170 };
171 
172 }  // namespace metal
173 }  // namespace gpu
174 }  // namespace tflite
175 
176 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_
177