xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/gpu_model.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/gpu_model.h"
17 
18 #include <algorithm>
19 #include <any>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h"
30 #include "tensorflow/lite/delegates/gpu/common/selectors/special_selector.h"
31 #include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
32 #include "tensorflow/lite/delegates/gpu/common/task/serialization_base.h"
33 #include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
34 #include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
35 #include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
36 
37 namespace tflite {
38 namespace gpu {
39 
40 namespace {
IsReady(const absl::flat_hash_set<ValueId> & ready_tensors,const GpuNode & node)41 bool IsReady(const absl::flat_hash_set<ValueId>& ready_tensors,
42              const GpuNode& node) {
43   for (const ValueId in_id : node.inputs) {
44     if (ready_tensors.find(in_id) == ready_tensors.end()) {
45       return false;
46     }
47   }
48   return true;
49 }
50 
MergeGpuNodes(const GpuInfo & gpu_info,GpuNode * src,GpuNode * dst)51 absl::Status MergeGpuNodes(const GpuInfo& gpu_info, GpuNode* src,
52                            GpuNode* dst) {
53   for (int j = 1; j < src->inputs.size(); ++j) {
54     dst->inputs.push_back(src->inputs[j]);
55   }
56   dst->outputs[0] = src->outputs[0];
57   dst->name += " -> " + src->name;
58   return dst->gpu_operation->AddOperation(gpu_info, src->gpu_operation.get());
59 }
60 
Encode(const TensorDescriptor & desc,const ValueId & id,flatbuffers::FlatBufferBuilder * builder)61 flatbuffers::Offset<data::TensorDescWithId> Encode(
62     const TensorDescriptor& desc, const ValueId& id,
63     flatbuffers::FlatBufferBuilder* builder) {
64   auto desc_fb = Encode(desc, builder);
65   data::TensorDescWithIdBuilder desc_builder(*builder);
66   desc_builder.add_desc(desc_fb);
67   desc_builder.add_id(id);
68   return desc_builder.Finish();
69 }
70 
Encode(const GpuNode & node,flatbuffers::FlatBufferBuilder * builder)71 flatbuffers::Offset<data::GpuNode> Encode(
72     const GpuNode& node, flatbuffers::FlatBufferBuilder* builder) {
73   auto op_fb = Encode(*node.gpu_operation, builder);
74   std::vector<int32_t> in_ids(node.inputs.size());
75   for (int i = 0; i < in_ids.size(); ++i) {
76     in_ids[i] = node.inputs[i];
77   }
78   std::vector<int32_t> out_ids(node.outputs.size());
79   for (int i = 0; i < out_ids.size(); ++i) {
80     out_ids[i] = node.outputs[i];
81   }
82   auto in_ids_fb = builder->CreateVector(in_ids);
83   auto out_ids_fb = builder->CreateVector(out_ids);
84   auto name_fb = builder->CreateString(node.name);
85   data::GpuNodeBuilder node_builder(*builder);
86   node_builder.add_gpu_op(op_fb);
87   node_builder.add_input_ids(in_ids_fb);
88   node_builder.add_output_ids(out_ids_fb);
89   node_builder.add_name(name_fb);
90   return node_builder.Finish();
91 }
92 
Decode(const data::GpuNode * fb_node,GpuNode * node)93 absl::Status Decode(const data::GpuNode* fb_node, GpuNode* node) {
94   GPUOperation op;
95   RETURN_IF_ERROR(Decode(fb_node->gpu_op(), &op));
96   node->gpu_operation = std::make_unique<GPUOperation>(std::move(op));
97   for (auto in_fb : *fb_node->input_ids()) {
98     node->inputs.push_back(in_fb);
99   }
100   for (auto out_fb : *fb_node->output_ids()) {
101     node->outputs.push_back(out_fb);
102   }
103   node->name = std::string(fb_node->name()->c_str(), fb_node->name()->size());
104 
105   return absl::OkStatus();
106 }
107 
IsAssociativeLinkableOp(const Node & node,const std::vector<Value * > & inputs,const std::vector<Value * > & outputs)108 bool IsAssociativeLinkableOp(const Node& node,
109                              const std::vector<Value*>& inputs,
110                              const std::vector<Value*>& outputs) {
111   if (inputs.size() == 1) {
112     return false;
113   }
114   const OperationType op_type = OperationTypeFromString(node.operation.type);
115   if (op_type != OperationType::ADD && op_type != OperationType::MUL) {
116     return false;
117   }
118 
119   const auto dst_shape = outputs[0]->tensor.shape;
120   for (int i = 0; i < inputs.size(); ++i) {
121     const auto src_shape = inputs[i]->tensor.shape;
122     if (dst_shape.b != src_shape.b && src_shape.b == 1) {
123       return false;
124     }
125     if (dst_shape.h != src_shape.h && src_shape.h == 1) {
126       return false;
127     }
128     if (dst_shape.w != src_shape.w && src_shape.w == 1) {
129       return false;
130     }
131     if (dst_shape.c != src_shape.c && src_shape.c == 1) {
132       return false;
133     }
134   }
135   return true;
136 }
137 
CheckExternalTensorDescription(const GpuInfo & gpu_info,const TensorDescriptor & tensor_desc,const BHWC & shape,DataType data_type)138 absl::Status CheckExternalTensorDescription(const GpuInfo& gpu_info,
139                                             const TensorDescriptor& tensor_desc,
140                                             const BHWC& shape,
141                                             DataType data_type) {
142   if (tensor_desc.GetDataType() != data_type) {
143     return absl::InvalidArgumentError(
144         "Global precision and precision of predefined/external tensors must be "
145         "synchronized.");
146   }
147   if (tensor_desc.HasAxis(Axis::DEPTH)) {
148     return absl::InvalidArgumentError(
149         "Currently no support of Depth dimension in predefined/external "
150         "tensors.");
151   }
152   if (tensor_desc.HasAxis(Axis::BATCH) && shape.b == 1) {
153     return absl::InvalidArgumentError("Wrong layout, batch mismatch.");
154   }
155   if (!tensor_desc.HasAxis(Axis::BATCH) && shape.b != 1) {
156     return absl::InvalidArgumentError("Wrong layout, batch mismatch.");
157   }
158   if (!tensor_desc.CanCreateTensorWithShape(gpu_info, shape).ok()) {
159     return absl::UnavailableError(
160         "Current device can not allocate tensor with this shape for "
161         "predefined/external descriptor.");
162   }
163   return absl::OkStatus();
164 }
165 
166 // Helper class for creating descriptors for appropriate tensors from
167 // GraphFloat32
168 // Also allows to create descriptors for new tensors(not present in
169 // GraphFloat32)
170 class TensorReserver {
171  public:
TensorReserver()172   TensorReserver() : next_(0) {}
Add(const TensorDescriptor & dummy)173   ValueId Add(const TensorDescriptor& dummy) {
174     reservations_[next_] = dummy;
175     return next_++;
176   }
Add(ValueId id,const TensorDescriptor & dummy)177   void Add(ValueId id, const TensorDescriptor& dummy) {
178     reservations_[id] = dummy;
179   }
GetNewId()180   ValueId GetNewId() { return next_++; }
SetNext(ValueId id)181   void SetNext(ValueId id) { next_ = id; }
Get(ValueId id)182   TensorDescriptor Get(ValueId id) { return reservations_[id]; }
183 
184  public:
185   absl::flat_hash_map<ValueId, TensorDescriptor> reservations_;
186   ValueId next_;
187 };
188 
ReserveGraphTensors(const CreateGpuModelInfo & create_info,const GpuInfo & gpu_info,const GraphFloat32 & graph,TensorReserver * tensor_reserver)189 absl::Status ReserveGraphTensors(const CreateGpuModelInfo& create_info,
190                                  const GpuInfo& gpu_info,
191                                  const GraphFloat32& graph,
192                                  TensorReserver* tensor_reserver) {
193   ValueId max_id = 0;
194   auto tensors = graph.values();
195   for (auto& t : tensors) {
196     auto data_type = DeduceDataTypeFromPrecision(create_info.precision);
197     if (t->tensor.type != DataType::FLOAT32 &&
198         t->tensor.type != DataType::FLOAT16) {
199       data_type = t->tensor.type;
200     }
201     const auto shape = graph.GetValue(t->id)->tensor.shape;
202     auto it_predefined = create_info.predefined.find(t->id);
203     auto it_immutable_external =
204         create_info.external_immutable_tensors.find(t->id);
205     auto it_mutable_external = create_info.external_mutable_tensors.find(t->id);
206     int external_categories_count = 0;
207     TensorDescriptor tensor_desc;
208     if (it_predefined != create_info.predefined.end()) {
209       external_categories_count++;
210       tensor_desc = it_predefined->second;
211     }
212     if (it_immutable_external != create_info.external_immutable_tensors.end()) {
213       external_categories_count++;
214       tensor_desc = it_immutable_external->second->GetDescriptor();
215     }
216     if (it_mutable_external != create_info.external_mutable_tensors.end()) {
217       external_categories_count++;
218       tensor_desc = it_mutable_external->second;
219     }
220     if (external_categories_count > 1) {
221       return absl::InvalidArgumentError(
222           "Tensors ids from predefined / external_immutable_tensors / "
223           "external_mutable_tensors should not intersect.");
224     }
225     if (external_categories_count == 1) {
226       if (!(graph.IsGraphInput(t->id) || graph.IsGraphOutput(t->id))) {
227         return absl::InvalidArgumentError(
228             "Currently external can be used only for graph inputs/outputs");
229       }
230       RETURN_IF_ERROR(CheckExternalTensorDescription(gpu_info, tensor_desc,
231                                                      shape, data_type));
232     } else {
233       TensorStorageType storage_type = create_info.storage_type;
234       Layout layout = shape.b == 1 ? Layout::HWC : Layout::BHWC;
235       const bool can_use_single_texture =
236           storage_type == TensorStorageType::TEXTURE_2D ||
237           storage_type == TensorStorageType::TEXTURE_3D ||
238           storage_type == TensorStorageType::TEXTURE_ARRAY;
239       if (shape.c < 4 && can_use_single_texture &&
240           TensorDescriptor{data_type, TensorStorageType::SINGLE_TEXTURE_2D,
241                            layout}
242               .CanCreateTensorWithShape(gpu_info, shape)
243               .ok()) {
244         storage_type = TensorStorageType::SINGLE_TEXTURE_2D;
245       }
246       tensor_desc = TensorDescriptor{data_type, storage_type, layout};
247       RETURN_IF_ERROR(
248           tensor_desc.UpdateToSupportedStorageType(gpu_info, shape));
249       if (gpu_info.IsApiMetal() &&
250           storage_type == TensorStorageType::TEXTURE_2D) {
251         tensor_desc.SetUseBufferForWriteOnlyTexture2d(true);
252       }
253     }
254     tensor_desc.SetBHWCShape(shape);
255     tensor_reserver->Add(t->id, tensor_desc);
256     max_id = std::max(max_id, t->id);
257   }
258   tensor_reserver->SetNext(max_id + 1);
259   return absl::OkStatus();
260 }
261 
ConvertOperations(const GpuInfo & gpu_info,const GraphFloat32 & graph,const CreateGpuModelInfo & create_info,TensorReserver * tensor_reserver,GpuModel * gpu_model)262 absl::Status ConvertOperations(const GpuInfo& gpu_info,
263                                const GraphFloat32& graph,
264                                const CreateGpuModelInfo& create_info,
265                                TensorReserver* tensor_reserver,
266                                GpuModel* gpu_model) {
267   std::map<ValueId, TensorDescriptor> tensor_descriptors;
268   const auto values = graph.values();
269   for (auto value : values) {
270     tensor_descriptors[value->id] = tensor_reserver->Get(value->id);
271   }
272   std::set<NodeId> consumed_nodes;
273   std::vector<Node*> graph_nodes = graph.nodes();
274   std::map<ValueId, int>
275       tensor_usages;  // keeps latest index of operation that updated tensor
276   for (const auto& input : gpu_model->input_ids_and_refs) {
277     tensor_usages[input.first] = -1;  // so as inputs "updated" before operation
278                                       // 0, we will mark them with -1
279   }
280   std::vector<SharedWeightsConvDesc> shared_conv_weights;
281   std::vector<SharedWeightsConvDesc>* shared_conv_weights_ptr =
282       create_info.hints.Check(ModelHints::kReuseConvWeights)
283           ? &shared_conv_weights
284           : nullptr;
285   for (int i = 0; i < graph_nodes.size(); ++i) {
286     const Node& node = *graph_nodes[i];
287     if (consumed_nodes.find(node.id) != consumed_nodes.end()) {
288       continue;
289     }
290     auto op_type = OperationTypeFromString(node.operation.type);
291     if (op_type == OperationType::CONSTANT) {
292       auto attr =
293           absl::any_cast<ConstTensorAttributes>(node.operation.attributes);
294       auto outputs = graph.FindOutputs(node.id);
295       gpu_model->const_tensors[outputs[0]->id] =
296           tensor_reserver->Get(outputs[0]->id);
297       gpu_model->const_tensors[outputs[0]->id].UploadData(attr.tensor);
298       continue;
299     }
300     GPUOperationsSubgraph gpu_subgraph;
301     if (create_info.hints.Check(ModelHints::kAllowSpecialKernels) &&
302         GPUSubgraphFromGraph(gpu_info, create_info.precision, graph, node.id,
303                              tensor_descriptors, &consumed_nodes, &gpu_subgraph)
304             .ok()) {
305       // Mapping of subgraph (set of nodes) to GPU operations. Should happen
306       // before straigtforward mapping.
307     } else {
308       // Straigtforward mapping of one graph node to GPU operations.
309       auto inputs = graph.FindInputs(node.id);
310       auto outputs = graph.FindOutputs(node.id);
311       // Reordering of input ids and updating of temporary tensors_usage struct.
312       // To have better linking we need linking tensor(latest written during
313       // linear execution) on first position.
314       if (IsAssociativeLinkableOp(node, inputs, outputs)) {
315         int latest_written_tensor_index = 0;
316         int last_usage = tensor_usages[inputs[0]->id];
317         for (int j = 1; j < inputs.size(); ++j) {
318           if (tensor_usages[inputs[j]->id] > last_usage) {
319             last_usage = tensor_usages[inputs[j]->id];
320             latest_written_tensor_index = j;
321           }
322         }
323         std::swap(inputs[0], inputs[latest_written_tensor_index]);
324       }
325       consumed_nodes.insert(node.id);
326       OperationDef op_def;
327       op_def.precision = create_info.precision;
328       for (int j = 0; j < inputs.size(); ++j) {
329         op_def.src_tensors.push_back(tensor_reserver->Get(inputs[j]->id));
330       }
331       for (int j = 0; j < outputs.size(); ++j) {
332         op_def.dst_tensors.push_back(tensor_reserver->Get(outputs[j]->id));
333       }
334       RETURN_IF_ERROR(GPUOperationFromNode(
335           gpu_info, op_def, create_info.hints, inputs, outputs, node,
336           shared_conv_weights_ptr, &gpu_subgraph));
337     }
338     absl::flat_hash_map<int, ValueId> mapping_to_global_ids;
339     for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
340       const auto& t = gpu_subgraph.new_tensors[j];
341       if (!t.GetData().empty()) {  // constant tensor
342         auto global_id = tensor_reserver->GetNewId();
343         gpu_model->const_tensors[global_id] =
344             std::move(gpu_subgraph.new_tensors[j]);
345         mapping_to_global_ids[j] = global_id;
346       } else {
347         auto global_id = tensor_reserver->Add(t);
348         mapping_to_global_ids[j] = global_id;
349       }
350     }
351     if (!shared_conv_weights.empty() && !mapping_to_global_ids.empty()) {
352       shared_conv_weights.back().RemapIds(mapping_to_global_ids);
353     }
354     for (auto& gpu_op : gpu_subgraph.operations) {
355       GpuNode gpu_node;
356       gpu_node.gpu_operation = std::move(gpu_op.operation);
357       gpu_node.inputs.resize(gpu_op.input_ids.size());
358       for (int j = 0; j < gpu_op.input_ids.size(); ++j) {
359         int id = gpu_op.input_ids[j];
360         if (id >= 0) {
361           gpu_node.inputs[j] = id;
362         } else {
363           gpu_node.inputs[j] = mapping_to_global_ids[-(id + 1)];
364         }
365       }
366       gpu_node.outputs.resize(gpu_op.output_ids.size());
367       for (int j = 0; j < gpu_op.output_ids.size(); ++j) {
368         int id = gpu_op.output_ids[j];
369         if (id >= 0) {
370           gpu_node.outputs[j] = id;
371           tensor_usages[id] = i;
372         } else {
373           gpu_node.outputs[j] = mapping_to_global_ids[-(id + 1)];
374         }
375       }
376       gpu_node.name = gpu_op.name;
377       gpu_model->nodes.push_back(std::move(gpu_node));
378     }
379   }
380 
381   return absl::OkStatus();
382 }
383 
MergeElementwiseNodes(const GpuInfo & gpu_info,GpuModel * gpu_model)384 absl::Status MergeElementwiseNodes(const GpuInfo& gpu_info,
385                                    GpuModel* gpu_model) {
386   auto& nodes = gpu_model->nodes;
387   for (int elem_root_index = 1; elem_root_index < nodes.size();
388        ++elem_root_index) {
389     auto& elem_root = nodes[elem_root_index];
390     if (!(elem_root.inputs.size() == 1 || elem_root.inputs.size() == 2) ||
391         elem_root.outputs.size() != 1 ||
392         !elem_root.gpu_operation->IsLinkable()) {
393       continue;
394     }
395     // key is elem_root input index, value is node index
396     std::map<int, int> prev_nodes;
397     for (int j = elem_root_index - 1; j >= 0; --j) {
398       for (int k = 0; k < elem_root.inputs.size(); ++k) {
399         if (elem_root.inputs[k] == nodes[j].outputs[0]) {
400           prev_nodes[k] = j;
401           break;
402         }
403       }
404     }
405     // TYPE_0
406     //    input       input
407     //      |           |
408     //    elem0         |
409     //      |    -->  elem
410     //  elem_root       |
411     //      |           |
412     //    output      output
413     if (prev_nodes.size() == 1) {
414       if (elem_root.inputs.size() != 1) {
415         continue;
416       }
417       const int prev_first_node_index = prev_nodes[0];
418       auto& prev_node = nodes[prev_first_node_index];
419       if (prev_node.inputs.size() != 1 || prev_node.outputs.size() != 1 ||
420           !prev_node.gpu_operation->IsLinkable()) {
421         continue;
422       }
423       int consumers_count = 0;
424       for (const auto& node : nodes) {
425         for (const auto& input : node.inputs) {
426           if (input == elem_root.inputs[0]) {
427             consumers_count++;
428           }
429         }
430       }
431       if (consumers_count != 1) {
432         continue;
433       }
434       prev_node.outputs[0] = elem_root.outputs[0];
435       prev_node.name += " -> " + elem_root.name;
436       RETURN_IF_ERROR(prev_node.gpu_operation->FuseSimpleElemWithSimpleElem(
437           gpu_info, elem_root.gpu_operation.get()));
438       nodes.erase(nodes.begin() + elem_root_index);
439       elem_root_index = prev_first_node_index;
440       continue;
441     }
442 
443     // check TYPE_1/2/3
444     if (prev_nodes.size() == 2) {
445       if (elem_root.inputs.size() != 2) {
446         continue;
447       }
448       const int prev_first_node_index = prev_nodes[0];
449       const int prev_second_node_index = prev_nodes[1];
450       auto& prev_first_node = nodes[prev_first_node_index];
451       auto& prev_second_node = nodes[prev_second_node_index];
452 
453       // check TYPE_1
454       // TYPE_1
455       //      input           input
456       //     /    \             |
457       //   elem0   |            |
458       //     \    /      -->  elem
459       //   elem_root            |
460       //       |                |
461       //     output           output
462       if (prev_first_node.gpu_operation->IsLinkable() &&
463           !prev_second_node.gpu_operation->IsLinkable() &&
464           prev_second_node.outputs.size() == 1 &&
465           prev_first_node.inputs.size() == 1 &&
466           prev_first_node.outputs.size() == 1) {
467         int first_node_parent_index = -1;
468         for (int j = prev_first_node_index - 1; j >= 0; --j) {
469           if (nodes[j].outputs[0] == prev_first_node.inputs[0]) {
470             first_node_parent_index = j;
471             break;
472           }
473         }
474         if (first_node_parent_index == -1 ||
475             first_node_parent_index != prev_second_node_index) {
476           continue;
477         }
478         int consumers_count = 0;
479         for (const auto& node : nodes) {
480           for (const auto& input : node.inputs) {
481             if (input == elem_root.inputs[0]) {
482               consumers_count++;
483             }
484           }
485         }
486         if (consumers_count != 1) {
487           continue;
488         }
489 
490         prev_first_node.outputs[0] = elem_root.outputs[0];
491         prev_first_node.name += " -> " + elem_root.name;
492         RETURN_IF_ERROR(prev_first_node.gpu_operation
493                             ->Fuse2InputElemWithSimpleElemAsFirstInput(
494                                 gpu_info, elem_root.gpu_operation.get()));
495         nodes.erase(nodes.begin() + elem_root_index);
496         elem_root_index = prev_first_node_index;
497         continue;
498       }
499 
500       // check TYPE_2
501       // TYPE_2
502       //      input           input
503       //     /    \             |
504       //    |    elem0          |
505       //     \    /      -->  elem
506       //   elem_root            |
507       //       |                |
508       //     output           output
509       if (!prev_first_node.gpu_operation->IsLinkable() &&
510           prev_second_node.gpu_operation->IsLinkable() &&
511           prev_first_node.outputs.size() == 1 &&
512           prev_second_node.inputs.size() == 1 &&
513           prev_second_node.outputs.size() == 1) {
514         int second_node_parent_index = -1;
515         for (int j = prev_second_node_index - 1; j >= 0; --j) {
516           if (nodes[j].outputs[0] == prev_second_node.inputs[0]) {
517             second_node_parent_index = j;
518             break;
519           }
520         }
521         if (second_node_parent_index == -1 ||
522             second_node_parent_index != prev_first_node_index) {
523           continue;
524         }
525         int consumers_count = 0;
526         for (const auto& node : nodes) {
527           for (const auto& input : node.inputs) {
528             if (input == elem_root.inputs[1]) {
529               consumers_count++;
530             }
531           }
532         }
533         if (consumers_count != 1) {
534           continue;
535         }
536 
537         prev_second_node.outputs[0] = elem_root.outputs[0];
538         prev_second_node.name += " -> " + elem_root.name;
539         RETURN_IF_ERROR(prev_second_node.gpu_operation
540                             ->Fuse2InputElemWithSimpleElemAsSecondInput(
541                                 gpu_info, elem_root.gpu_operation.get()));
542         nodes.erase(nodes.begin() + elem_root_index);
543         elem_root_index = prev_second_node_index;
544         continue;
545       }
546 
547       // check TYPE_3
548       // TYPE_3
549       //      input           input
550       //     /    \             |
551       //  elem0  elem1          |
552       //     \    /      -->  elem
553       //   elem_root            |
554       //       |                |
555       //     output           output
556       if (prev_first_node.gpu_operation->IsLinkable() &&
557           prev_second_node.gpu_operation->IsLinkable() &&
558           prev_first_node.inputs.size() == 1 &&
559           prev_first_node.outputs.size() == 1 &&
560           prev_second_node.inputs.size() == 1 &&
561           prev_second_node.outputs.size() == 1) {
562         int first_node_parent_index = -1;
563         for (int j = prev_first_node_index - 1; j >= 0; --j) {
564           if (nodes[j].outputs[0] == prev_first_node.inputs[0]) {
565             first_node_parent_index = j;
566             break;
567           }
568         }
569         int second_node_parent_index = -1;
570         for (int j = prev_second_node_index - 1; j >= 0; --j) {
571           if (nodes[j].outputs[0] == prev_second_node.inputs[0]) {
572             second_node_parent_index = j;
573             break;
574           }
575         }
576         if (first_node_parent_index == -1 || second_node_parent_index == -1 ||
577             first_node_parent_index != second_node_parent_index) {
578           continue;
579         }
580 
581         int consumers_count = 0;
582         for (const auto& node : nodes) {
583           for (const auto& input : node.inputs) {
584             if (input == elem_root.inputs[1]) {
585               consumers_count++;
586             }
587           }
588         }
589         if (consumers_count != 1) {
590           continue;
591         }
592 
593         consumers_count = 0;
594         for (const auto& node : nodes) {
595           for (const auto& input : node.inputs) {
596             if (input == elem_root.inputs[0]) {
597               consumers_count++;
598             }
599           }
600         }
601         if (consumers_count != 1) {
602           continue;
603         }
604 
605         GPUOperation new_operation;
606         RETURN_IF_ERROR(Fuse2InputElemWith2SimpleElem(
607             gpu_info, std::move(*prev_first_node.gpu_operation.get()),
608             std::move(*prev_second_node.gpu_operation.get()),
609             std::move(*elem_root.gpu_operation.get()), &new_operation));
610         GpuNode new_node;
611         new_node.inputs.push_back(prev_first_node.inputs[0]);
612         new_node.outputs.push_back(elem_root.outputs[0]);
613         new_node.name = prev_first_node.name + " -> " + prev_second_node.name +
614                         " -> " + elem_root.name;
615         new_node.gpu_operation =
616             std::make_unique<GPUOperation>(std::move(new_operation));
617 
618         // prev_first_node_index and prev_second_node_index ordered relative to
619         // elem_root inputs.
620         // first_prev_node_index and second_prev_node_index ordered relative to
621         // nodes.
622         int first_prev_node_index =
623             std::min(prev_first_node_index, prev_second_node_index);
624         int second_prev_node_index =
625             std::max(prev_first_node_index, prev_second_node_index);
626         nodes.erase(nodes.begin() + elem_root_index);
627         nodes.erase(nodes.begin() + second_prev_node_index);
628         nodes[first_prev_node_index] = std::move(new_node);
629         elem_root_index = first_prev_node_index - 1;
630         continue;
631       }
632     }
633   }
634   return absl::OkStatus();
635 }
636 
MergeNodes(const GpuInfo & gpu_info,GpuModel * gpu_model)637 absl::Status MergeNodes(const GpuInfo& gpu_info, GpuModel* gpu_model) {
638   absl::flat_hash_set<ValueId> ready_tensors;
639   for (const auto& input : gpu_model->input_ids_and_refs) {
640     ready_tensors.insert(input.first);
641   }
642   auto& nodes = gpu_model->nodes;
643   for (int i = 0; i < nodes.size(); ++i) {
644     auto& node = nodes[i];
645     for (const auto& out_id : node.outputs) {
646       ready_tensors.insert(out_id);
647     }
648     if (node.outputs.size() != 1) {
649       continue;
650     }
651     std::vector<int> next_nodes;
652     int link_index = 0;
653     for (int j = i + 1; j < nodes.size(); ++j) {
654       for (int k = 0; k < nodes[j].inputs.size(); ++k) {
655         if (nodes[j].inputs[k] == node.outputs[0]) {
656           next_nodes.push_back(j);
657           link_index = k;
658         }
659       }
660     }
661     if (next_nodes.size() != 1 || link_index != 0) {
662       continue;
663     }
664     auto& linkable_node = nodes[next_nodes[0]];
665     if (!linkable_node.gpu_operation->IsLinkable() ||
666         linkable_node.outputs.size() != 1 ||
667         !IsReady(ready_tensors, linkable_node)) {
668       continue;
669     }
670     RETURN_IF_ERROR(MergeGpuNodes(gpu_info, &linkable_node, &node));
671     nodes.erase(nodes.begin() + next_nodes[0]);
672     i -= 1;
673   }
674   return absl::OkStatus();
675 }
676 
CopyExternals(const GraphFloat32 & graph,GpuModel * gpu_model)677 void CopyExternals(const GraphFloat32& graph, GpuModel* gpu_model) {
678   const auto inputs = graph.inputs();
679   for (const auto& value : inputs) {
680     gpu_model->input_ids_and_refs.push_back({value->id, value->tensor.ref});
681   }
682 
683   const auto variable_inputs = graph.variable_inputs();
684   for (const auto& value : variable_inputs) {
685     gpu_model->variable_ids_and_refs.push_back({value->id, value->tensor.ref});
686   }
687 
688   const auto outputs = graph.outputs();
689   for (const auto& value : outputs) {
690     gpu_model->output_ids_and_refs.push_back({value->id, value->tensor.ref});
691   }
692 }
693 
694 // Removing tensors that was fused in complex operations
RemoveUnusedTensors(GpuModel * gpu_model)695 void RemoveUnusedTensors(GpuModel* gpu_model) {
696   absl::flat_hash_set<ValueId> used_tensors;
697   for (const auto& node : gpu_model->nodes) {
698     for (const auto& id : node.inputs) {
699       used_tensors.insert(id);
700     }
701     for (const auto& id : node.outputs) {
702       used_tensors.insert(id);
703     }
704   }
705   for (const auto& inputs : gpu_model->input_ids_and_refs) {
706     used_tensors.insert(inputs.first);
707   }
708   for (const auto& outputs : gpu_model->output_ids_and_refs) {
709     used_tensors.insert(outputs.first);
710   }
711   for (auto it = gpu_model->tensors.begin(); it != gpu_model->tensors.end();) {
712     if (used_tensors.find(it->first) == used_tensors.end()) {
713       gpu_model->tensors.erase(it++);
714     } else {
715       ++it;
716     }
717   }
718 }
719 
720 // Serialized model will lose polymorphic properties for GpuOperations.
721 // Here we will retrieve some information needed for generic execution of
722 // GpuOperations. Specifically, BindArguments and RecalculateGridSize must be
723 // executed.
ResolvePolymorphicArgs(GpuModel * gpu_model)724 absl::Status ResolvePolymorphicArgs(GpuModel* gpu_model) {
725   class DummySpatialTensor : public GpuSpatialTensor {
726    public:
727     DummySpatialTensor() = default;
728     explicit DummySpatialTensor(const BHWDC& shape,
729                                 const TensorDescriptor& tensor_desc)
730         : shape_(shape), tensor_desc_(tensor_desc) {}
731     ~DummySpatialTensor() override = default;
732 
733     int Width() const override { return shape_.w; }
734     int Height() const override { return shape_.h; }
735     int Depth() const override { return shape_.d; }
736     int Channels() const override { return shape_.c; }
737     int Slices() const override { return DivideRoundUp(shape_.c, 4); }
738     int Batch() const override { return shape_.b; }
739 
740     TensorDescriptor GetDescriptor() const override { return tensor_desc_; }
741 
742    private:
743     BHWDC shape_;
744     TensorDescriptor tensor_desc_;
745   };
746 
747   for (auto& node : gpu_model->nodes) {
748     std::vector<DummySpatialTensor> src_tensors(node.inputs.size());
749     for (int i = 0; i < node.inputs.size(); ++i) {
750       const auto& tensor_desc = gpu_model->tensors[node.inputs[i]];
751       src_tensors[i] =
752           DummySpatialTensor(tensor_desc.GetBHWDCShape(), tensor_desc);
753       node.gpu_operation->SetSrc(&src_tensors[i], i);
754     }
755     std::vector<DummySpatialTensor> dst_tensors(node.outputs.size());
756     for (int i = 0; i < node.outputs.size(); ++i) {
757       const auto& tensor_desc = gpu_model->tensors[node.outputs[i]];
758       dst_tensors[i] =
759           DummySpatialTensor(tensor_desc.GetBHWDCShape(), tensor_desc);
760       node.gpu_operation->SetDst(&dst_tensors[i], i);
761     }
762     RETURN_IF_ERROR(
763         node.gpu_operation->BindArguments(&node.gpu_operation->args_));
764     node.gpu_operation->RecalculateGridSize();
765   }
766   return absl::OkStatus();
767 }
768 
769 }  // namespace
770 
GraphToGpuModel(const GraphFloat32 & graph,const CreateGpuModelInfo & create_info,const GpuInfo & gpu_info,GpuModel * gpu_model)771 absl::Status GraphToGpuModel(const GraphFloat32& graph,
772                              const CreateGpuModelInfo& create_info,
773                              const GpuInfo& gpu_info, GpuModel* gpu_model) {
774   TensorReserver tensor_reserver;
775   RETURN_IF_ERROR(
776       ReserveGraphTensors(create_info, gpu_info, graph, &tensor_reserver));
777   CopyExternals(graph, gpu_model);
778   RETURN_IF_ERROR(ConvertOperations(gpu_info, graph, create_info,
779                                     &tensor_reserver, gpu_model));
780   // MergeElementwise fuse only elemntwise nodes, MergeNodes fuse elementwise to
781   // usual nodes
782   RETURN_IF_ERROR(MergeElementwiseNodes(gpu_info, gpu_model));
783   RETURN_IF_ERROR(MergeNodes(gpu_info, gpu_model));
784   gpu_model->tensors = std::move(tensor_reserver.reservations_);
785   RemoveUnusedTensors(gpu_model);
786 
787   for (auto& node : gpu_model->nodes) {
788     RETURN_IF_ERROR(node.gpu_operation->AssembleCode(gpu_info));
789   }
790 
791   return ResolvePolymorphicArgs(gpu_model);
792 }
793 
Encode(const GpuModel & gpu_model,flatbuffers::FlatBufferBuilder * builder)794 flatbuffers::Offset<data::GpuModel> Encode(
795     const GpuModel& gpu_model, flatbuffers::FlatBufferBuilder* builder) {
796   std::vector<int32_t> in_ids(gpu_model.input_ids_and_refs.size());
797   std::vector<int64_t> in_refs(gpu_model.input_ids_and_refs.size());
798   for (int i = 0; i < in_ids.size(); ++i) {
799     in_ids[i] = gpu_model.input_ids_and_refs[i].first;
800     in_refs[i] = gpu_model.input_ids_and_refs[i].second;
801   }
802   auto in_ids_fb = builder->CreateVector(in_ids);
803   auto in_refs_fb = builder->CreateVector(in_refs);
804 
805   std::vector<int32_t> out_ids(gpu_model.output_ids_and_refs.size());
806   std::vector<int64_t> out_refs(gpu_model.output_ids_and_refs.size());
807   for (int i = 0; i < out_ids.size(); ++i) {
808     out_ids[i] = gpu_model.output_ids_and_refs[i].first;
809     out_refs[i] = gpu_model.output_ids_and_refs[i].second;
810   }
811   auto out_ids_fb = builder->CreateVector(out_ids);
812   auto out_refs_fb = builder->CreateVector(out_refs);
813 
814   std::vector<flatbuffers::Offset<data::GpuNode>> nodes_fb;
815   for (int i = 0; i < gpu_model.nodes.size(); ++i) {
816     auto node_fb = Encode(gpu_model.nodes[i], builder);
817     nodes_fb.push_back(node_fb);
818   }
819   auto nodes_fb_vec = builder->CreateVector(nodes_fb);
820 
821   std::vector<flatbuffers::Offset<data::TensorDescWithId>> tensors_fb;
822   for (const auto& tensor : gpu_model.tensors) {
823     auto tensor_fb = Encode(tensor.second, tensor.first, builder);
824     tensors_fb.push_back(tensor_fb);
825   }
826   auto tensors_fb_vec = builder->CreateVector(tensors_fb);
827 
828   std::vector<flatbuffers::Offset<data::TensorDescWithId>> const_tensors_fb;
829   for (const auto& tensor : gpu_model.const_tensors) {
830     auto tensor_fb = Encode(tensor.second, tensor.first, builder);
831     const_tensors_fb.push_back(tensor_fb);
832   }
833   auto const_tensors_fb_vec = builder->CreateVector(const_tensors_fb);
834 
835   std::vector<flatbuffers::Offset<data::PairOfValueIds>>
836       variable_ids_and_refs_fb;
837   for (auto& pair : gpu_model.variable_ids_and_refs) {
838     data::PairOfValueIdsBuilder pair_builder(*builder);
839     pair_builder.add_first(pair.first);
840     pair_builder.add_second(pair.second);
841     variable_ids_and_refs_fb.push_back(pair_builder.Finish());
842   }
843   auto variable_ids_and_refs_fb_vec =
844       builder->CreateVector(variable_ids_and_refs_fb);
845 
846   data::GpuModelBuilder gpu_model_builder(*builder);
847   gpu_model_builder.add_nodes(nodes_fb_vec);
848   gpu_model_builder.add_tensors(tensors_fb_vec);
849   gpu_model_builder.add_const_tensors(const_tensors_fb_vec);
850   gpu_model_builder.add_input_ids(in_ids_fb);
851   gpu_model_builder.add_output_ids(out_ids_fb);
852   gpu_model_builder.add_variable_ids_and_refs(variable_ids_and_refs_fb_vec);
853   gpu_model_builder.add_input_refs(in_refs_fb);
854   gpu_model_builder.add_output_refs(out_refs_fb);
855   return gpu_model_builder.Finish();
856 }
857 
Decode(const data::GpuModel * fb_gpu_model,GpuModel * gpu_model)858 absl::Status Decode(const data::GpuModel* fb_gpu_model, GpuModel* gpu_model) {
859   gpu_model->nodes.resize(fb_gpu_model->nodes()->size());
860   int counter = 0;
861   for (auto node_fb : *fb_gpu_model->nodes()) {
862     RETURN_IF_ERROR(Decode(node_fb, &gpu_model->nodes[counter]));
863     counter++;
864   }
865 
866   for (const auto& tensor_fb : *fb_gpu_model->tensors()) {
867     TensorDescriptor desc;
868     Decode(tensor_fb->desc(), &desc);
869     gpu_model->tensors[tensor_fb->id()] = std::move(desc);
870   }
871   for (const auto& tensor_fb : *fb_gpu_model->const_tensors()) {
872     TensorDescriptor desc;
873     Decode(tensor_fb->desc(), &desc);
874     gpu_model->const_tensors[tensor_fb->id()] = std::move(desc);
875   }
876   for (int i = 0; i < fb_gpu_model->input_ids()->size(); ++i) {
877     gpu_model->input_ids_and_refs.push_back(
878         {(*fb_gpu_model->input_ids())[i], (*fb_gpu_model->input_refs())[i]});
879   }
880   for (int i = 0; i < fb_gpu_model->output_ids()->size(); ++i) {
881     gpu_model->output_ids_and_refs.push_back(
882         {(*fb_gpu_model->output_ids())[i], (*fb_gpu_model->output_refs())[i]});
883   }
884 
885   for (auto variable_id : *fb_gpu_model->variable_ids_and_refs()) {
886     gpu_model->variable_ids_and_refs.push_back(
887         {variable_id->first(), variable_id->second()});
888   }
889   return absl::OkStatus();
890 }
891 
RunGraphTransformsForGpuModel(GraphFloat32 * graph)892 absl::Status RunGraphTransformsForGpuModel(GraphFloat32* graph) {
893   auto merge_padding_transform = NewMergePaddingWithAdd();
894   auto add_bias_transform = NewAddBias();
895   auto pooling_to_reduce_op = NewGlobalPoolingToReduceOp();
896   ModelTransformer transformer(graph);
897   if (!transformer.Apply("add_bias", add_bias_transform.get())) {
898     return absl::InternalError("Invalid add_bias transform");
899   }
900   if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
901     return absl::InternalError("Invalid merge_padding transform");
902   }
903   if (!transformer.Apply("global pooling to mean",
904                          pooling_to_reduce_op.get())) {
905     return absl::InternalError("Invalid global pooling to mean transform");
906   }
907   return absl::OkStatus();
908 }
909 
910 }  // namespace gpu
911 }  // namespace tflite
912