xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
16 
17 #include <cstring>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/strings/substitute.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
24 #include "tensorflow/lite/delegates/gpu/common/util.h"
25 #include "tensorflow/lite/delegates/gpu/metal/buffer.h"
26 #include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
27 
28 namespace tflite {
29 namespace gpu {
30 namespace metal {
31 namespace {
IsWordSymbol(char symbol)32 bool IsWordSymbol(char symbol) {
33   return absl::ascii_isalnum(symbol) || symbol == '_';
34 }
35 
ReplaceAllWords(const std::string & old_word,const std::string & new_word,std::string * str)36 void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
37                      std::string* str) {
38   if (!str) {
39     return;
40   }
41   size_t position = str->find(old_word);
42   while (position != std::string::npos) {
43     char prev = position == 0 ? '.' : (*str)[position - 1];
44     char next = position + old_word.size() < str->size()
45                     ? (*str)[position + old_word.size()]
46                     : '.';
47     if (IsWordSymbol(prev) || IsWordSymbol(next)) {
48       position = str->find(old_word, position + 1);
49       continue;
50     }
51     str->replace(position, old_word.size(), new_word);
52     position = str->find(old_word, position + new_word.size());
53   }
54 }
55 
AppendArgument(const std::string & arg,std::string * args)56 void AppendArgument(const std::string& arg, std::string* args) {
57   if (!args->empty()) {
58     absl::StrAppend(args, ",\n");
59   }
60   absl::StrAppend(args, arg);
61 }
62 
CreateMetalObject(id<MTLDevice> device,GPUObjectDescriptor * desc,GPUObjectPtr * result)63 absl::Status CreateMetalObject(id<MTLDevice> device, GPUObjectDescriptor* desc,
64                             GPUObjectPtr* result) {
65   const auto* buffer_desc = dynamic_cast<const BufferDescriptor*>(desc);
66   if (buffer_desc) {
67     Buffer gpu_buffer;
68     RETURN_IF_ERROR(
69         gpu_buffer.CreateFromBufferDescriptor(*buffer_desc, device));
70     *result = std::make_unique<Buffer>(std::move(gpu_buffer));
71     return absl::OkStatus();
72   }
73 
74   const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc);
75   if (tensor_desc) {
76     MetalSpatialTensor gpu_tensor;
77     RETURN_IF_ERROR(gpu_tensor.CreateFromDescriptor(*tensor_desc, device));
78     *result = std::make_unique<MetalSpatialTensor>(std::move(gpu_tensor));
79     return absl::OkStatus();
80   }
81 
82   return absl::InvalidArgumentError("Unknown GPU descriptor.");
83 }
84 
AccessToMetalTextureAccess(AccessType access_type)85 std::string AccessToMetalTextureAccess(AccessType access_type) {
86   if (access_type == AccessType::READ) {
87     return "access::read";
88   } else if (access_type == AccessType::READ_WRITE) {
89     return "access::read_write";
90   } else if (access_type == AccessType::WRITE) {
91     return "access::write";
92   } else {
93     return "access::unknown";
94   }
95 }
96 }  // namespace
97 
98 // Static
99 constexpr char MetalArguments::kArgsPrefix[];
100 
Init(bool use_arguments_buffer,MetalDevice * device,Arguments * args,std::string * code)101 absl::Status MetalArguments::Init(
102     bool use_arguments_buffer, MetalDevice* device, Arguments* args,
103     std::string* code) {
104   RETURN_IF_ERROR(AllocateObjects(*args, device->device()));
105   RETURN_IF_ERROR(AddObjectArgs(device->GetInfo(), *args));
106   args->MoveObjectRefs(&object_refs_);
107   std::string call_prefix = use_arguments_buffer ? "args." : "";
108   std::string struct_desc =
109       CopyScalarArgumentsToStructWithVec4Fields(*args, call_prefix, code);
110   RETURN_IF_ERROR(SetObjectsResources(*args));
111   if (!use_arguments_buffer) {
112     args->ResolveArgsPass(code);
113   }
114   std::string header = R"(
115 #include <metal_stdlib>
116 using namespace metal;
117 
118 )";
119   header += struct_desc + "\n";
120   if (use_arguments_buffer) {
121     const std::string arg_buf_struct =
122         GetArgumentBufferStructDefinition(!struct_desc.empty());
123     header += arg_buf_struct + "\n";
124   }
125   *code = header + *code;
126   std::string arguments;
127   if (use_arguments_buffer) {
128     arguments = "device ArgBuffer& args[[buffer(0)]]";
129   } else {
130     arguments = GetListOfArgs(/*buffer_offset*/ 0);
131   }
132   const bool use_global_id = code->find("GLOBAL_ID_") != std::string::npos;
133   const bool use_local_id = code->find("LOCAL_ID_") != std::string::npos;
134   const bool use_group_id = code->find("GROUP_ID_") != std::string::npos;
135   const bool use_group_size = code->find("GROUP_SIZE_") != std::string::npos;
136   const bool use_simd_id =
137       code->find("SUB_GROUP_LOCAL_ID") != std::string::npos;
138   if (use_global_id) {
139     AppendArgument("uint3 reserved_gid[[thread_position_in_grid]]", &arguments);
140   }
141   if (use_local_id) {
142     AppendArgument("uint3 reserved_lid[[thread_position_in_threadgroup]]",
143                    &arguments);
144   }
145   if (use_group_id) {
146     AppendArgument("uint3 reserved_group_id[[threadgroup_position_in_grid]]",
147                    &arguments);
148   }
149   if (use_group_size) {
150     AppendArgument("uint3 reserved_group_size[[threads_per_threadgroup]]",
151                    &arguments);
152   }
153   if (use_simd_id) {
154     AppendArgument("uint reserved_simd_id[[thread_index_in_simdgroup]]",
155                    &arguments);
156   }
157   if (!use_global_id && !use_local_id && !use_group_id && !use_group_size &&
158       !arguments.empty()) {
159     arguments += ",\n";
160   }
161   *code = absl::Substitute(*code, arguments);
162   return absl::OkStatus();
163 }
164 
Init(bool use_arguments_buffer,MetalDevice * device,Arguments * args)165 absl::Status MetalArguments::Init(bool use_arguments_buffer,
166                                   MetalDevice* device, Arguments* args) {
167   RETURN_IF_ERROR(AllocateObjects(*args, device->device()));
168   RETURN_IF_ERROR(AddObjectArgs(device->GetInfo(), *args));
169   args->MoveObjectRefs(&object_refs_);
170   CopyScalarArgumentsToStructWithVec4Fields(*args);
171   RETURN_IF_ERROR(SetObjectsResources(*args));
172   return absl::OkStatus();
173 }
174 
CopyScalarArgumentsToStructWithScalarFields(const Arguments & args,const std::string & call_prefix,std::string * code)175 std::string MetalArguments::CopyScalarArgumentsToStructWithScalarFields(
176     const Arguments& args, const std::string& call_prefix, std::string* code) {
177   std::string struct_desc = "struct uniforms_buffer {\n";
178   int pos = 0;
179   for (auto& fvalue : args.GetFloatValues()) {
180     auto& new_val = float_values_[fvalue.first];
181     new_val.value = fvalue.second.value;
182     new_val.active = fvalue.second.active;
183     if (fvalue.second.active) {
184       new_val.bytes_offset = pos * 4;
185       pos++;
186       struct_desc += "  float " + fvalue.first + ";\n";
187       ReplaceAllWords(kArgsPrefix + fvalue.first,
188                       call_prefix + "U." + fvalue.first, code);
189     }
190   }
191   for (const auto& hfvalue : args.GetHalfValues()) {
192     auto& new_val = float_values_[hfvalue.first];
193     new_val.value = hfvalue.second.value;
194     new_val.active = hfvalue.second.active;
195     if (hfvalue.second.active) {
196       new_val.bytes_offset = pos * 4;
197       pos++;
198       struct_desc += "  float " + hfvalue.first + ";\n";
199       ReplaceAllWords(
200           kArgsPrefix + hfvalue.first,
201           "static_cast<half>(" + call_prefix + "U." + hfvalue.first + ")",
202           code);
203     }
204   }
205   for (auto& ivalue : args.GetIntValues()) {
206     auto& new_val = int_values_[ivalue.first];
207     new_val.value = ivalue.second.value;
208     new_val.active = ivalue.second.active;
209     if (ivalue.second.active) {
210       new_val.bytes_offset = pos * 4;
211       pos++;
212       struct_desc += "  int " + ivalue.first + ";\n";
213       ReplaceAllWords(kArgsPrefix + ivalue.first,
214                       call_prefix + "U." + ivalue.first, code);
215     }
216   }
217   if (pos != 0) {
218     int aligned_pos = AlignByN(pos, 4);
219     for (int i = pos; i < aligned_pos; i++) {
220       struct_desc += "  int dummy" + std::to_string(i - pos) + ";\n";
221     }
222     struct_desc += "};";
223     const_data_.resize(aligned_pos * 4);
224     for (auto& it : float_values_) {
225       if (it.second.active) {
226         float* ptr =
227             reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
228         *ptr = it.second.value;
229       }
230     }
231     for (auto& it : int_values_) {
232       if (it.second.active) {
233         int32_t* ptr =
234             reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
235         *ptr = it.second.value;
236       }
237     }
238   } else {
239     struct_desc = "";
240   }
241   return struct_desc;
242 }
243 
CopyScalarArgumentsToStructWithVec4Fields(const Arguments & args,const std::string & call_prefix,std::string * code)244 std::string MetalArguments::CopyScalarArgumentsToStructWithVec4Fields(
245     const Arguments& args, const std::string& call_prefix, std::string* code) {
246   std::string struct_desc = "struct uniforms_buffer {\n";
247   int pos = 0;
248   std::string channels[4] = {".x", ".y", ".z", ".w"};
249   for (auto& fvalue : args.GetFloatValues()) {
250     auto& new_val = float_values_[fvalue.first];
251     new_val.value = fvalue.second.value;
252     new_val.active = fvalue.second.active;
253     if (fvalue.second.active) {
254       new_val.bytes_offset = pos * 4;
255       if (pos % 4 == 0) {
256         struct_desc += "  float4 cmp_float4_" + std::to_string(pos / 4) + ";\n";
257       }
258       std::string new_name = call_prefix + "U.cmp_float4_" +
259                              std::to_string(pos / 4) + channels[pos % 4];
260       ReplaceAllWords(kArgsPrefix + fvalue.first, new_name, code);
261       pos++;
262     }
263   }
264   for (const auto& hfvalue : args.GetHalfValues()) {
265     auto& new_val = float_values_[hfvalue.first];
266     new_val.value = hfvalue.second.value;
267     new_val.active = hfvalue.second.active;
268     if (hfvalue.second.active) {
269       new_val.bytes_offset = pos * 4;
270       if (pos % 4 == 0) {
271         struct_desc += "  float4 cmp_float4_" + std::to_string(pos / 4) + ";\n";
272       }
273       std::string new_name = "static_cast<half>(" + call_prefix +
274                              "U.cmp_float4_" + std::to_string(pos / 4) +
275                              channels[pos % 4] + ")";
276       ReplaceAllWords(kArgsPrefix + hfvalue.first, new_name, code);
277       pos++;
278     }
279   }
280   pos = AlignByN(pos, 4);
281   for (auto& ivalue : args.GetIntValues()) {
282     auto& new_val = int_values_[ivalue.first];
283     new_val.value = ivalue.second.value;
284     new_val.active = ivalue.second.active;
285     if (ivalue.second.active) {
286       new_val.bytes_offset = pos * 4;
287       if (pos % 4 == 0) {
288         struct_desc += "  int4 cmp_int4_" + std::to_string(pos / 4) + ";\n";
289       }
290       std::string new_name = call_prefix + "U.cmp_int4_" +
291                              std::to_string(pos / 4) + channels[pos % 4];
292       ReplaceAllWords(kArgsPrefix + ivalue.first, new_name, code);
293       pos++;
294     }
295   }
296   if (pos != 0) {
297     int aligned_pos = AlignByN(pos, 4);
298     struct_desc += "};";
299     const_data_.resize(aligned_pos * 4);
300     for (auto& it : float_values_) {
301       if (it.second.active) {
302         float* ptr =
303             reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
304         *ptr = it.second.value;
305       }
306     }
307     for (auto& it : int_values_) {
308       if (it.second.active) {
309         int32_t* ptr =
310             reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
311         *ptr = it.second.value;
312       }
313     }
314   } else {
315     struct_desc = "";
316   }
317   return struct_desc;
318 }
319 
GetArgumentBufferStructDefinition(bool add_constants_struct)320 std::string MetalArguments::GetArgumentBufferStructDefinition(
321     bool add_constants_struct) {
322   std::string result;
323   result = "struct ArgBuffer {\n";
324   int index = 0;
325   for (auto& t : buffers_) {
326     std::string mem_type = MemoryTypeToMetalType(t.second.desc.memory_type);
327     std::string metal_type;
328     if (t.second.desc.data_type == DataType::BOOL) {
329       metal_type = ToMetalDataType(DataType::UINT8, t.second.desc.element_size);
330     } else {
331       metal_type =
332           ToMetalDataType(t.second.desc.data_type, t.second.desc.element_size);
333     }
334     result += absl::StrCat("  ", mem_type, " ", metal_type, "* ", t.first,
335                            "[[id(", index, ")]];\n");
336     index++;
337   }
338   for (auto& t : images2d_) {
339     std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
340     std::string data_type =
341         ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
342     result += absl::StrCat("  texture2d<", data_type, ", ", access, "> ",
343                            t.first, "[[id(", index, ")]];\n");
344     index++;
345   }
346   for (auto& t : image2d_arrays_) {
347     std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
348     std::string data_type =
349         ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
350     result += absl::StrCat("  texture2d_array<", data_type, ", ", access, "> ",
351                            t.first, "[[id(", index, ")]];\n");
352     index++;
353   }
354   for (auto& t : images3d_) {
355     std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
356     std::string data_type =
357         ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
358     result += absl::StrCat("  texture3d<", data_type, ", ", access, "> ",
359                            t.first, "[[id(", index, ")]];\n");
360     index++;
361   }
362   for (auto& t : image_buffers_) {
363     std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
364     std::string data_type =
365         ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
366     result += absl::StrCat("  texture_buffer<", data_type, ", ", access, "> ",
367                            t.first, "[[id(", index, ")]];\n");
368     index++;
369   }
370   if (add_constants_struct) {
371     result += "  uniforms_buffer U;\n";
372   }
373   result += "};";
374   return result;
375 }
376 
SetInt(const std::string & name,int value)377 absl::Status MetalArguments::SetInt(const std::string& name, int value) {
378   auto it = int_values_.find(name);
379   if (it == int_values_.end()) {
380     return absl::NotFoundError(
381         absl::StrCat("No int argument with name - ", name));
382   }
383   it->second.value = value;
384   if (it->second.active) {
385     int32_t* ptr =
386         reinterpret_cast<int32_t*>(&const_data_[it->second.bytes_offset]);
387     *ptr = value;
388   }
389   return absl::OkStatus();
390 }
SetFloat(const std::string & name,float value)391 absl::Status MetalArguments::SetFloat(const std::string& name, float value) {
392   auto it = float_values_.find(name);
393   if (it == float_values_.end()) {
394     return absl::NotFoundError(
395         absl::StrCat("No float argument with name - ", name));
396   }
397   it->second.value = value;
398   if (it->second.active) {
399     float* ptr =
400         reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
401     *ptr = value;
402   }
403   return absl::OkStatus();
404 }
405 
SetHalf(const std::string & name,half value)406 absl::Status MetalArguments::SetHalf(const std::string& name, half value) {
407   auto it = float_values_.find(name);
408   if (it == float_values_.end()) {
409     return absl::NotFoundError(
410         absl::StrCat("No half argument with name - ", name));
411   }
412   it->second.value = value;
413   if (it->second.active) {
414     float* ptr =
415         reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
416     *ptr = value;
417   }
418   return absl::OkStatus();
419 }
420 
SetObjectRef(const std::string & name,const GPUObject & object)421 absl::Status MetalArguments::SetObjectRef(const std::string& name,
422                                           const GPUObject& object) {
423   auto it = object_refs_.find(name);
424   if (it == object_refs_.end()) {
425     return absl::NotFoundError(
426         absl::StrCat("No object ref with name - ", name));
427   }
428   GPUResourcesWithValue resources;
429   RETURN_IF_ERROR(object.GetGPUResources(it->second.get(), &resources));
430   return SetGPUResources(name, resources);
431 }
432 
Encode(id<MTLComputeCommandEncoder> encoder,int buffer_offset,int texture_offset) const433 void MetalArguments::Encode(id<MTLComputeCommandEncoder> encoder,
434                             int buffer_offset, int texture_offset) const {
435   for (auto& b : buffers_) {
436     [encoder setBuffer:b.second.handle
437                 offset:b.second.offset
438                atIndex:buffer_offset];
439     buffer_offset++;
440   }
441   for (auto& image : images2d_) {
442     [encoder setTexture:image.second.handle atIndex:texture_offset];
443     texture_offset++;
444   }
445   for (auto& image : image2d_arrays_) {
446     [encoder setTexture:image.second.handle atIndex:texture_offset];
447     texture_offset++;
448   }
449   for (auto& image : images3d_) {
450     [encoder setTexture:image.second.handle atIndex:texture_offset];
451     texture_offset++;
452   }
453   for (auto& image : image_buffers_) {
454     [encoder setTexture:image.second.handle atIndex:texture_offset];
455     texture_offset++;
456   }
457 
458   if (!const_data_.empty()) {
459     [encoder setBytes:const_data_.data()
460                length:const_data_.size()
461               atIndex:buffer_offset];
462   }
463 }
464 
465 API_AVAILABLE(ios(11.0), macos(10.13), tvos(11.0))
AddResourcesToEncoder(id<MTLComputeCommandEncoder> encoder) const466 void MetalArguments::AddResourcesToEncoder(
467     id<MTLComputeCommandEncoder> encoder) const {
468   for (auto& b : buffers_) {
469     [encoder useResource:b.second.handle
470                    usage:MTLResourceUsageRead | MTLResourceUsageWrite];
471   }
472   for (auto& image : images2d_) {
473     [encoder useResource:image.second.handle
474                    usage:MTLResourceUsageRead | MTLResourceUsageWrite];
475   }
476   for (auto& image : image2d_arrays_) {
477     [encoder useResource:image.second.handle
478                    usage:MTLResourceUsageRead | MTLResourceUsageWrite];
479   }
480   for (auto& image : images3d_) {
481     [encoder useResource:image.second.handle
482                    usage:MTLResourceUsageRead | MTLResourceUsageWrite];
483   }
484   for (auto& image : image_buffers_) {
485     [encoder useResource:image.second.handle
486                    usage:MTLResourceUsageRead | MTLResourceUsageWrite];
487   }
488 }
489 
490 API_AVAILABLE(ios(11.0), macos(10.13), tvos(11.0))
EncodeArguments(id<MTLArgumentEncoder> arguments_encoder)491 void MetalArguments::EncodeArguments(id<MTLArgumentEncoder> arguments_encoder) {
492   int index = 0;
493   for (auto& b : buffers_) {
494     [arguments_encoder setBuffer:b.second.handle
495                           offset:b.second.offset
496                          atIndex:index];
497     index++;
498   }
499   for (auto& image : images2d_) {
500     [arguments_encoder setTexture:image.second.handle atIndex:index];
501     index++;
502   }
503   for (auto& image : image2d_arrays_) {
504     [arguments_encoder setTexture:image.second.handle atIndex:index];
505     index++;
506   }
507   for (auto& image : images3d_) {
508     [arguments_encoder setTexture:image.second.handle atIndex:index];
509     index++;
510   }
511   for (auto& image : image_buffers_) {
512     [arguments_encoder setTexture:image.second.handle atIndex:index];
513     index++;
514   }
515   if (!const_data_.empty()) {
516     std::memcpy([arguments_encoder constantDataAtIndex:index],
517                 const_data_.data(), const_data_.size());
518   }
519 }
520 
AllocateObjects(const Arguments & args,id<MTLDevice> device)521 absl::Status MetalArguments::AllocateObjects(const Arguments& args,
522                                           id<MTLDevice> device) {
523   objects_.resize(args.GetObjects().size());
524   int i = 0;
525   for (auto& t : args.GetObjects()) {
526     RETURN_IF_ERROR(CreateMetalObject(device, t.second.get(), &objects_[i]));
527     i++;
528   }
529   return absl::OkStatus();
530 }
531 
AddObjectArgs(const GpuInfo & gpu_info,const Arguments & args)532 absl::Status MetalArguments::AddObjectArgs(const GpuInfo& gpu_info,
533                                            const Arguments& args) {
534   for (const auto& t : args.GetObjects()) {
535     AddGPUResources(t.first, t.second->GetGPUResources(gpu_info));
536   }
537   for (const auto& t : args.GetObjectRefs()) {
538     AddGPUResources(t.first, t.second->GetGPUResources(gpu_info));
539   }
540   return absl::OkStatus();
541 }
542 
GetListOfArgs(int buffer_offset,int textures_offset)543 std::string MetalArguments::GetListOfArgs(int buffer_offset,
544                                           int textures_offset) {
545   std::string result;
546   for (auto& t : buffers_) {
547     std::string metal_type;
548     if (t.second.desc.data_type == DataType::BOOL) {
549       metal_type = ToMetalDataType(DataType::UINT8, t.second.desc.element_size);
550     } else {
551       metal_type =
552           ToMetalDataType(t.second.desc.data_type, t.second.desc.element_size);
553     }
554     AppendArgument(
555         absl::StrCat(MemoryTypeToMetalType(t.second.desc.memory_type), " ",
556                      metal_type, "* ", t.first, "[[buffer(", buffer_offset,
557                      ")]]"),
558         &result);
559     buffer_offset++;
560   }
561   for (auto& t : images2d_) {
562     std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
563     std::string data_type =
564         ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
565     if (t.second.desc.normalized) {
566       data_type = ToMetalDataType(t.second.desc.normalized_type);
567     }
568     AppendArgument(absl::StrCat("texture2d<", data_type, ", ", access, "> ",
569                                 t.first, "[[texture(", textures_offset, ")]]"),
570                    &result);
571     textures_offset++;
572   }
573   for (auto& t : image2d_arrays_) {
574     std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
575     std::string data_type =
576         ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
577     AppendArgument(
578         absl::StrCat("texture2d_array<", data_type, ", ", access, "> ", t.first,
579                      "[[texture(", textures_offset, ")]]"),
580         &result);
581     textures_offset++;
582   }
583   for (auto& t : images3d_) {
584     std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
585     std::string data_type =
586         ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
587     AppendArgument(absl::StrCat("texture3d<", data_type, ", ", access, "> ",
588                                 t.first, "[[texture(", textures_offset, ")]]"),
589                    &result);
590     textures_offset++;
591   }
592   for (auto& t : image_buffers_) {
593     std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
594     std::string data_type =
595         ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
596     AppendArgument(
597         absl::StrCat("texture_buffer<", data_type, ", ", access, "> ", t.first,
598                      "[[texture(", textures_offset, ")]]"),
599         &result);
600     textures_offset++;
601   }
602   if (!const_data_.empty()) {
603     AppendArgument(absl::StrCat("constant uniforms_buffer& U[[buffer(",
604                                 buffer_offset, ")]]"),
605                    &result);
606     buffer_offset++;
607   }
608   return result;
609 }
610 
SetGPUResources(const std::string & name,const GPUResourcesWithValue & resources)611 absl::Status MetalArguments::SetGPUResources(
612     const std::string& name, const GPUResourcesWithValue& resources) {
613   for (const auto& r : resources.generic.ints) {
614     RETURN_IF_ERROR(SetInt(absl::StrCat(name, "_", r.first), r.second));
615   }
616   for (const auto& r : resources.generic.floats) {
617     RETURN_IF_ERROR(SetFloat(absl::StrCat(name, "_", r.first), r.second));
618   }
619   for (const auto& r : resources.buffers) {
620     RETURN_IF_ERROR(SetBuffer(absl::StrCat(name, "_", r.first), r.second.handle,
621                               r.second.offset));
622   }
623   for (const auto& r : resources.images2d) {
624     RETURN_IF_ERROR(SetImage2D(absl::StrCat(name, "_", r.first), r.second));
625   }
626   for (const auto& r : resources.image2d_arrays) {
627     RETURN_IF_ERROR(
628         SetImage2DArray(absl::StrCat(name, "_", r.first), r.second));
629   }
630   for (const auto& r : resources.images3d) {
631     RETURN_IF_ERROR(SetImage3D(absl::StrCat(name, "_", r.first), r.second));
632   }
633   for (const auto& r : resources.image_buffers) {
634     RETURN_IF_ERROR(SetImageBuffer(absl::StrCat(name, "_", r.first), r.second));
635   }
636   return absl::OkStatus();
637 }
638 
AddBuffer(const std::string & name,const GPUBufferDescriptor & desc)639 void MetalArguments::AddBuffer(const std::string& name,
640                                const GPUBufferDescriptor& desc) {
641   buffers_[name].desc = desc;
642 }
643 
AddImage2D(const std::string & name,const GPUImage2DDescriptor & desc)644 void MetalArguments::AddImage2D(const std::string& name,
645                                 const GPUImage2DDescriptor& desc) {
646   images2d_[name].desc = desc;
647 }
648 
AddImage2DArray(const std::string & name,const GPUImage2DArrayDescriptor & desc)649 void MetalArguments::AddImage2DArray(const std::string& name,
650                                      const GPUImage2DArrayDescriptor& desc) {
651   image2d_arrays_[name].desc = desc;
652 }
653 
AddImage3D(const std::string & name,const GPUImage3DDescriptor & desc)654 void MetalArguments::AddImage3D(const std::string& name,
655                                 const GPUImage3DDescriptor& desc) {
656   images3d_[name].desc = desc;
657 }
658 
AddImageBuffer(const std::string & name,const GPUImageBufferDescriptor & desc)659 void MetalArguments::AddImageBuffer(const std::string& name,
660                                     const GPUImageBufferDescriptor& desc) {
661   image_buffers_[name].desc = desc;
662 }
663 
AddGPUResources(const std::string & name,const GPUResources & resources)664 void MetalArguments::AddGPUResources(const std::string& name,
665                                      const GPUResources& resources) {
666   for (const auto& r : resources.buffers) {
667     AddBuffer(absl::StrCat(name, "_", r.first), r.second);
668   }
669   for (const auto& r : resources.images2d) {
670     AddImage2D(absl::StrCat(name, "_", r.first), r.second);
671   }
672   for (const auto& r : resources.image2d_arrays) {
673     AddImage2DArray(absl::StrCat(name, "_", r.first), r.second);
674   }
675   for (const auto& r : resources.images3d) {
676     AddImage3D(absl::StrCat(name, "_", r.first), r.second);
677   }
678   for (const auto& r : resources.image_buffers) {
679     AddImageBuffer(absl::StrCat(name, "_", r.first), r.second);
680   }
681 }
682 
SetBuffer(const std::string & name,id<MTLBuffer> handle,uint64_t offset)683 absl::Status MetalArguments::SetBuffer(const std::string& name,
684                                        id<MTLBuffer> handle, uint64_t offset) {
685   auto it = buffers_.find(name);
686   if (it == buffers_.end()) {
687     return absl::NotFoundError(
688         absl::StrCat("No buffer argument with name - ", name));
689   }
690   it->second.handle = handle;
691   it->second.offset = offset;
692   return absl::OkStatus();
693 }
694 
SetImage2D(const std::string & name,id<MTLTexture> handle)695 absl::Status MetalArguments::SetImage2D(const std::string& name,
696                                         id<MTLTexture> handle) {
697   auto it = images2d_.find(name);
698   if (it == images2d_.end()) {
699     return absl::NotFoundError(
700         absl::StrCat("No image2d argument with name - ", name));
701   }
702   it->second.handle = handle;
703   return absl::OkStatus();
704 }
705 
SetImage2DArray(const std::string & name,id<MTLTexture> handle)706 absl::Status MetalArguments::SetImage2DArray(const std::string& name,
707                                              id<MTLTexture> handle) {
708   auto it = image2d_arrays_.find(name);
709   if (it == image2d_arrays_.end()) {
710     return absl::NotFoundError(
711         absl::StrCat("No image2d array argument with name - ", name));
712   }
713   it->second.handle = handle;
714   return absl::OkStatus();
715 }
716 
SetImage3D(const std::string & name,id<MTLTexture> handle)717 absl::Status MetalArguments::SetImage3D(const std::string& name,
718                                         id<MTLTexture> handle) {
719   auto it = images3d_.find(name);
720   if (it == images3d_.end()) {
721     return absl::NotFoundError(
722         absl::StrCat("No image3d argument with name - ", name));
723   }
724   it->second.handle = handle;
725   return absl::OkStatus();
726 }
727 
SetImageBuffer(const std::string & name,id<MTLTexture> handle)728 absl::Status MetalArguments::SetImageBuffer(const std::string& name,
729                                             id<MTLTexture> handle) {
730   auto it = image_buffers_.find(name);
731   if (it == image_buffers_.end()) {
732     return absl::NotFoundError(
733         absl::StrCat("No image buffer argument with name - ", name));
734   }
735   it->second.handle = handle;
736   return absl::OkStatus();
737 }
738 
SetObjectsResources(const Arguments & args)739 absl::Status MetalArguments::SetObjectsResources(const Arguments& args) {
740   int i = 0;
741   for (const auto& t : args.GetObjects()) {
742     GPUResourcesWithValue resources;
743     RETURN_IF_ERROR(objects_[i]->GetGPUResources(t.second.get(), &resources));
744     RETURN_IF_ERROR(SetGPUResources(t.first, resources));
745     i++;
746   }
747   return absl::OkStatus();
748 }
749 
750 }  // namespace metal
751 }  // namespace gpu
752 }  // namespace tflite
753