1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
16
17 #include <cstring>
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "absl/strings/substitute.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
24 #include "tensorflow/lite/delegates/gpu/common/util.h"
25 #include "tensorflow/lite/delegates/gpu/metal/buffer.h"
26 #include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
27
28 namespace tflite {
29 namespace gpu {
30 namespace metal {
31 namespace {
IsWordSymbol(char symbol)32 bool IsWordSymbol(char symbol) {
33 return absl::ascii_isalnum(symbol) || symbol == '_';
34 }
35
ReplaceAllWords(const std::string & old_word,const std::string & new_word,std::string * str)36 void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
37 std::string* str) {
38 if (!str) {
39 return;
40 }
41 size_t position = str->find(old_word);
42 while (position != std::string::npos) {
43 char prev = position == 0 ? '.' : (*str)[position - 1];
44 char next = position + old_word.size() < str->size()
45 ? (*str)[position + old_word.size()]
46 : '.';
47 if (IsWordSymbol(prev) || IsWordSymbol(next)) {
48 position = str->find(old_word, position + 1);
49 continue;
50 }
51 str->replace(position, old_word.size(), new_word);
52 position = str->find(old_word, position + new_word.size());
53 }
54 }
55
AppendArgument(const std::string & arg,std::string * args)56 void AppendArgument(const std::string& arg, std::string* args) {
57 if (!args->empty()) {
58 absl::StrAppend(args, ",\n");
59 }
60 absl::StrAppend(args, arg);
61 }
62
CreateMetalObject(id<MTLDevice> device,GPUObjectDescriptor * desc,GPUObjectPtr * result)63 absl::Status CreateMetalObject(id<MTLDevice> device, GPUObjectDescriptor* desc,
64 GPUObjectPtr* result) {
65 const auto* buffer_desc = dynamic_cast<const BufferDescriptor*>(desc);
66 if (buffer_desc) {
67 Buffer gpu_buffer;
68 RETURN_IF_ERROR(
69 gpu_buffer.CreateFromBufferDescriptor(*buffer_desc, device));
70 *result = std::make_unique<Buffer>(std::move(gpu_buffer));
71 return absl::OkStatus();
72 }
73
74 const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc);
75 if (tensor_desc) {
76 MetalSpatialTensor gpu_tensor;
77 RETURN_IF_ERROR(gpu_tensor.CreateFromDescriptor(*tensor_desc, device));
78 *result = std::make_unique<MetalSpatialTensor>(std::move(gpu_tensor));
79 return absl::OkStatus();
80 }
81
82 return absl::InvalidArgumentError("Unknown GPU descriptor.");
83 }
84
AccessToMetalTextureAccess(AccessType access_type)85 std::string AccessToMetalTextureAccess(AccessType access_type) {
86 if (access_type == AccessType::READ) {
87 return "access::read";
88 } else if (access_type == AccessType::READ_WRITE) {
89 return "access::read_write";
90 } else if (access_type == AccessType::WRITE) {
91 return "access::write";
92 } else {
93 return "access::unknown";
94 }
95 }
96 } // namespace
97
98 // Static
99 constexpr char MetalArguments::kArgsPrefix[];
100
Init(bool use_arguments_buffer,MetalDevice * device,Arguments * args,std::string * code)101 absl::Status MetalArguments::Init(
102 bool use_arguments_buffer, MetalDevice* device, Arguments* args,
103 std::string* code) {
104 RETURN_IF_ERROR(AllocateObjects(*args, device->device()));
105 RETURN_IF_ERROR(AddObjectArgs(device->GetInfo(), *args));
106 args->MoveObjectRefs(&object_refs_);
107 std::string call_prefix = use_arguments_buffer ? "args." : "";
108 std::string struct_desc =
109 CopyScalarArgumentsToStructWithVec4Fields(*args, call_prefix, code);
110 RETURN_IF_ERROR(SetObjectsResources(*args));
111 if (!use_arguments_buffer) {
112 args->ResolveArgsPass(code);
113 }
114 std::string header = R"(
115 #include <metal_stdlib>
116 using namespace metal;
117
118 )";
119 header += struct_desc + "\n";
120 if (use_arguments_buffer) {
121 const std::string arg_buf_struct =
122 GetArgumentBufferStructDefinition(!struct_desc.empty());
123 header += arg_buf_struct + "\n";
124 }
125 *code = header + *code;
126 std::string arguments;
127 if (use_arguments_buffer) {
128 arguments = "device ArgBuffer& args[[buffer(0)]]";
129 } else {
130 arguments = GetListOfArgs(/*buffer_offset*/ 0);
131 }
132 const bool use_global_id = code->find("GLOBAL_ID_") != std::string::npos;
133 const bool use_local_id = code->find("LOCAL_ID_") != std::string::npos;
134 const bool use_group_id = code->find("GROUP_ID_") != std::string::npos;
135 const bool use_group_size = code->find("GROUP_SIZE_") != std::string::npos;
136 const bool use_simd_id =
137 code->find("SUB_GROUP_LOCAL_ID") != std::string::npos;
138 if (use_global_id) {
139 AppendArgument("uint3 reserved_gid[[thread_position_in_grid]]", &arguments);
140 }
141 if (use_local_id) {
142 AppendArgument("uint3 reserved_lid[[thread_position_in_threadgroup]]",
143 &arguments);
144 }
145 if (use_group_id) {
146 AppendArgument("uint3 reserved_group_id[[threadgroup_position_in_grid]]",
147 &arguments);
148 }
149 if (use_group_size) {
150 AppendArgument("uint3 reserved_group_size[[threads_per_threadgroup]]",
151 &arguments);
152 }
153 if (use_simd_id) {
154 AppendArgument("uint reserved_simd_id[[thread_index_in_simdgroup]]",
155 &arguments);
156 }
157 if (!use_global_id && !use_local_id && !use_group_id && !use_group_size &&
158 !arguments.empty()) {
159 arguments += ",\n";
160 }
161 *code = absl::Substitute(*code, arguments);
162 return absl::OkStatus();
163 }
164
Init(bool use_arguments_buffer,MetalDevice * device,Arguments * args)165 absl::Status MetalArguments::Init(bool use_arguments_buffer,
166 MetalDevice* device, Arguments* args) {
167 RETURN_IF_ERROR(AllocateObjects(*args, device->device()));
168 RETURN_IF_ERROR(AddObjectArgs(device->GetInfo(), *args));
169 args->MoveObjectRefs(&object_refs_);
170 CopyScalarArgumentsToStructWithVec4Fields(*args);
171 RETURN_IF_ERROR(SetObjectsResources(*args));
172 return absl::OkStatus();
173 }
174
CopyScalarArgumentsToStructWithScalarFields(const Arguments & args,const std::string & call_prefix,std::string * code)175 std::string MetalArguments::CopyScalarArgumentsToStructWithScalarFields(
176 const Arguments& args, const std::string& call_prefix, std::string* code) {
177 std::string struct_desc = "struct uniforms_buffer {\n";
178 int pos = 0;
179 for (auto& fvalue : args.GetFloatValues()) {
180 auto& new_val = float_values_[fvalue.first];
181 new_val.value = fvalue.second.value;
182 new_val.active = fvalue.second.active;
183 if (fvalue.second.active) {
184 new_val.bytes_offset = pos * 4;
185 pos++;
186 struct_desc += " float " + fvalue.first + ";\n";
187 ReplaceAllWords(kArgsPrefix + fvalue.first,
188 call_prefix + "U." + fvalue.first, code);
189 }
190 }
191 for (const auto& hfvalue : args.GetHalfValues()) {
192 auto& new_val = float_values_[hfvalue.first];
193 new_val.value = hfvalue.second.value;
194 new_val.active = hfvalue.second.active;
195 if (hfvalue.second.active) {
196 new_val.bytes_offset = pos * 4;
197 pos++;
198 struct_desc += " float " + hfvalue.first + ";\n";
199 ReplaceAllWords(
200 kArgsPrefix + hfvalue.first,
201 "static_cast<half>(" + call_prefix + "U." + hfvalue.first + ")",
202 code);
203 }
204 }
205 for (auto& ivalue : args.GetIntValues()) {
206 auto& new_val = int_values_[ivalue.first];
207 new_val.value = ivalue.second.value;
208 new_val.active = ivalue.second.active;
209 if (ivalue.second.active) {
210 new_val.bytes_offset = pos * 4;
211 pos++;
212 struct_desc += " int " + ivalue.first + ";\n";
213 ReplaceAllWords(kArgsPrefix + ivalue.first,
214 call_prefix + "U." + ivalue.first, code);
215 }
216 }
217 if (pos != 0) {
218 int aligned_pos = AlignByN(pos, 4);
219 for (int i = pos; i < aligned_pos; i++) {
220 struct_desc += " int dummy" + std::to_string(i - pos) + ";\n";
221 }
222 struct_desc += "};";
223 const_data_.resize(aligned_pos * 4);
224 for (auto& it : float_values_) {
225 if (it.second.active) {
226 float* ptr =
227 reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
228 *ptr = it.second.value;
229 }
230 }
231 for (auto& it : int_values_) {
232 if (it.second.active) {
233 int32_t* ptr =
234 reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
235 *ptr = it.second.value;
236 }
237 }
238 } else {
239 struct_desc = "";
240 }
241 return struct_desc;
242 }
243
CopyScalarArgumentsToStructWithVec4Fields(const Arguments & args,const std::string & call_prefix,std::string * code)244 std::string MetalArguments::CopyScalarArgumentsToStructWithVec4Fields(
245 const Arguments& args, const std::string& call_prefix, std::string* code) {
246 std::string struct_desc = "struct uniforms_buffer {\n";
247 int pos = 0;
248 std::string channels[4] = {".x", ".y", ".z", ".w"};
249 for (auto& fvalue : args.GetFloatValues()) {
250 auto& new_val = float_values_[fvalue.first];
251 new_val.value = fvalue.second.value;
252 new_val.active = fvalue.second.active;
253 if (fvalue.second.active) {
254 new_val.bytes_offset = pos * 4;
255 if (pos % 4 == 0) {
256 struct_desc += " float4 cmp_float4_" + std::to_string(pos / 4) + ";\n";
257 }
258 std::string new_name = call_prefix + "U.cmp_float4_" +
259 std::to_string(pos / 4) + channels[pos % 4];
260 ReplaceAllWords(kArgsPrefix + fvalue.first, new_name, code);
261 pos++;
262 }
263 }
264 for (const auto& hfvalue : args.GetHalfValues()) {
265 auto& new_val = float_values_[hfvalue.first];
266 new_val.value = hfvalue.second.value;
267 new_val.active = hfvalue.second.active;
268 if (hfvalue.second.active) {
269 new_val.bytes_offset = pos * 4;
270 if (pos % 4 == 0) {
271 struct_desc += " float4 cmp_float4_" + std::to_string(pos / 4) + ";\n";
272 }
273 std::string new_name = "static_cast<half>(" + call_prefix +
274 "U.cmp_float4_" + std::to_string(pos / 4) +
275 channels[pos % 4] + ")";
276 ReplaceAllWords(kArgsPrefix + hfvalue.first, new_name, code);
277 pos++;
278 }
279 }
280 pos = AlignByN(pos, 4);
281 for (auto& ivalue : args.GetIntValues()) {
282 auto& new_val = int_values_[ivalue.first];
283 new_val.value = ivalue.second.value;
284 new_val.active = ivalue.second.active;
285 if (ivalue.second.active) {
286 new_val.bytes_offset = pos * 4;
287 if (pos % 4 == 0) {
288 struct_desc += " int4 cmp_int4_" + std::to_string(pos / 4) + ";\n";
289 }
290 std::string new_name = call_prefix + "U.cmp_int4_" +
291 std::to_string(pos / 4) + channels[pos % 4];
292 ReplaceAllWords(kArgsPrefix + ivalue.first, new_name, code);
293 pos++;
294 }
295 }
296 if (pos != 0) {
297 int aligned_pos = AlignByN(pos, 4);
298 struct_desc += "};";
299 const_data_.resize(aligned_pos * 4);
300 for (auto& it : float_values_) {
301 if (it.second.active) {
302 float* ptr =
303 reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
304 *ptr = it.second.value;
305 }
306 }
307 for (auto& it : int_values_) {
308 if (it.second.active) {
309 int32_t* ptr =
310 reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
311 *ptr = it.second.value;
312 }
313 }
314 } else {
315 struct_desc = "";
316 }
317 return struct_desc;
318 }
319
GetArgumentBufferStructDefinition(bool add_constants_struct)320 std::string MetalArguments::GetArgumentBufferStructDefinition(
321 bool add_constants_struct) {
322 std::string result;
323 result = "struct ArgBuffer {\n";
324 int index = 0;
325 for (auto& t : buffers_) {
326 std::string mem_type = MemoryTypeToMetalType(t.second.desc.memory_type);
327 std::string metal_type;
328 if (t.second.desc.data_type == DataType::BOOL) {
329 metal_type = ToMetalDataType(DataType::UINT8, t.second.desc.element_size);
330 } else {
331 metal_type =
332 ToMetalDataType(t.second.desc.data_type, t.second.desc.element_size);
333 }
334 result += absl::StrCat(" ", mem_type, " ", metal_type, "* ", t.first,
335 "[[id(", index, ")]];\n");
336 index++;
337 }
338 for (auto& t : images2d_) {
339 std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
340 std::string data_type =
341 ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
342 result += absl::StrCat(" texture2d<", data_type, ", ", access, "> ",
343 t.first, "[[id(", index, ")]];\n");
344 index++;
345 }
346 for (auto& t : image2d_arrays_) {
347 std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
348 std::string data_type =
349 ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
350 result += absl::StrCat(" texture2d_array<", data_type, ", ", access, "> ",
351 t.first, "[[id(", index, ")]];\n");
352 index++;
353 }
354 for (auto& t : images3d_) {
355 std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
356 std::string data_type =
357 ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
358 result += absl::StrCat(" texture3d<", data_type, ", ", access, "> ",
359 t.first, "[[id(", index, ")]];\n");
360 index++;
361 }
362 for (auto& t : image_buffers_) {
363 std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
364 std::string data_type =
365 ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
366 result += absl::StrCat(" texture_buffer<", data_type, ", ", access, "> ",
367 t.first, "[[id(", index, ")]];\n");
368 index++;
369 }
370 if (add_constants_struct) {
371 result += " uniforms_buffer U;\n";
372 }
373 result += "};";
374 return result;
375 }
376
SetInt(const std::string & name,int value)377 absl::Status MetalArguments::SetInt(const std::string& name, int value) {
378 auto it = int_values_.find(name);
379 if (it == int_values_.end()) {
380 return absl::NotFoundError(
381 absl::StrCat("No int argument with name - ", name));
382 }
383 it->second.value = value;
384 if (it->second.active) {
385 int32_t* ptr =
386 reinterpret_cast<int32_t*>(&const_data_[it->second.bytes_offset]);
387 *ptr = value;
388 }
389 return absl::OkStatus();
390 }
SetFloat(const std::string & name,float value)391 absl::Status MetalArguments::SetFloat(const std::string& name, float value) {
392 auto it = float_values_.find(name);
393 if (it == float_values_.end()) {
394 return absl::NotFoundError(
395 absl::StrCat("No float argument with name - ", name));
396 }
397 it->second.value = value;
398 if (it->second.active) {
399 float* ptr =
400 reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
401 *ptr = value;
402 }
403 return absl::OkStatus();
404 }
405
SetHalf(const std::string & name,half value)406 absl::Status MetalArguments::SetHalf(const std::string& name, half value) {
407 auto it = float_values_.find(name);
408 if (it == float_values_.end()) {
409 return absl::NotFoundError(
410 absl::StrCat("No half argument with name - ", name));
411 }
412 it->second.value = value;
413 if (it->second.active) {
414 float* ptr =
415 reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
416 *ptr = value;
417 }
418 return absl::OkStatus();
419 }
420
SetObjectRef(const std::string & name,const GPUObject & object)421 absl::Status MetalArguments::SetObjectRef(const std::string& name,
422 const GPUObject& object) {
423 auto it = object_refs_.find(name);
424 if (it == object_refs_.end()) {
425 return absl::NotFoundError(
426 absl::StrCat("No object ref with name - ", name));
427 }
428 GPUResourcesWithValue resources;
429 RETURN_IF_ERROR(object.GetGPUResources(it->second.get(), &resources));
430 return SetGPUResources(name, resources);
431 }
432
Encode(id<MTLComputeCommandEncoder> encoder,int buffer_offset,int texture_offset) const433 void MetalArguments::Encode(id<MTLComputeCommandEncoder> encoder,
434 int buffer_offset, int texture_offset) const {
435 for (auto& b : buffers_) {
436 [encoder setBuffer:b.second.handle
437 offset:b.second.offset
438 atIndex:buffer_offset];
439 buffer_offset++;
440 }
441 for (auto& image : images2d_) {
442 [encoder setTexture:image.second.handle atIndex:texture_offset];
443 texture_offset++;
444 }
445 for (auto& image : image2d_arrays_) {
446 [encoder setTexture:image.second.handle atIndex:texture_offset];
447 texture_offset++;
448 }
449 for (auto& image : images3d_) {
450 [encoder setTexture:image.second.handle atIndex:texture_offset];
451 texture_offset++;
452 }
453 for (auto& image : image_buffers_) {
454 [encoder setTexture:image.second.handle atIndex:texture_offset];
455 texture_offset++;
456 }
457
458 if (!const_data_.empty()) {
459 [encoder setBytes:const_data_.data()
460 length:const_data_.size()
461 atIndex:buffer_offset];
462 }
463 }
464
465 API_AVAILABLE(ios(11.0), macos(10.13), tvos(11.0))
AddResourcesToEncoder(id<MTLComputeCommandEncoder> encoder) const466 void MetalArguments::AddResourcesToEncoder(
467 id<MTLComputeCommandEncoder> encoder) const {
468 for (auto& b : buffers_) {
469 [encoder useResource:b.second.handle
470 usage:MTLResourceUsageRead | MTLResourceUsageWrite];
471 }
472 for (auto& image : images2d_) {
473 [encoder useResource:image.second.handle
474 usage:MTLResourceUsageRead | MTLResourceUsageWrite];
475 }
476 for (auto& image : image2d_arrays_) {
477 [encoder useResource:image.second.handle
478 usage:MTLResourceUsageRead | MTLResourceUsageWrite];
479 }
480 for (auto& image : images3d_) {
481 [encoder useResource:image.second.handle
482 usage:MTLResourceUsageRead | MTLResourceUsageWrite];
483 }
484 for (auto& image : image_buffers_) {
485 [encoder useResource:image.second.handle
486 usage:MTLResourceUsageRead | MTLResourceUsageWrite];
487 }
488 }
489
490 API_AVAILABLE(ios(11.0), macos(10.13), tvos(11.0))
EncodeArguments(id<MTLArgumentEncoder> arguments_encoder)491 void MetalArguments::EncodeArguments(id<MTLArgumentEncoder> arguments_encoder) {
492 int index = 0;
493 for (auto& b : buffers_) {
494 [arguments_encoder setBuffer:b.second.handle
495 offset:b.second.offset
496 atIndex:index];
497 index++;
498 }
499 for (auto& image : images2d_) {
500 [arguments_encoder setTexture:image.second.handle atIndex:index];
501 index++;
502 }
503 for (auto& image : image2d_arrays_) {
504 [arguments_encoder setTexture:image.second.handle atIndex:index];
505 index++;
506 }
507 for (auto& image : images3d_) {
508 [arguments_encoder setTexture:image.second.handle atIndex:index];
509 index++;
510 }
511 for (auto& image : image_buffers_) {
512 [arguments_encoder setTexture:image.second.handle atIndex:index];
513 index++;
514 }
515 if (!const_data_.empty()) {
516 std::memcpy([arguments_encoder constantDataAtIndex:index],
517 const_data_.data(), const_data_.size());
518 }
519 }
520
AllocateObjects(const Arguments & args,id<MTLDevice> device)521 absl::Status MetalArguments::AllocateObjects(const Arguments& args,
522 id<MTLDevice> device) {
523 objects_.resize(args.GetObjects().size());
524 int i = 0;
525 for (auto& t : args.GetObjects()) {
526 RETURN_IF_ERROR(CreateMetalObject(device, t.second.get(), &objects_[i]));
527 i++;
528 }
529 return absl::OkStatus();
530 }
531
AddObjectArgs(const GpuInfo & gpu_info,const Arguments & args)532 absl::Status MetalArguments::AddObjectArgs(const GpuInfo& gpu_info,
533 const Arguments& args) {
534 for (const auto& t : args.GetObjects()) {
535 AddGPUResources(t.first, t.second->GetGPUResources(gpu_info));
536 }
537 for (const auto& t : args.GetObjectRefs()) {
538 AddGPUResources(t.first, t.second->GetGPUResources(gpu_info));
539 }
540 return absl::OkStatus();
541 }
542
GetListOfArgs(int buffer_offset,int textures_offset)543 std::string MetalArguments::GetListOfArgs(int buffer_offset,
544 int textures_offset) {
545 std::string result;
546 for (auto& t : buffers_) {
547 std::string metal_type;
548 if (t.second.desc.data_type == DataType::BOOL) {
549 metal_type = ToMetalDataType(DataType::UINT8, t.second.desc.element_size);
550 } else {
551 metal_type =
552 ToMetalDataType(t.second.desc.data_type, t.second.desc.element_size);
553 }
554 AppendArgument(
555 absl::StrCat(MemoryTypeToMetalType(t.second.desc.memory_type), " ",
556 metal_type, "* ", t.first, "[[buffer(", buffer_offset,
557 ")]]"),
558 &result);
559 buffer_offset++;
560 }
561 for (auto& t : images2d_) {
562 std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
563 std::string data_type =
564 ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
565 if (t.second.desc.normalized) {
566 data_type = ToMetalDataType(t.second.desc.normalized_type);
567 }
568 AppendArgument(absl::StrCat("texture2d<", data_type, ", ", access, "> ",
569 t.first, "[[texture(", textures_offset, ")]]"),
570 &result);
571 textures_offset++;
572 }
573 for (auto& t : image2d_arrays_) {
574 std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
575 std::string data_type =
576 ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
577 AppendArgument(
578 absl::StrCat("texture2d_array<", data_type, ", ", access, "> ", t.first,
579 "[[texture(", textures_offset, ")]]"),
580 &result);
581 textures_offset++;
582 }
583 for (auto& t : images3d_) {
584 std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
585 std::string data_type =
586 ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
587 AppendArgument(absl::StrCat("texture3d<", data_type, ", ", access, "> ",
588 t.first, "[[texture(", textures_offset, ")]]"),
589 &result);
590 textures_offset++;
591 }
592 for (auto& t : image_buffers_) {
593 std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
594 std::string data_type =
595 ToMetalDataType(ToMetalTextureType(t.second.desc.data_type));
596 AppendArgument(
597 absl::StrCat("texture_buffer<", data_type, ", ", access, "> ", t.first,
598 "[[texture(", textures_offset, ")]]"),
599 &result);
600 textures_offset++;
601 }
602 if (!const_data_.empty()) {
603 AppendArgument(absl::StrCat("constant uniforms_buffer& U[[buffer(",
604 buffer_offset, ")]]"),
605 &result);
606 buffer_offset++;
607 }
608 return result;
609 }
610
SetGPUResources(const std::string & name,const GPUResourcesWithValue & resources)611 absl::Status MetalArguments::SetGPUResources(
612 const std::string& name, const GPUResourcesWithValue& resources) {
613 for (const auto& r : resources.generic.ints) {
614 RETURN_IF_ERROR(SetInt(absl::StrCat(name, "_", r.first), r.second));
615 }
616 for (const auto& r : resources.generic.floats) {
617 RETURN_IF_ERROR(SetFloat(absl::StrCat(name, "_", r.first), r.second));
618 }
619 for (const auto& r : resources.buffers) {
620 RETURN_IF_ERROR(SetBuffer(absl::StrCat(name, "_", r.first), r.second.handle,
621 r.second.offset));
622 }
623 for (const auto& r : resources.images2d) {
624 RETURN_IF_ERROR(SetImage2D(absl::StrCat(name, "_", r.first), r.second));
625 }
626 for (const auto& r : resources.image2d_arrays) {
627 RETURN_IF_ERROR(
628 SetImage2DArray(absl::StrCat(name, "_", r.first), r.second));
629 }
630 for (const auto& r : resources.images3d) {
631 RETURN_IF_ERROR(SetImage3D(absl::StrCat(name, "_", r.first), r.second));
632 }
633 for (const auto& r : resources.image_buffers) {
634 RETURN_IF_ERROR(SetImageBuffer(absl::StrCat(name, "_", r.first), r.second));
635 }
636 return absl::OkStatus();
637 }
638
AddBuffer(const std::string & name,const GPUBufferDescriptor & desc)639 void MetalArguments::AddBuffer(const std::string& name,
640 const GPUBufferDescriptor& desc) {
641 buffers_[name].desc = desc;
642 }
643
AddImage2D(const std::string & name,const GPUImage2DDescriptor & desc)644 void MetalArguments::AddImage2D(const std::string& name,
645 const GPUImage2DDescriptor& desc) {
646 images2d_[name].desc = desc;
647 }
648
AddImage2DArray(const std::string & name,const GPUImage2DArrayDescriptor & desc)649 void MetalArguments::AddImage2DArray(const std::string& name,
650 const GPUImage2DArrayDescriptor& desc) {
651 image2d_arrays_[name].desc = desc;
652 }
653
AddImage3D(const std::string & name,const GPUImage3DDescriptor & desc)654 void MetalArguments::AddImage3D(const std::string& name,
655 const GPUImage3DDescriptor& desc) {
656 images3d_[name].desc = desc;
657 }
658
AddImageBuffer(const std::string & name,const GPUImageBufferDescriptor & desc)659 void MetalArguments::AddImageBuffer(const std::string& name,
660 const GPUImageBufferDescriptor& desc) {
661 image_buffers_[name].desc = desc;
662 }
663
AddGPUResources(const std::string & name,const GPUResources & resources)664 void MetalArguments::AddGPUResources(const std::string& name,
665 const GPUResources& resources) {
666 for (const auto& r : resources.buffers) {
667 AddBuffer(absl::StrCat(name, "_", r.first), r.second);
668 }
669 for (const auto& r : resources.images2d) {
670 AddImage2D(absl::StrCat(name, "_", r.first), r.second);
671 }
672 for (const auto& r : resources.image2d_arrays) {
673 AddImage2DArray(absl::StrCat(name, "_", r.first), r.second);
674 }
675 for (const auto& r : resources.images3d) {
676 AddImage3D(absl::StrCat(name, "_", r.first), r.second);
677 }
678 for (const auto& r : resources.image_buffers) {
679 AddImageBuffer(absl::StrCat(name, "_", r.first), r.second);
680 }
681 }
682
SetBuffer(const std::string & name,id<MTLBuffer> handle,uint64_t offset)683 absl::Status MetalArguments::SetBuffer(const std::string& name,
684 id<MTLBuffer> handle, uint64_t offset) {
685 auto it = buffers_.find(name);
686 if (it == buffers_.end()) {
687 return absl::NotFoundError(
688 absl::StrCat("No buffer argument with name - ", name));
689 }
690 it->second.handle = handle;
691 it->second.offset = offset;
692 return absl::OkStatus();
693 }
694
SetImage2D(const std::string & name,id<MTLTexture> handle)695 absl::Status MetalArguments::SetImage2D(const std::string& name,
696 id<MTLTexture> handle) {
697 auto it = images2d_.find(name);
698 if (it == images2d_.end()) {
699 return absl::NotFoundError(
700 absl::StrCat("No image2d argument with name - ", name));
701 }
702 it->second.handle = handle;
703 return absl::OkStatus();
704 }
705
SetImage2DArray(const std::string & name,id<MTLTexture> handle)706 absl::Status MetalArguments::SetImage2DArray(const std::string& name,
707 id<MTLTexture> handle) {
708 auto it = image2d_arrays_.find(name);
709 if (it == image2d_arrays_.end()) {
710 return absl::NotFoundError(
711 absl::StrCat("No image2d array argument with name - ", name));
712 }
713 it->second.handle = handle;
714 return absl::OkStatus();
715 }
716
SetImage3D(const std::string & name,id<MTLTexture> handle)717 absl::Status MetalArguments::SetImage3D(const std::string& name,
718 id<MTLTexture> handle) {
719 auto it = images3d_.find(name);
720 if (it == images3d_.end()) {
721 return absl::NotFoundError(
722 absl::StrCat("No image3d argument with name - ", name));
723 }
724 it->second.handle = handle;
725 return absl::OkStatus();
726 }
727
SetImageBuffer(const std::string & name,id<MTLTexture> handle)728 absl::Status MetalArguments::SetImageBuffer(const std::string& name,
729 id<MTLTexture> handle) {
730 auto it = image_buffers_.find(name);
731 if (it == image_buffers_.end()) {
732 return absl::NotFoundError(
733 absl::StrCat("No image buffer argument with name - ", name));
734 }
735 it->second.handle = handle;
736 return absl::OkStatus();
737 }
738
SetObjectsResources(const Arguments & args)739 absl::Status MetalArguments::SetObjectsResources(const Arguments& args) {
740 int i = 0;
741 for (const auto& t : args.GetObjects()) {
742 GPUResourcesWithValue resources;
743 RETURN_IF_ERROR(objects_[i]->GetGPUResources(t.second.get(), &resources));
744 RETURN_IF_ERROR(SetGPUResources(t.first, resources));
745 i++;
746 }
747 return absl::OkStatus();
748 }
749
750 } // namespace metal
751 } // namespace gpu
752 } // namespace tflite
753