xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/model_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/base/attributes.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/status/status.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/strings/string_view.h"
34 #include "tensorflow/lite/builtin_ops.h"
35 #include "tensorflow/lite/c/builtin_op_data.h"
36 #include "tensorflow/lite/c/c_api_types.h"
37 #include "tensorflow/lite/c/common.h"
38 #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
39 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
40 #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
41 #include "tensorflow/lite/delegates/gpu/common/model.h"
42 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
43 #include "tensorflow/lite/delegates/gpu/common/model_builder_internal.h"
44 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
45 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
46 #include "tensorflow/lite/delegates/gpu/common/operation_parser.h"
47 #include "tensorflow/lite/delegates/gpu/common/operations.h"
48 #include "tensorflow/lite/delegates/gpu/common/shape.h"
49 #include "tensorflow/lite/delegates/gpu/common/status.h"
50 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
51 #include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
52 #include "tensorflow/lite/delegates/utils.h"
53 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
54 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
55 #include "tensorflow/lite/kernels/kernel_util.h"
56 #include "tensorflow/lite/tools/versioning/gpu_compatibility.h"
57 #include "tensorflow/lite/util.h"
58 
59 namespace tflite {
60 namespace gpu {
61 namespace {
62 
GetFullyConnectedAttributes(int weights_tensor_id,int bias_tensor_id,ObjectReader * reader,FullyConnectedAttributes * attr)63 absl::Status GetFullyConnectedAttributes(int weights_tensor_id,
64                                          int bias_tensor_id,
65                                          ObjectReader* reader,
66                                          FullyConnectedAttributes* attr) {
67   Tensor<HW, DataType::FLOAT32> weights;
68   RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
69   attr->weights.data = std::move(weights.data);
70   attr->weights.id = weights.id;
71   attr->weights.shape.h = 1;
72   attr->weights.shape.w = 1;
73   attr->weights.shape.o = weights.shape.h;
74   attr->weights.shape.i = weights.shape.w;
75   reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError();  // optional
76   return absl::OkStatus();
77 }
78 
79 template <typename ParamsT>
RetrieveBuiltinData(const TfLiteNode * tflite_node,const ParamsT ** tf_options)80 absl::Status RetrieveBuiltinData(const TfLiteNode* tflite_node,
81                                  const ParamsT** tf_options) {
82   *tf_options = static_cast<const ParamsT*>(tflite_node->builtin_data);
83   if (!*tf_options) {
84     return absl::InternalError("Unable to retrieve builtin_data.");
85   }
86   return absl::OkStatus();
87 }
88 
89 template <typename ParamsT>
RetrieveCustomInitialData(const TfLiteNode * tflite_node,const ParamsT ** tf_options)90 absl::Status RetrieveCustomInitialData(const TfLiteNode* tflite_node,
91                                        const ParamsT** tf_options) {
92   *tf_options = static_cast<const ParamsT*>(tflite_node->custom_initial_data);
93   if (!*tf_options) {
94     return absl::InternalError("Unable to retrieve custom_initial_data.");
95   }
96   return absl::OkStatus();
97 }
98 
99 // Creates a simple node that holds tensor value.
NewConstNode(TensorFloat32 t,GraphFloat32 * graph,Value ** value)100 absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, Value** value) {
101   ConstTensorAttributes attr;
102   attr.tensor = std::move(t);
103   Node* node = graph->NewNode();
104   node->operation.attributes = attr;
105   node->operation.type = ToString(OperationType::CONSTANT);
106   *value = graph->NewValue();
107   RETURN_IF_ERROR(graph->SetProducer(node->id, (*value)->id));
108   // Keep data inside this tensor.
109   (*value)->tensor.ref = attr.tensor.id;
110   (*value)->tensor.type = attr.tensor.kType;
111   (*value)->tensor.shape = attr.tensor.shape;
112   return absl::OkStatus();
113 }
114 
ParseInputsWithConstTensor(Node * node,ObjectReader * reader,TensorOrScalar * tensor_or_scalar)115 absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
116                                         TensorOrScalar* tensor_or_scalar) {
117   const std::string& opname = node->operation.type;
118 
119   // Determine runtime/constant tensors.
120   const TfLiteTensor* input0 = reader->GetInputTensor(0);
121   if (!input0) {
122     return absl::InvalidArgumentError("Couldn't get the 1st input tensor for " +
123                                       opname);
124   }
125   const TfLiteTensor* input1 = reader->GetInputTensor(1);
126   if (!input1) {
127     return absl::InvalidArgumentError("Couldn't get the 2nd input tensor for " +
128                                       opname);
129   }
130   const bool constant_tensor0 = IsConstantTensor(input0);
131   const bool constant_tensor1 = IsConstantTensor(input1);
132   if (constant_tensor0 && constant_tensor1) {
133     return absl::InvalidArgumentError("No runtime input tensors for " + opname);
134   }
135   const bool runtime_tensor0 = !constant_tensor0;
136   const bool runtime_tensor1 = !constant_tensor1;
137 
138   if (runtime_tensor0 && runtime_tensor1) {
139     RETURN_IF_ERROR(reader->AddInput(node, 0));
140     RETURN_IF_ERROR(reader->AddInput(node, 1));
141   } else {
142     int runtime_tensor = 0;
143     int constant_tensor = 1;
144     TfLiteIntArray* constant_dims = input1->dims;
145     if (constant_tensor0 && runtime_tensor1) {
146       runtime_tensor = 1;
147       constant_tensor = 0;
148       constant_dims = input0->dims;
149     }
150     RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
151     if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) {
152       Tensor<Scalar, DataType::FLOAT32> tensor;
153       RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
154       *tensor_or_scalar = tensor.data[0];
155     } else {
156       if (CheckIfLinearConvertible(constant_dims).ok()) {
157         Tensor<Linear, DataType::FLOAT32> tensor;
158         RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
159         *tensor_or_scalar = std::move(tensor);
160       } else if (constant_dims->size == 2) {
161         Tensor<HW, DataType::FLOAT32> tensor_hw;
162         RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor_hw));
163         Tensor<HWC, DataType::FLOAT32> tensor;
164         tensor.id = tensor_hw.id;
165         tensor.shape = HWC(1, tensor_hw.shape.h, tensor_hw.shape.w);
166         tensor.data = tensor_hw.data;
167         *tensor_or_scalar = std::move(tensor);
168       } else {
169         Tensor<HWC, DataType::FLOAT32> tensor;
170         RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
171         *tensor_or_scalar = std::move(tensor);
172       }
173     }
174   }
175   return absl::OkStatus();
176 }
177 
MaybeFuseActivationForElementwiseNode(OperationType operation_type,const TfLiteNode * tflite_node,GraphFloat32 * graph,Node * node)178 absl::Status MaybeFuseActivationForElementwiseNode(
179     OperationType operation_type, const TfLiteNode* tflite_node,
180     GraphFloat32* graph, Node* node) {
181   TfLiteFusedActivation activation = kTfLiteActNone;
182   switch (operation_type) {
183     case OperationType::MUL: {
184       const TfLiteMulParams* tf_options;
185       if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
186         activation = tf_options->activation;
187       }
188       break;
189     }
190     case OperationType::ADD: {
191       const TfLiteAddParams* tf_options;
192       if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
193         activation = tf_options->activation;
194       }
195       break;
196     }
197     case OperationType::SUB: {
198       const TfLiteSubParams* tf_options;
199       if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
200         activation = tf_options->activation;
201       }
202       break;
203     }
204     case OperationType::DIV: {
205       const TfLiteDivParams* tf_options;
206       if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
207         activation = tf_options->activation;
208       }
209       break;
210     }
211     default:
212       // No activation expected.
213       activation = kTfLiteActNone;
214   }
215 
216   if (activation) {
217     return MaybeFuseActivation(activation, graph, node);
218   }
219   return absl::OkStatus();
220 }
221 
222 struct TensorInfo {
223   std::vector<std::pair<TfLiteNode*, TfLiteRegistration*>> producers;
224   std::vector<std::pair<TfLiteNode*, TfLiteRegistration*>> consumers;
225 };
226 
GetTensorInfo(const TfLiteContext * context,int tensor_id,TensorInfo * result)227 absl::Status GetTensorInfo(const TfLiteContext* context, int tensor_id,
228                            TensorInfo* result) {
229   TfLiteIntArray* execution_plan = nullptr;
230   if (context->GetExecutionPlan(const_cast<TfLiteContext*>(context),
231                                 &execution_plan) != kTfLiteOk) {
232     return absl::UnavailableError("Unable to get graph execution plan.");
233   }
234   for (int i = 0; i < execution_plan->size; ++i) {
235     const int node_index = execution_plan->data[i];
236     TfLiteNode* node = nullptr;
237     TfLiteRegistration* registration = nullptr;
238     if (context->GetNodeAndRegistration(const_cast<TfLiteContext*>(context),
239                                         node_index, &node,
240                                         &registration) != kTfLiteOk) {
241       return absl::UnavailableError(
242           "Unable to get node and registration for node.");
243     }
244     for (int j = 0; j < node->inputs->size; ++j) {
245       if (tensor_id == node->inputs->data[j]) {
246         result->consumers.push_back({node, registration});
247       }
248     }
249     for (int j = 0; j < node->outputs->size; ++j) {
250       if (tensor_id == node->outputs->data[j]) {
251         result->producers.push_back({node, registration});
252       }
253     }
254   }
255   return absl::OkStatus();
256 }
257 
IsLogicalCode(int32_t builtin_code)258 bool IsLogicalCode(int32_t builtin_code) {
259   return builtin_code == kTfLiteBuiltinGreater ||
260          builtin_code == kTfLiteBuiltinGreaterEqual ||
261          builtin_code == kTfLiteBuiltinLess ||
262          builtin_code == kTfLiteBuiltinLessEqual ||
263          builtin_code == kTfLiteBuiltinEqual ||
264          builtin_code == kTfLiteBuiltinNotEqual;
265 }
266 
IsLogicalOp(tflite::gpu::OperationType op_type)267 bool IsLogicalOp(tflite::gpu::OperationType op_type) {
268   return op_type == tflite::gpu::OperationType::GREATER ||
269          op_type == tflite::gpu::OperationType::GREATER_EQUAL ||
270          op_type == tflite::gpu::OperationType::LESS ||
271          op_type == tflite::gpu::OperationType::LESS_EQUAL ||
272          op_type == tflite::gpu::OperationType::EQUAL ||
273          op_type == tflite::gpu::OperationType::NOT_EQUAL;
274 }
275 
276 class BatchedMatMulOperationParser : public TFLiteOperationParser {
277  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)278   absl::Status IsSupported(const TfLiteContext* context,
279                            const TfLiteNode* tflite_node,
280                            const TfLiteRegistration* registration) final {
281     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
282   }
283 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)284   absl::Status Parse(const TfLiteNode* tflite_node,
285                      const TfLiteRegistration* registration,
286                      GraphFloat32* graph, ObjectReader* reader) final {
287     if (reader->GetNumberOfRuntimeInputs() == 2) {
288       Node* node = graph->NewNode();
289       node->operation.type = ToString(OperationType::BATCHED_MATMUL);
290       RETURN_IF_ERROR(reader->AddInput(node, 0));
291       RETURN_IF_ERROR(reader->AddInput(node, 1));
292       RETURN_IF_ERROR(reader->AddOutputs(node));
293       return absl::OkStatus();
294     } else if (reader->GetNumberOfRuntimeInputs() == 1) {
295       // Second input is constant, replace with Convolution2D
296       const TfLiteTensor* second_input = reader->GetInputTensor(1);
297       if (!IsConstantTensor(second_input) || second_input->dims->size != 2) {
298         // first input must be runtime and second is 2d constant tensor
299         return absl::UnavailableError("Not supported batched mat mul case");
300       }
301       Node* node = graph->NewNode();
302       node->operation.type = ToString(OperationType::CONVOLUTION_2D);
303       RETURN_IF_ERROR(reader->AddInput(node, 0));
304       RETURN_IF_ERROR(reader->AddOutputs(node));
305 
306       Tensor<HW, DataType::FLOAT32> weights;
307       RETURN_IF_ERROR(reader->ReadTensor(1, &weights));
308       Convolution2DAttributes attr;
309       attr.weights.data.resize(weights.shape.w * weights.shape.h);
310       for (int i = 0; i < weights.shape.w; ++i) {
311         for (int j = 0; j < weights.shape.h; ++j) {
312           attr.weights.data[i * weights.shape.h + j] =
313               weights.data[j * weights.shape.w + i];
314         }
315       }
316       attr.weights.id = weights.id;
317       attr.weights.shape.h = 1;
318       attr.weights.shape.w = 1;
319       attr.weights.shape.o = weights.shape.w;
320       attr.weights.shape.i = weights.shape.h;
321       attr.strides = HW(1, 1);
322       attr.dilations = HW(1, 1);
323       attr.padding.appended = HW(0, 0);
324       attr.padding.prepended = HW(0, 0);
325       node->operation.attributes = std::move(attr);
326       return absl::OkStatus();
327     } else {
328       return absl::UnavailableError("Not supported batched mat mul case");
329     }
330     return absl::OkStatus();
331   }
332 };
333 
334 class CastOperationParser : public TFLiteOperationParser {
335  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)336   absl::Status IsSupported(const TfLiteContext* context,
337                            const TfLiteNode* tflite_node,
338                            const TfLiteRegistration* registration) final {
339     TfLiteType src_type = context->tensors[tflite_node->inputs->data[0]].type;
340     TfLiteType dst_type = context->tensors[tflite_node->outputs->data[0]].type;
341     if (src_type == kTfLiteBool &&
342         (dst_type == kTfLiteFloat16 || dst_type == kTfLiteFloat32)) {
343       // check that we have next sequence:
344       //   logical_op->bool_tensor->CAST->float_tensor.
345       TensorInfo input_tensor_info;
346       RETURN_IF_ERROR(GetTensorInfo(context, tflite_node->inputs->data[0],
347                                     &input_tensor_info));
348       if (input_tensor_info.producers.size() != 1 ||
349           input_tensor_info.consumers.size() != 1) {
350         return absl::UnavailableError("Not supported cast case");
351       }
352       // If the cast is an output, do the cast to float on CPU.
353       TensorInfo output_tensor_info;
354       RETURN_IF_ERROR(GetTensorInfo(context, tflite_node->outputs->data[0],
355                                     &output_tensor_info));
356       if (output_tensor_info.consumers.size() != 1) {
357         return absl::UnavailableError(
358             "Cast from bool not supported for outputs");
359       }
360       if (IsLogicalCode(input_tensor_info.producers[0].second->builtin_code)) {
361         return absl::OkStatus();
362       }
363     }
364     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
365   }
366 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)367   absl::Status Parse(const TfLiteNode* tflite_node,
368                      const TfLiteRegistration* registration,
369                      GraphFloat32* graph, ObjectReader* reader) final {
370     Node* node = graph->NewNode();
371     node->operation.type = ToString(OperationType::CAST);
372     RETURN_IF_ERROR(reader->AddInput(node, 0));
373     RETURN_IF_ERROR(reader->AddOutputs(node));
374     return absl::OkStatus();
375   }
376 };
377 
378 class ClampOperationsParser : public TFLiteOperationParser {
379  public:
ClampOperationsParser(float clamp_a,float clamp_b)380   explicit ClampOperationsParser(float clamp_a, float clamp_b)
381       : clamp_a_(clamp_a), clamp_b_(clamp_b) {}
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)382   absl::Status IsSupported(const TfLiteContext* context,
383                            const TfLiteNode* tflite_node,
384                            const TfLiteRegistration* registration) final {
385     return absl::OkStatus();
386   }
387 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)388   absl::Status Parse(const TfLiteNode* tflite_node,
389                      const TfLiteRegistration* registration,
390                      GraphFloat32* graph, ObjectReader* reader) final {
391     // clamp(v, a, b) = clamp(v - a, 0.0, b - a) + a;
392     // We replace clamp(...) with sequence of elementwise ops:
393     // substaction -> usual relu with alpha = 0.0 -> addition.
394     // node_sub = v0 = v - a // add op (add -a)
395     // node_relu = v1 = clamp(v0, 0.0, clip); // relu op alpha = 0.0,
396     // clip = b - a;
397     // node_add = v2 = v1 + a // add op (add a)
398     Node* node_sub = graph->NewNode();
399     Node* node_relu = graph->NewNode();
400     Node* node_add = graph->NewNode();
401 
402     ElementwiseAttributes sub_attr;
403     sub_attr.param = -clamp_a_;
404     node_sub->operation.type = ToString(OperationType::ADD);
405     node_sub->operation.attributes = std::move(sub_attr);
406 
407     ReLUAttributes relu_attr;
408     relu_attr.alpha = 0.0f;
409     relu_attr.clip = clamp_b_ - clamp_a_;
410     node_relu->operation.type = ToString(OperationType::RELU);
411     node_relu->operation.attributes = relu_attr;
412 
413     ElementwiseAttributes add_attr;
414     add_attr.param = clamp_a_;
415     node_add->operation.type = ToString(OperationType::ADD);
416     node_add->operation.attributes = std::move(add_attr);
417 
418     RETURN_IF_ERROR(reader->AddInput(node_sub, 0));
419     auto input = graph->FindInputs(node_sub->id)[0];
420 
421     Value* v0 = graph->NewValue();
422     Value* v1 = graph->NewValue();
423     v0->tensor.type = input->tensor.type;
424     v0->tensor.shape = input->tensor.shape;
425     v1->tensor.type = input->tensor.type;
426     v1->tensor.shape = input->tensor.shape;
427 
428     RETURN_IF_ERROR(graph->SetProducer(node_sub->id, v0->id));
429     RETURN_IF_ERROR(graph->AddConsumer(node_relu->id, v0->id));
430     RETURN_IF_ERROR(graph->SetProducer(node_relu->id, v1->id));
431     RETURN_IF_ERROR(graph->AddConsumer(node_add->id, v1->id));
432 
433     RETURN_IF_ERROR(reader->AddOutputs(node_add));
434     return absl::OkStatus();
435   }
436 
437  private:
438   const float clamp_a_, clamp_b_;
439 };
440 
441 class ConcatenationOperationParser : public TFLiteOperationParser {
442  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)443   absl::Status IsSupported(const TfLiteContext* context,
444                            const TfLiteNode* tflite_node,
445                            const TfLiteRegistration* registration) final {
446     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
447 
448     // TODO(eignasheva): add proper tensor availability checking
449     // for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
450     //   RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, idx));
451     // }
452     // TODO(eignasheva): add axis checking.
453     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
454   }
455 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)456   absl::Status Parse(const TfLiteNode* tflite_node,
457                      const TfLiteRegistration* registration,
458                      GraphFloat32* graph, ObjectReader* reader) final {
459     ConcatAttributes attr;
460     // Read inputs first to make sure const node is added to a graph before
461     // concat node to ensure topological order.
462     std::vector<const Value*> inputs;
463     for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
464       Value* value;
465       const auto status = reader->ReadValue(idx, &value);
466       if (status.ok()) {
467         inputs.push_back(value);
468       } else {
469         TensorFloat32 tensor;
470         RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
471         Value* value;
472         RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
473         inputs.push_back(value);
474       }
475     }
476 
477     for (int i = 0; i < inputs.size(); ++i) {
478       for (int j = 0; j < i; ++j) {
479         if (inputs[i] == inputs[j]) {
480           Node* node_copy = graph->NewNode();
481           node_copy->operation.type = ToString(OperationType::COPY);
482           RETURN_IF_ERROR(graph->AddConsumer(node_copy->id, inputs[j]->id));
483           Value* copy_value = graph->NewValue();
484           copy_value->tensor.type = inputs[j]->tensor.type;
485           copy_value->tensor.shape = inputs[j]->tensor.shape;
486           RETURN_IF_ERROR(graph->SetProducer(node_copy->id, copy_value->id));
487           inputs[i] = copy_value;
488           break;
489         }
490       }
491     }
492 
493     Node* node = graph->NewNode();
494     node->operation.type = ToString(OperationType::CONCAT);
495     RETURN_IF_ERROR(reader->AddOutputs(node));
496     for (int i = 0; i < inputs.size(); ++i) {
497       RETURN_IF_ERROR(graph->AddConsumer(node->id, inputs[i]->id));
498     }
499 
500     std::vector<BHWC> input_shapes;
501     for (auto input : graph->FindInputs(node->id)) {
502       input_shapes.push_back(input->tensor.shape);
503     }
504     RETURN_IF_ERROR(SetAxis(input_shapes, &attr.axis));
505 
506     // Guess axis.
507     BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
508     for (auto input : graph->FindInputs(node->id)) {
509       if (input->tensor.shape.h != output_shape.h) {
510         attr.axis = Axis::HEIGHT;
511         break;
512       }
513       if (input->tensor.shape.w != output_shape.w) {
514         attr.axis = Axis::WIDTH;
515         break;
516       }
517       if (input->tensor.shape.c != output_shape.c) {
518         attr.axis = Axis::CHANNELS;
519         break;
520       }
521     }
522     const TfLiteConcatenationParams* tf_options;
523     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
524     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
525     node->operation.attributes = attr;
526     return absl::OkStatus();
527   }
528 
529  private:
SetAxis(const std::vector<BHWC> & input_shapes,Axis * axis)530   absl::Status SetAxis(const std::vector<BHWC>& input_shapes, Axis* axis) {
531     *axis = Axis::BATCH;
532     for (int i = 1; i < input_shapes.size(); i++) {
533       if (input_shapes[0].h != input_shapes[i].h &&
534           input_shapes[0].w != input_shapes[i].w &&
535           input_shapes[0].c != input_shapes[i].c) {
536         *axis = Axis::HEIGHT;
537         break;
538       }
539     }
540     if (*axis == Axis::BATCH) return absl::OkStatus();
541     for (int i = 1; i < input_shapes.size(); i++) {
542       if (input_shapes[0].b != input_shapes[i].b &&
543           input_shapes[0].w != input_shapes[i].w &&
544           input_shapes[0].c != input_shapes[i].c) {
545         *axis = Axis::WIDTH;
546         break;
547       }
548     }
549     if (*axis == Axis::HEIGHT) return absl::OkStatus();
550     for (int i = 1; i < input_shapes.size(); i++) {
551       if (input_shapes[0].b != input_shapes[i].b &&
552           input_shapes[0].h != input_shapes[i].h &&
553           input_shapes[0].c != input_shapes[i].c) {
554         *axis = Axis::CHANNELS;
555         break;
556       }
557     }
558     if (*axis == Axis::WIDTH) return absl::OkStatus();
559     for (int i = 1; i < input_shapes.size(); i++) {
560       if (input_shapes[0].b != input_shapes[i].b &&
561           input_shapes[0].w != input_shapes[i].w &&
562           input_shapes[0].h != input_shapes[i].h) {
563         return absl::UnimplementedError(
564             "Can concatenate tensors only by batch, height, width, or "
565             "channels.");
566       }
567     }
568     return absl::OkStatus();
569   }
570 };
571 
572 class Conv2DOperationParser : public TFLiteOperationParser {
573  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)574   absl::Status IsSupported(const TfLiteContext* context,
575                            const TfLiteNode* tflite_node,
576                            const TfLiteRegistration* registration) final {
577     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 6));
578     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
579   }
580 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)581   absl::Status Parse(const TfLiteNode* tflite_node,
582                      const TfLiteRegistration* registration,
583                      GraphFloat32* graph, ObjectReader* reader) final {
584     const TfLiteConvParams* tf_options;
585     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
586     Convolution2DAttributes attr;
587     RETURN_IF_ERROR(ReadAttributes(tflite_node, tf_options, reader, &attr));
588 
589     const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
590     if (runtime_inputs == 2) {
591       // weights are second runtime input
592       const TfLiteTensor* src_tensor = reader->GetInputTensor(0);
593       const TfLiteTensor* weights_tensor = reader->GetInputTensor(1);
594       BHWC src_shape, weights_shape;
595       RETURN_IF_ERROR(ExtractTensorShape(*src_tensor, &src_shape));
596       RETURN_IF_ERROR(ExtractTensorShape(*weights_tensor, &weights_shape));
597       if (src_shape.c != weights_shape.c) {
598         return absl::InternalError(
599             "No support of CONVOLUTION_2D with runtime grouped weights.");
600       }
601 
602       Node* node = graph->NewNode();
603       node->operation.type = ToString(OperationType::CONVOLUTION_2D);
604       node->operation.attributes = std::move(attr);
605       RETURN_IF_ERROR(reader->AddInput(node, 0));
606       RETURN_IF_ERROR(reader->AddInput(node, 1));
607       RETURN_IF_ERROR(reader->AddOutputs(node));
608       RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
609       return absl::OkStatus();
610     } else {
611       // weights are constants
612       const int src_group_size = attr.weights.shape.i;
613       const int dst_group_size = attr.weights.shape.o / attr.groups;
614       const bool supported_grouped_conv =
615           src_group_size % 4 == 0 && dst_group_size % 4 == 0;
616       if (attr.groups != 1 && !supported_grouped_conv) {
617         // Not supported case, replace with usual convolutions:
618         return ResolveGroupedConvolution(attr, tf_options, reader, graph);
619       } else {
620         Node* node = graph->NewNode();
621         node->operation.type = ToString(OperationType::CONVOLUTION_2D);
622         node->operation.attributes = std::move(attr);
623         RETURN_IF_ERROR(reader->AddInput(node, 0));
624         RETURN_IF_ERROR(reader->AddOutputs(node));
625         RETURN_IF_ERROR(
626             MaybeFuseActivation(tf_options->activation, graph, node));
627         return absl::OkStatus();
628       }
629     }
630   }
631 
632  private:
ReadAttributes(const TfLiteNode * tflite_node,const TfLiteConvParams * tf_options,ObjectReader * reader,Convolution2DAttributes * attr)633   absl::Status ReadAttributes(const TfLiteNode* tflite_node,
634                               const TfLiteConvParams* tf_options,
635                               ObjectReader* reader,
636                               Convolution2DAttributes* attr) {
637     const TfLiteTensor* src_tensor = reader->GetInputTensor(0);
638     BHWC src_shape;
639     RETURN_IF_ERROR(ExtractTensorShape(*src_tensor, &src_shape));
640     const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
641     if (runtime_inputs == 1) {
642       RETURN_IF_ERROR(reader->ReadTensor(1, &attr->weights));
643       attr->groups = src_shape.c / attr->weights.shape.i;
644     } else {
645       const TfLiteTensor* weights_tensor = reader->GetInputTensor(1);
646       if (!weights_tensor) {
647         return absl::InternalError("Expected second runtime tensor.");
648       }
649       BHWC weights_shape;
650       RETURN_IF_ERROR(ExtractTensorShape(*weights_tensor, &weights_shape));
651       attr->weights.shape = OHWI(weights_shape.b, weights_shape.h,
652                                  weights_shape.w, weights_shape.c);
653       attr->groups = 1;
654     }
655     reader->ReadTensor(2, &attr->bias).IgnoreError();  // bias is optional
656     attr->strides = ToHW(tf_options->stride_height, tf_options->stride_width);
657     attr->dilations = HW(tf_options->dilation_height_factor,
658                          tf_options->dilation_width_factor);
659     UpdatePadding(tf_options->padding, src_shape, attr);
660     return absl::OkStatus();
661   }
662 
663   // Replace single grouped convolution(N = groups count) with this sequence:
664   //  split input to N tensors in channels dim
665   //  N usual convs
666   //  concat N tensors to 1 output in channels dim
ResolveGroupedConvolution(const Convolution2DAttributes & attr,const TfLiteConvParams * tf_options,ObjectReader * reader,GraphFloat32 * graph)667   absl::Status ResolveGroupedConvolution(const Convolution2DAttributes& attr,
668                                          const TfLiteConvParams* tf_options,
669                                          ObjectReader* reader,
670                                          GraphFloat32* graph) {
671     const TfLiteTensor* src_tensor = reader->GetInputTensor(0);
672     const TfLiteTensor* dst_tensor = reader->GetOutputTensor(0);
673     BHWC src_shape, dst_shape;
674     RETURN_IF_ERROR(ExtractTensorShape(*src_tensor, &src_shape));
675     RETURN_IF_ERROR(ExtractTensorShape(*dst_tensor, &dst_shape));
676 
677     DataType src_type = DataType::FLOAT32;
678     if (src_tensor->type == kTfLiteFloat16) {
679       src_type = DataType::FLOAT16;
680     }
681     DataType dst_type = DataType::FLOAT32;
682     if (dst_tensor->type == kTfLiteFloat16) {
683       dst_type = DataType::FLOAT16;
684     }
685 
686     const int src_group_size = attr.weights.shape.i;
687     const int dst_group_size = attr.weights.shape.o / attr.groups;
688 
689     Node* split_node = graph->NewNode();
690     RETURN_IF_ERROR(reader->AddInput(split_node, 0));
691     {
692       SplitAttributes split_attr;
693       split_attr.axis = Axis::CHANNELS;
694       split_node->operation.type = ToString(OperationType::SPLIT);
695       split_node->operation.attributes = split_attr;
696     }
697 
698     std::vector<Node*> conv_nodes(attr.groups);
699     std::vector<Value*> conv_src(attr.groups);
700     std::vector<Value*> conv_dst(attr.groups);
701     for (int i = 0; i < attr.groups; ++i) {
702       conv_nodes[i] = graph->NewNode();
703       conv_src[i] = graph->NewValue();
704       conv_dst[i] = graph->NewValue();
705       conv_src[i]->tensor.shape = src_shape;
706       conv_src[i]->tensor.type = src_type;
707       conv_src[i]->tensor.shape.c = src_group_size;
708       conv_dst[i]->tensor.shape = dst_shape;
709       conv_dst[i]->tensor.type = dst_type;
710       conv_dst[i]->tensor.shape.c = dst_group_size;
711       Convolution2DAttributes conv_attr;
712       conv_attr = attr;
713       conv_attr.groups = 1;
714       conv_attr.weights.id = -1;
715       conv_attr.weights.shape.o = dst_group_size;
716       conv_attr.weights.data.resize(
717           conv_attr.weights.shape.DimensionsProduct());
718       for (int out_i = 0; out_i < dst_group_size; ++out_i) {
719         for (int in_i = 0; in_i < src_group_size; ++in_i) {
720           for (int ky = 0; ky < attr.weights.shape.h; ++ky) {
721             for (int kx = 0; kx < attr.weights.shape.w; ++kx) {
722               const int src_index = attr.weights.shape.LinearIndex(
723                   {{i * dst_group_size + out_i, ky, kx, in_i}});
724               const int dst_index =
725                   conv_attr.weights.shape.LinearIndex({{out_i, ky, kx, in_i}});
726               conv_attr.weights.data[dst_index] = attr.weights.data[src_index];
727             }
728           }
729         }
730       }
731       conv_attr.bias.shape.v = dst_group_size;
732       conv_attr.bias.data.resize(conv_attr.bias.shape.DimensionsProduct());
733       for (int out_i = 0; out_i < dst_group_size; ++out_i) {
734         if (i * dst_group_size + out_i < attr.bias.data.size()) {
735           conv_attr.bias.data[out_i] =
736               attr.bias.data[i * dst_group_size + out_i];
737         } else {
738           conv_attr.bias.data[out_i] = 0.0f;
739         }
740       }
741       conv_nodes[i]->operation.type = ToString(OperationType::CONVOLUTION_2D);
742       conv_nodes[i]->operation.attributes = conv_attr;
743 
744       RETURN_IF_ERROR(graph->SetProducer(split_node->id, conv_src[i]->id));
745       RETURN_IF_ERROR(graph->AddConsumer(conv_nodes[i]->id, conv_src[i]->id));
746       RETURN_IF_ERROR(graph->SetProducer(conv_nodes[i]->id, conv_dst[i]->id));
747     }
748 
749     Node* concat_node = graph->NewNode();
750     {
751       ConcatAttributes concat_attr;
752       concat_attr.axis = Axis::CHANNELS;
753       concat_node->operation.type = ToString(OperationType::CONCAT);
754       concat_node->operation.attributes = concat_attr;
755     }
756     for (int i = 0; i < attr.groups; ++i) {
757       RETURN_IF_ERROR(graph->AddConsumer(concat_node->id, conv_dst[i]->id));
758     }
759     RETURN_IF_ERROR(reader->AddOutputs(concat_node));
760     RETURN_IF_ERROR(
761         MaybeFuseActivation(tf_options->activation, graph, concat_node));
762     return absl::OkStatus();
763   }
764 };
765 
766 class CumsumOperationParser : public TFLiteOperationParser {
767  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)768   absl::Status IsSupported(const TfLiteContext* context,
769                            const TfLiteNode* tflite_node,
770                            const TfLiteRegistration* registration) final {
771     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
772     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
773   }
774 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)775   absl::Status Parse(const TfLiteNode* tflite_node,
776                      const TfLiteRegistration* registration,
777                      GraphFloat32* graph, ObjectReader* reader) final {
778     Node* node = graph->NewNode();
779     CumsumAttributes attr;
780     const TfLiteTensor* input_tensor = reader->GetInputTensor(0);
781     const TfLiteTensor* axis_tensor = reader->GetInputTensor(1);
782     const TfLiteIntArray* shape = input_tensor->dims;
783     const int tflite_axis = GetTensorData<int32_t>(axis_tensor)[0];
784     const Axis axes[4] = {Axis::BATCH, Axis::WIDTH, Axis::HEIGHT,
785                           Axis::CHANNELS};
786     attr.axis = axes[tflite_axis + 4 - shape->size];
787     node->operation.type = ToString(OperationType::CUMSUM);
788     Tensor<BHWC, DataType::FLOAT32> inputs;
789     node->operation.attributes = std::move(attr);
790     RETURN_IF_ERROR(reader->AddInput(node, 0));
791     RETURN_IF_ERROR(reader->AddOutputs(node));
792     return absl::OkStatus();
793   }
794 };
795 
796 // Doesn't have a kernel implementation.
797 class DensifyOperationParser : public TFLiteOperationParser {
798  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)799   absl::Status IsSupported(const TfLiteContext* context,
800                            const TfLiteNode* tflite_node,
801                            const TfLiteRegistration* registration) final {
802     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
803     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
804   }
805 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)806   absl::Status Parse(const TfLiteNode* tflite_node,
807                      const TfLiteRegistration* registration,
808                      GraphFloat32* graph, ObjectReader* reader) final {
809     Node* node = graph->NewNode();
810     node->operation.type = ToString(OperationType::DENSIFY);
811     const TfLiteTensor* const_tensor = reader->GetInputTensor(0);
812     if (!const_tensor->sparsity) {
813       return absl::InvalidArgumentError("Input tensor must be sparse.");
814     }
815     TensorFloat32 sparse_tensor;
816     RETURN_IF_ERROR(reader->ReadTensor(0, &sparse_tensor));
817     DensifyAttributes attributes;
818     attributes.tensor = std::move(sparse_tensor);
819     node->operation.attributes = attributes;
820     return reader->AddOutputs(node);
821   }
822 };
823 
824 class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
825  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)826   absl::Status IsSupported(const TfLiteContext* context,
827                            const TfLiteNode* tflite_node,
828                            const TfLiteRegistration* registration) final {
829     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 6));
830     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
831   }
832 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)833   absl::Status Parse(const TfLiteNode* tflite_node,
834                      const TfLiteRegistration* registration,
835                      GraphFloat32* graph, ObjectReader* reader) final {
836     Node* node = graph->NewNode();
837     node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION);
838     RETURN_IF_ERROR(reader->AddInput(node, 0));
839     RETURN_IF_ERROR(reader->AddOutputs(node));
840 
841     DepthwiseConvolution2DAttributes attr;
842     const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
843     if (runtime_inputs == 2) {
844       RETURN_IF_ERROR(reader->AddInput(node, 1));
845       auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape;
846       attr.weights.shape = OHWI(weights_shape.b, weights_shape.h,
847                                 weights_shape.w, weights_shape.c);
848     } else {  // runtime_inputs == 1;
849       RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
850     }
851     reader->ReadTensor(2, &attr.bias).IgnoreError();  // bias is optional
852     const TfLiteDepthwiseConvParams* tf_options;
853     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
854     attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
855     attr.dilations = HW(std::max(1, tf_options->dilation_height_factor),
856                         std::max(1, tf_options->dilation_width_factor));
857     UpdatePadding(tf_options->padding,
858                   graph->FindInputs(node->id)[0]->tensor.shape, &attr);
859     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
860     const int depth_multiplier = tf_options->depth_multiplier;
861     if (depth_multiplier != 1) {
862       const TfLiteTensor* input = reader->GetInputTensor(0);
863       const TfLiteTensor* filter = reader->GetInputTensor(1);
864       const TfLiteTensor* output = reader->GetOutputTensor(0);
865       TransposeWeights(input, filter, output, depth_multiplier, &attr);
866     }
867     node->operation.attributes = std::move(attr);
868     return absl::OkStatus();
869   }
870 
871  private:
872   // TFLite CPU stores weights as:
873   //   [1, kernel_height, kernel_width, input_depth * depth_multiplier]
874   // TFLite GPU stores weights as:
875   //   [depth_multiplier, kernel_height, kernel_width, input_depth]
TransposeWeights(const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * output,int depth_multiplier,DepthwiseConvolution2DAttributes * attr)876   static void TransposeWeights(const TfLiteTensor* input,
877                                const TfLiteTensor* filter,
878                                const TfLiteTensor* output, int depth_multiplier,
879                                DepthwiseConvolution2DAttributes* attr) {
880     const int input_depth = input->dims->data[3];
881     const int filter_height = filter->dims->data[1];
882     const int filter_width = filter->dims->data[2];
883     const int output_depth = output->dims->data[3];
884     Tensor<OHWI, DataType::FLOAT32> weights;
885     weights.id = attr->weights.id;
886     weights.shape =
887         OHWI(output_depth, filter_height, filter_width, input_depth);
888     weights.data.resize(weights.shape.DimensionsProduct());
889     float* dst = &weights.data[0];
890     for (int j = 0; j < output_depth; ++j) {
891       const float* src = attr->weights.data.data() + j;
892       for (int i = 0; i < filter_height * filter_width; ++i) {
893         *dst = *src;
894         dst++;
895         src += output_depth;
896       }
897     }
898     attr->weights = std::move(weights);
899   }
900 };
901 
902 class DepthToSpaceOperationParser : public TFLiteOperationParser {
903  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)904   absl::Status IsSupported(const TfLiteContext* context,
905                            const TfLiteNode* tflite_node,
906                            const TfLiteRegistration* registration) final {
907     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
908   }
909 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)910   absl::Status Parse(const TfLiteNode* tflite_node,
911                      const TfLiteRegistration* registration,
912                      GraphFloat32* graph, ObjectReader* reader) final {
913     Node* node = graph->NewNode();
914     node->operation.type = ToString(OperationType::DEPTH_TO_SPACE);
915     RETURN_IF_ERROR(reader->AddInput(node, 0));
916     RETURN_IF_ERROR(reader->AddOutputs(node));
917     const TfLiteDepthToSpaceParams* tf_options;
918     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
919     SpaceToDepthAttributes attr;
920     attr.block_size = tf_options->block_size;
921     node->operation.attributes = attr;
922     return absl::OkStatus();
923   }
924 };
925 
926 class DequantizeOperationParser : public TFLiteOperationParser {
927  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)928   absl::Status IsSupported(const TfLiteContext* context,
929                            const TfLiteNode* tflite_node,
930                            const TfLiteRegistration* registration) final {
931     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
932     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
933   }
934 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)935   absl::Status Parse(const TfLiteNode* tflite_node,
936                      const TfLiteRegistration* registration,
937                      GraphFloat32* graph, ObjectReader* reader) final {
938     // 'Dequantize' is rewritten as QuantizeAndDequantize since we are dealing
939     // with floating-point versions of the original tensors.
940     const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
941     if (runtime_inputs == 0) {
942       // constant input, can be dequantized here
943       ConstTensorAttributes attr;
944       RETURN_IF_ERROR(reader->ReadTensor(0, &attr.tensor));
945       Node* node = graph->NewNode();
946       node->operation.attributes = attr;
947       node->operation.type = ToString(OperationType::CONSTANT);
948       return reader->AddOutputs(node);
949     }
950     Node* node = graph->NewNode();
951     node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
952     // Non-constant dequantization.
953     RETURN_IF_ERROR(reader->AddInput(node, 0));
954     RETURN_IF_ERROR(reader->AddOutputs(node));
955 
956     // Quantization attributes should already be present in the input tensor.
957     auto input_value = graph->FindInputs(node->id)[0];
958     if (!input_value->quant_params) {
959       if (runtime_inputs == 1) {
960         // DEQUANTIZE op is preceded by DENSIFY op and doesn't have any
961         // quantization params. The DEQUANTIZE op latter will be removed from
962         // the graph in `MergeDensify` graph transformation.
963         return absl::OkStatus();
964       }
965       return absl::InvalidArgumentError(
966           "Encountered Dequantize input with no quant params");
967     }
968     QuantizeAndDequantizeAttributes attr;
969     attr.min = input_value->quant_params.value().min;
970     attr.max = input_value->quant_params.value().max;
971     attr.scale = input_value->quant_params.value().scale;
972 
973     node->operation.attributes = attr;
974     return absl::OkStatus();
975   }
976 };
977 
978 class ElementwiseOperationParser : public TFLiteOperationParser {
979  public:
ElementwiseOperationParser(OperationType operation_type)980   explicit ElementwiseOperationParser(OperationType operation_type)
981       : operation_type_(operation_type) {}
982 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)983   absl::Status IsSupported(const TfLiteContext* context,
984                            const TfLiteNode* tflite_node,
985                            const TfLiteRegistration* registration) final {
986     const int kMaxSupportedOpVersion =
987         operation_type_ == OperationType::MUL ? 3 : 2;
988     RETURN_IF_ERROR(
989         CheckMaxSupportedOpVersion(registration, kMaxSupportedOpVersion));
990     if (IsLogicalOp(operation_type_)) {
991       TensorInfo output_tensor_info;
992       RETURN_IF_ERROR(GetTensorInfo(context, tflite_node->outputs->data[0],
993                                     &output_tensor_info));
994       if (output_tensor_info.producers.size() != 1 ||
995           output_tensor_info.consumers.size() != 1) {
996         return absl::UnavailableError("Not supported logical op case");
997       }
998       const auto& next_node = output_tensor_info.consumers[0];
999       TfLiteType dst_type =
1000           context->tensors[next_node.first->outputs->data[0]].type;
1001       if (next_node.second->builtin_code == kTfLiteBuiltinCast &&
1002           (dst_type == kTfLiteFloat16 || dst_type == kTfLiteFloat32)) {
1003         return absl::OkStatus();
1004       } else {
1005         return absl::UnimplementedError("Not supported logical op case.");
1006       }
1007     }
1008     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1009   }
1010 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1011   absl::Status Parse(const TfLiteNode* tflite_node,
1012                      const TfLiteRegistration* registration,
1013                      GraphFloat32* graph, ObjectReader* reader) final {
1014     Node* node = graph->NewNode();
1015     node->operation.type = ToString(operation_type_);
1016     if (operation_type_ == OperationType::ADD) {
1017       ElementwiseAttributes attr;
1018       node->operation.attributes = std::move(attr);
1019     }
1020 
1021     if (IsOneArgumentOperation()) {
1022       RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
1023                                                         /*runtime_inputs=*/1,
1024                                                         /*const_inputs=*/0,
1025                                                         /*outputs=*/1));
1026 
1027       RETURN_IF_ERROR(reader->AddInput(node, 0));
1028     } else if (IsTwoArgumentOperation() &&
1029                reader
1030                    ->VerifyInputsConstsOutputs(tflite_node,
1031                                                /*runtime_inputs=*/2,
1032                                                /*const_inputs=*/0,
1033                                                /*outputs=*/1)
1034                    .ok()) {
1035       if (tflite_node->inputs->size != 2) {
1036         return absl::InvalidArgumentError("Applies only two input tensors");
1037       }
1038       const TfLiteTensor* input0 = reader->GetInputTensor(0);
1039       const TfLiteTensor* input1 = reader->GetInputTensor(1);
1040 
1041       // TODO(b/166831113): Support the same inputs for operations.
1042       if (input0 == input1) {
1043         if (operation_type_ == OperationType::MUL) {
1044           // replace MUL(A, A) with SQUARE(A)
1045           node->operation.type = ToString(OperationType::SQUARE);
1046           RETURN_IF_ERROR(reader->AddInput(node, 0));
1047         } else if (operation_type_ == OperationType::ADD) {
1048           // replace ADD(A, A) with MUL(A, 2.0)
1049           node->operation.type = ToString(OperationType::MUL);
1050           ElementwiseAttributes attr;
1051           attr.param = 2.0f;
1052           node->operation.attributes = std::move(attr);
1053           RETURN_IF_ERROR(reader->AddInput(node, 0));
1054         } else {
1055           return absl::UnimplementedError(
1056               "No support of few identical inputs in the same operation.");
1057         }
1058       } else {
1059         int input_tensor0 = 0;
1060         int input_tensor1 = 1;
1061         if (operation_type_ == OperationType::MUL ||
1062             operation_type_ == OperationType::ADD) {
1063           // The "larger" input tensor must be bound to 1st input and the
1064           // "smaller" input tensor must be bound to 2nd input.
1065           BHWC shape0;
1066           RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
1067           BHWC shape1;
1068           RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1));
1069           if (shape0.h <= shape1.h && shape0.w <= shape1.w &&
1070               shape0.c == shape1.c) {
1071             input_tensor0 = 1;
1072             input_tensor1 = 0;
1073           }
1074         }
1075 
1076         RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
1077         RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
1078       }
1079     } else if (IsTwoArgumentOperationWithConst()) {
1080       RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
1081                                                         /*runtime_inputs=*/1,
1082                                                         /*const_inputs=*/1,
1083                                                         /*outputs=*/1));
1084       ElementwiseAttributes attr;
1085       RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
1086       attr.runtime_tensor_is_second =
1087           IsConstantTensor(reader->GetInputTensor(0));
1088       node->operation.attributes = std::move(attr);
1089     } else {
1090       return absl::InvalidArgumentError("Incorrect operation type passed");
1091     }
1092 
1093     RETURN_IF_ERROR(reader->AddOutputs(node));
1094     return MaybeFuseActivationForElementwiseNode(operation_type_, tflite_node,
1095                                                  graph, node);
1096   }
1097 
1098  private:
GetActivation(const TfLiteNode * tflite_node,TfLiteFusedActivation * activation) const1099   absl::Status GetActivation(const TfLiteNode* tflite_node,
1100                              TfLiteFusedActivation* activation) const {
1101     if (operation_type_ == OperationType::DIV) {
1102       const TfLiteDivParams* tf_options;
1103       auto status = RetrieveBuiltinData(tflite_node, &tf_options);
1104       *activation = status.ok() ? tf_options->activation : kTfLiteActNone;
1105       return absl::OkStatus();
1106     }
1107     if (operation_type_ == OperationType::SUB) {
1108       const TfLiteSubParams* tf_options;
1109       auto status = RetrieveBuiltinData(tflite_node, &tf_options);
1110       *activation = status.ok() ? tf_options->activation : kTfLiteActNone;
1111       return absl::OkStatus();
1112     }
1113 
1114     // Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or
1115     // TfLiteXxxParams.activation.
1116     *activation = kTfLiteActNone;
1117     return absl::OkStatus();
1118   }
1119 
IsOneArgumentOperation() const1120   bool IsOneArgumentOperation() const {
1121     switch (operation_type_) {
1122       case OperationType::ABS:
1123       case OperationType::COPY:
1124       case OperationType::COS:
1125       case OperationType::ELU:
1126       case OperationType::EXP:
1127       case OperationType::FLOOR:
1128       case OperationType::LOG:
1129       case OperationType::NEG:
1130       case OperationType::RSQRT:
1131       case OperationType::SIGMOID:
1132       case OperationType::SIN:
1133       case OperationType::SQRT:
1134       case OperationType::SQUARE:
1135       case OperationType::TANH:
1136         return true;
1137       default:
1138         return false;
1139     }
1140   }
1141 
IsTwoArgumentOperation() const1142   bool IsTwoArgumentOperation() const {
1143     switch (operation_type_) {
1144       case OperationType::ADD:
1145       case OperationType::DIV:
1146       case OperationType::EQUAL:
1147       case OperationType::FLOOR_DIV:
1148       case OperationType::FLOOR_MOD:
1149       case OperationType::GREATER:
1150       case OperationType::GREATER_EQUAL:
1151       case OperationType::LESS:
1152       case OperationType::LESS_EQUAL:
1153       case OperationType::MAXIMUM:
1154       case OperationType::MINIMUM:
1155       case OperationType::MUL:
1156       case OperationType::NOT_EQUAL:
1157       case OperationType::POW:
1158       case OperationType::SQUARED_DIFF:
1159       case OperationType::SUB:
1160         return true;
1161       default:
1162         return false;
1163     }
1164   }
1165 
IsTwoArgumentOperationWithConst() const1166   bool IsTwoArgumentOperationWithConst() const {
1167     switch (operation_type_) {
1168       case OperationType::ADD:
1169       case OperationType::DIV:
1170       case OperationType::EQUAL:
1171       case OperationType::FLOOR_DIV:
1172       case OperationType::FLOOR_MOD:
1173       case OperationType::GREATER:
1174       case OperationType::GREATER_EQUAL:
1175       case OperationType::LESS:
1176       case OperationType::LESS_EQUAL:
1177       case OperationType::MAXIMUM:
1178       case OperationType::MINIMUM:
1179       case OperationType::MUL:
1180       case OperationType::NOT_EQUAL:
1181       case OperationType::POW:
1182       case OperationType::SQUARED_DIFF:
1183       case OperationType::SUB:
1184         return true;
1185       default:
1186         return false;
1187     }
1188   }
1189 
1190   OperationType operation_type_;
1191 };
1192 
1193 class FullyConnectedOperationParser : public TFLiteOperationParser {
1194  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1195   absl::Status IsSupported(const TfLiteContext* context,
1196                            const TfLiteNode* tflite_node,
1197                            const TfLiteRegistration* registration) final {
1198     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 9));
1199     // TODO(eignasheva): check input shape
1200     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1201   }
1202 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1203   absl::Status Parse(const TfLiteNode* tflite_node,
1204                      const TfLiteRegistration* registration,
1205                      GraphFloat32* graph, ObjectReader* reader) final {
1206     const TfLiteFullyConnectedParams* tf_options;
1207     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1208 
1209     if (reader->GetNumberOfRuntimeInputs() == 2) {
1210       // Create Convolution2D, so as it supports runtime weights.
1211       Node* node = graph->NewNode();
1212       node->operation.type = ToString(OperationType::CONVOLUTION_2D);
1213       RETURN_IF_ERROR(reader->AddInput(node, 0));
1214       RETURN_IF_ERROR(reader->AddInput(node, 1));
1215 
1216       const TfLiteTensor* input_tensor = reader->GetInputTensor(0);
1217       BHWC input_shape;
1218       RETURN_IF_ERROR(ExtractTensorShape(*input_tensor, &input_shape));
1219       const TfLiteTensor* input2_tensor = reader->GetInputTensor(1);
1220       BHWC input2_shape;
1221       RETURN_IF_ERROR(ExtractTensorShape(*input2_tensor, &input2_shape));
1222       const TfLiteTensor* output_tensor = reader->GetOutputTensor(0);
1223       BHWC output_shape;
1224       RETURN_IF_ERROR(ExtractTensorShape(*output_tensor, &output_shape));
1225       BHWC output_ref_shape = input_shape;
1226       output_ref_shape.c = input2_shape.b;
1227       if (output_ref_shape != output_shape) {
1228         Value* copy_value = graph->NewValue();
1229         auto input_value = graph->FindInputs(node->id)[0];
1230         copy_value->tensor.type = input_value->tensor.type;
1231         copy_value->tensor.shape = output_ref_shape;
1232         Node* node_reshape = graph->NewNode();
1233         node_reshape->operation.type = ToString(OperationType::RESHAPE);
1234         ReshapeAttributes reshape_attr;
1235         reshape_attr.new_shape = output_shape;
1236         node_reshape->operation.attributes = reshape_attr;
1237         RETURN_IF_ERROR(graph->SetProducer(node->id, copy_value->id));
1238         RETURN_IF_ERROR(graph->AddConsumer(node_reshape->id, copy_value->id));
1239         RETURN_IF_ERROR(reader->AddOutputs(node_reshape));
1240       } else {
1241         RETURN_IF_ERROR(reader->AddOutputs(node));
1242       }
1243 
1244       Convolution2DAttributes attr;
1245       reader->ReadTensor(2, &attr.bias).IgnoreError();  // bias is optional
1246 
1247       attr.strides = HW(1, 1);
1248       attr.dilations = HW(1, 1);
1249       attr.padding.appended = HW(0, 0);
1250       attr.padding.prepended = HW(0, 0);
1251       RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
1252       node->operation.attributes = std::move(attr);
1253       return absl::OkStatus();
1254     }
1255     Node* node = graph->NewNode();
1256     RETURN_IF_ERROR(reader->AddInput(node, 0));
1257 
1258     if (tf_options->weights_format !=
1259         kTfLiteFullyConnectedWeightsFormatDefault) {
1260       return absl::UnimplementedError(
1261           "Unsupported FullyConnected weights format.");
1262     }
1263 
1264     FullyConnectedAttributes attr;
1265     RETURN_IF_ERROR(GetFullyConnectedAttributes(1, 2, reader, &attr));
1266 
1267     auto input = graph->FindInputs(node->id)[0];
1268     if (input->tensor.shape.c != attr.weights.shape.i) {
1269       return absl::UnimplementedError(
1270           "Amount of input channels should match weights width");
1271     }
1272 
1273     Node* conv = node;
1274     if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
1275       // In Gpu delegates assume that height and width = 1 for FullyConnected
1276       // Using usual convolution2d when height or width != 1
1277       Convolution2DAttributes conv_attr;
1278       conv_attr.strides = HW(1, 1);
1279       conv_attr.dilations = HW(1, 1);
1280       conv_attr.padding.appended = HW(0, 0);
1281       conv_attr.padding.prepended = HW(0, 0);
1282       conv_attr.weights = attr.weights;
1283       conv_attr.bias = attr.bias;
1284       conv->operation.type = ToString(OperationType::CONVOLUTION_2D);
1285       conv->operation.attributes = std::move(conv_attr);
1286     } else {
1287       conv->operation.type = ToString(OperationType::FULLY_CONNECTED);
1288       conv->operation.attributes = std::move(attr);
1289     }
1290     RETURN_IF_ERROR(reader->AddOutputs(conv));
1291     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, conv));
1292     return absl::OkStatus();
1293   }
1294 };
1295 
1296 class HardSwishOperationParser : public TFLiteOperationParser {
1297  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1298   absl::Status IsSupported(const TfLiteContext* context,
1299                            const TfLiteNode* tflite_node,
1300                            const TfLiteRegistration* registration) final {
1301     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1302   }
1303 
Parse(const TfLiteNode *,const TfLiteRegistration *,GraphFloat32 * graph,ObjectReader * reader)1304   absl::Status Parse(const TfLiteNode*, const TfLiteRegistration*,
1305                      GraphFloat32* graph, ObjectReader* reader) final {
1306     Node* node = graph->NewNode();
1307     node->operation.type = ToString(OperationType::HARD_SWISH);
1308     RETURN_IF_ERROR(reader->AddInput(node, 0));
1309     return reader->AddOutputs(node);
1310   }
1311 };
1312 
1313 // Basic LSTM Cell:
1314 //
1315 //  1name = name is at input  index 1
1316 //  name1 = name is at output index 1
1317 //
1318 //    0input     1prev_activ
1319 //       \        /
1320 //        [[concat]]
1321 //             \
1322 //       concat_temp2  2weights  3biases
1323 //              \      /        /
1324 //             [[fully-connected]]
1325 //               \
1326 //         activ_temp3    4prev_state
1327 //                 \      /
1328 //                 [[LSTM]]
1329 //                 /      \
1330 //           new_state1    activation0
1331 //
1332 // For full LSTM cells, see this blog post:
1333 // https://colah.github.io/posts/2015-08-Understanding-LSTMs/
1334 // In addition to Peephole connections and Combined Input Forget Gates (CIFG)
1335 // described in that post, this code also adds the following optional features:
1336 // - Configurable activations (sigmoid or TANH)
1337 // - L2 Normalization of gates: https://arxiv.org/abs/1607.06450
1338 // - Output projection:
1339 //     https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html
1340 // - Configurable clipping of cell state and output state.
1341 class LSTMOperationParser : public TFLiteOperationParser {
1342  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1343   absl::Status IsSupported(const TfLiteContext* context,
1344                            const TfLiteNode* tflite_node,
1345                            const TfLiteRegistration* registration) final {
1346     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 4));
1347     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1348   }
1349 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1350   absl::Status Parse(const TfLiteNode* tflite_node,
1351                      const TfLiteRegistration* registration,
1352                      GraphFloat32* graph, ObjectReader* reader) final {
1353     const TfLiteLSTMParams* tf_options;
1354     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1355     switch (tf_options->kernel_type) {
1356       case kTfLiteLSTMFullKernel:
1357         return ParseFull(tflite_node, registration, graph, reader, tf_options);
1358       case kTfLiteLSTMBasicKernel:
1359         return ParseBasic(tflite_node, registration, graph, reader, tf_options);
1360     }
1361   }
1362 
GetNewValueIdsForVariableInputNodes()1363   absl::flat_hash_map<int, ValueId> GetNewValueIdsForVariableInputNodes()
1364       final {
1365     return new_variable_input_value_map_;
1366   }
1367 
1368  private:
ParseBasic(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * tf_options)1369   absl::Status ParseBasic(const TfLiteNode* tflite_node,
1370                           const TfLiteRegistration* registration,
1371                           GraphFloat32* graph, ObjectReader* reader,
1372                           const TfLiteLSTMParams* tf_options) {
1373     if (tflite_node->inputs->size != 5) {
1374       return absl::InvalidArgumentError("LSTM should have 5 input tensors");
1375     }
1376     if (tflite_node->outputs->size != 4) {
1377       return absl::InvalidArgumentError("LSTM should have 4 output tensors");
1378     }
1379     RETURN_IF_ERROR(CheckBasicParameters(tf_options));
1380 
1381     Node* concat_node = graph->NewNode();
1382     concat_node->operation.type = ToString(OperationType::CONCAT);
1383     ConcatAttributes concat_attr;
1384     concat_attr.axis = Axis::CHANNELS;
1385     concat_node->operation.attributes = concat_attr;
1386 
1387     Node* fc_node = graph->NewNode();
1388     fc_node->operation.type = ToString(OperationType::FULLY_CONNECTED);
1389     FullyConnectedAttributes fc_attr;
1390     RETURN_IF_ERROR(GetFullyConnectedAttributes(2, 3, reader, &fc_attr));
1391     fc_node->operation.attributes = std::move(fc_attr);
1392 
1393     Node* lstm_node = graph->NewNode();
1394     lstm_node->operation.type = ToString(OperationType::LSTM);
1395     LstmAttributes lstm_attr;
1396     lstm_attr.kernel_type = LstmKernelType::BASIC;
1397     lstm_node->operation.attributes = lstm_attr;
1398 
1399     Value* concat_temp;
1400     int concat_tensor_idx = tflite_node->outputs->data[2];
1401     RETURN_IF_ERROR(
1402         reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
1403     Value* activ_temp;
1404     int activ_tensor_idx = tflite_node->outputs->data[3];
1405     RETURN_IF_ERROR(
1406         reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
1407 
1408     RETURN_IF_ERROR(reader->AddInput(concat_node, 0));  // input
1409     RETURN_IF_ERROR(reader->AddInput(concat_node, 1));  // prev_activ
1410     RETURN_IF_ERROR(graph->SetProducer(concat_node->id, concat_temp->id));
1411 
1412     RETURN_IF_ERROR(graph->AddConsumer(fc_node->id, concat_temp->id));
1413     RETURN_IF_ERROR(graph->SetProducer(fc_node->id, activ_temp->id));
1414 
1415     RETURN_IF_ERROR(graph->AddConsumer(lstm_node->id, activ_temp->id));
1416     RETURN_IF_ERROR(reader->AddInput(lstm_node, 4));   // prev_state
1417     RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1));  // new_state
1418     RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0));  // activation
1419 
1420     return absl::OkStatus();
1421   }
1422 
CheckBasicParameters(const TfLiteLSTMParams * tf_options)1423   absl::Status CheckBasicParameters(const TfLiteLSTMParams* tf_options) {
1424     if (tf_options->activation != kTfLiteActTanh) {
1425       return absl::UnimplementedError("Only TANH activation is supported.");
1426     }
1427     if (tf_options->cell_clip != 0.0f) {
1428       return absl::UnimplementedError("cell_clip is not supported.");
1429     }
1430     if (tf_options->proj_clip != 0.0f) {
1431       return absl::UnimplementedError("proj_clip is not supported.");
1432     }
1433     return absl::OkStatus();
1434   }
1435 
ParseFull(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * tf_options)1436   absl::Status ParseFull(const TfLiteNode* tflite_node,
1437                          const TfLiteRegistration* registration,
1438                          GraphFloat32* graph, ObjectReader* reader,
1439                          const TfLiteLSTMParams* tf_options) {
1440     // Invoke full LSTM parser
1441     RETURN_IF_ERROR(ParseLSTMAttributes(tflite_node, registration, graph,
1442                                         reader, tf_options,
1443                                         &new_variable_input_value_map_));
1444     return absl::OkStatus();
1445   }
1446 
CheckFullParameters(const TfLiteLSTMParams * tf_options)1447   absl::Status CheckFullParameters(const TfLiteLSTMParams* tf_options) {
1448     if (tf_options->activation != kTfLiteActSigmoid &&
1449         tf_options->activation != kTfLiteActTanh) {
1450       return absl::UnimplementedError(
1451           "Only sigmoid or tanh activation is supported.");
1452     }
1453 
1454     return absl::OkStatus();
1455   }
1456 
1457   absl::flat_hash_map<int, ValueId> new_variable_input_value_map_;
1458 };
1459 
1460 class OneHotOperationParser : public TFLiteOperationParser {
1461  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1462   absl::Status IsSupported(const TfLiteContext* context,
1463                            const TfLiteNode* tflite_node,
1464                            const TfLiteRegistration* registration) final {
1465     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1466     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1467   }
1468 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1469   absl::Status Parse(const TfLiteNode* tflite_node,
1470                      const TfLiteRegistration* registration,
1471                      GraphFloat32* graph, ObjectReader* reader) final {
1472     Node* node = graph->NewNode();
1473     OneHotAttributes attr;
1474     const TfLiteTensor* on_tensor = reader->GetInputTensor(2);
1475     const TfLiteTensor* off_tensor = reader->GetInputTensor(3);
1476     attr.on_value = GetTensorData<float>(on_tensor)[0];
1477     attr.off_value = GetTensorData<float>(off_tensor)[0];
1478     node->operation.type = ToString(OperationType::ONE_HOT);
1479     node->operation.attributes = std::move(attr);
1480     RETURN_IF_ERROR(reader->AddInput(node, 0));
1481     RETURN_IF_ERROR(reader->AddOutputs(node));
1482     return absl::OkStatus();
1483   }
1484 };
1485 
1486 class PackOperationParser : public TFLiteOperationParser {
1487  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1488   absl::Status IsSupported(const TfLiteContext* context,
1489                            const TfLiteNode* tflite_node,
1490                            const TfLiteRegistration* registration) final {
1491     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1492   }
1493 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1494   absl::Status Parse(const TfLiteNode* tflite_node,
1495                      const TfLiteRegistration* registration,
1496                      GraphFloat32* graph, ObjectReader* reader) final {
1497     if (tflite_node->inputs->size == 1) {
1498       // Pack with single input can be replaced with Reshape
1499       Node* node = graph->NewNode();
1500       node->operation.type = ToString(OperationType::RESHAPE);
1501       RETURN_IF_ERROR(reader->AddInput(node, 0));
1502       RETURN_IF_ERROR(reader->AddOutputs(node));
1503       // New shape comes from output shape.
1504       ReshapeAttributes attr;
1505       attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1506       node->operation.attributes = attr;
1507       return absl::OkStatus();
1508     } else {
1509       // Pack with few inputs can be replaced with Concat
1510       const TfLitePackParams* tf_options;
1511       RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1512 
1513       // Read inputs first to make sure const node is added to a graph before
1514       // concat node to ensure topological order.
1515       std::vector<const Value*> inputs;
1516       for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
1517         Value* value;
1518         const auto status = reader->ReadValue(idx, &value);
1519         if (status.ok()) {
1520           inputs.push_back(value);
1521         } else {
1522           TensorFloat32 tensor;
1523           RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
1524           Value* value;
1525           RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
1526           inputs.push_back(value);
1527         }
1528       }
1529 
1530       const TfLiteTensor* output = reader->GetOutputTensor(0);
1531       ConcatAttributes attr;
1532       RETURN_IF_ERROR(
1533           ExtractAxisFromIndex(*output, tf_options->axis, &attr.axis));
1534       BHWC output_shape;
1535       RETURN_IF_ERROR(ExtractTensorShape(*output, &output_shape));
1536       BHWC input_required_shape = output_shape;
1537       input_required_shape.set(attr.axis, 1);
1538       for (int i = 0; i < inputs.size(); ++i) {
1539         BHWC input_shape = inputs[i]->tensor.shape;
1540         if (input_shape != input_required_shape) {
1541           // GPU delegates does not support implicit shapes transformations
1542           // adding explicit Reshape
1543           Node* node_reshape = graph->NewNode();
1544           node_reshape->operation.type = ToString(OperationType::RESHAPE);
1545           ReshapeAttributes reshape_attr;
1546           reshape_attr.new_shape = input_required_shape;
1547           node_reshape->operation.attributes = reshape_attr;
1548           RETURN_IF_ERROR(graph->AddConsumer(node_reshape->id, inputs[i]->id));
1549           Value* copy_value = graph->NewValue();
1550           copy_value->tensor.type = inputs[i]->tensor.type;
1551           copy_value->tensor.shape = input_required_shape;
1552           RETURN_IF_ERROR(graph->SetProducer(node_reshape->id, copy_value->id));
1553           inputs[i] = copy_value;
1554         }
1555       }
1556 
1557       Node* node = graph->NewNode();
1558       node->operation.type = ToString(OperationType::CONCAT);
1559       RETURN_IF_ERROR(reader->AddOutputs(node));
1560       for (const Value* input : inputs) {
1561         RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1562       }
1563       node->operation.attributes = attr;
1564       return absl::OkStatus();
1565     }
1566   }
1567 };
1568 
1569 class PReLUOperationParser : public TFLiteOperationParser {
1570  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1571   absl::Status IsSupported(const TfLiteContext* context,
1572                            const TfLiteNode* tflite_node,
1573                            const TfLiteRegistration* registration) final {
1574     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1575     // TODO(eignasheva): add params check
1576     return absl::OkStatus();
1577   }
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1578   absl::Status Parse(const TfLiteNode* tflite_node,
1579                      const TfLiteRegistration* registration,
1580                      GraphFloat32* graph, ObjectReader* reader) final {
1581     Node* node = graph->NewNode();
1582     node->operation.type = ToString(OperationType::PRELU);
1583     RETURN_IF_ERROR(reader->AddInput(node, 0));
1584     auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1585 
1586     PReLUAttributes attr;
1587     Tensor<Linear, DataType::FLOAT32> linear_alpha;
1588     absl::Status status = reader->ReadTensor(1, &linear_alpha);
1589     if (status.ok()) {
1590       if (linear_alpha.shape.v != input_shape.c) {
1591         return absl::InvalidArgumentError(
1592             "Linear alpha shape does not match the number of input channels.");
1593       }
1594       attr.alpha = std::move(linear_alpha);
1595     } else {
1596       Tensor<HWC, DataType::FLOAT32> hwc_alpha;
1597       RETURN_IF_ERROR(reader->ReadTensor(1, &hwc_alpha));
1598       if (hwc_alpha.shape.h != input_shape.h ||
1599           hwc_alpha.shape.w != input_shape.w ||
1600           hwc_alpha.shape.c != input_shape.c) {
1601         return absl::InvalidArgumentError(
1602             "Alpha shape does not match input shape.");
1603       }
1604       attr.alpha = std::move(hwc_alpha);
1605     }
1606     node->operation.attributes = std::move(attr);
1607     return reader->AddOutputs(node);
1608   }
1609 };
1610 
1611 class PadOperationParser : public TFLiteOperationParser {
1612  public:
PadOperationParser(bool mirror_pad)1613   explicit PadOperationParser(bool mirror_pad) : mirror_pad_(mirror_pad) {}
1614 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1615   absl::Status IsSupported(const TfLiteContext* context,
1616                            const TfLiteNode* tflite_node,
1617                            const TfLiteRegistration* registration) final {
1618     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1619     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1620   }
1621 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1622   absl::Status Parse(const TfLiteNode* tflite_node,
1623                      const TfLiteRegistration* registration,
1624                      GraphFloat32* graph, ObjectReader* reader) final {
1625     Node* node = graph->NewNode();
1626     node->operation.type = ToString(OperationType::PAD);
1627     RETURN_IF_ERROR(reader->AddInput(node, 0));
1628     RETURN_IF_ERROR(reader->AddOutputs(node));
1629 
1630     PadAttributes attr;
1631     if (mirror_pad_) {
1632       attr.type = PaddingContentType::REFLECT;
1633     } else /*zero pad*/ {
1634       attr.type = PaddingContentType::ZEROS;
1635     }
1636 
1637     Tensor<HW, DataType::INT32> paddings;
1638     RETURN_IF_ERROR(reader->ReadTensor(1, &paddings));
1639 
1640     if (paddings.shape.h == 4 && paddings.shape.w == 2) {
1641       // 4x2 tensor with paddings.
1642       attr.prepended = BHWC(paddings.data[0], paddings.data[2],
1643                             paddings.data[4], paddings.data[6]);
1644       attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5],
1645                            paddings.data[7]);
1646     } else if (paddings.shape.h == 3 && paddings.shape.w == 2) {
1647       // 3x2 tensor with paddings.
1648       attr.prepended =
1649           BHWC(1, paddings.data[0], paddings.data[2], paddings.data[4]);
1650       attr.appended =
1651           BHWC(1, paddings.data[1], paddings.data[3], paddings.data[5]);
1652     } else {
1653       // It shouldn't fail here since it's checked at IsSupported().
1654       return absl::InvalidArgumentError(
1655           "Paddings tensor has unexpected shape.");
1656     }
1657     node->operation.attributes = attr;
1658     return absl::OkStatus();
1659   }
1660 
1661  private:
1662   bool mirror_pad_ = false;
1663 };
1664 
1665 class Pooling2DOperationParser : public TFLiteOperationParser {
1666  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1667   absl::Status IsSupported(const TfLiteContext* context,
1668                            const TfLiteNode* tflite_node,
1669                            const TfLiteRegistration* registration) final {
1670     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1671     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1672   }
1673 
1674  public:
Pooling2DOperationParser(PoolingType type)1675   explicit Pooling2DOperationParser(PoolingType type) : type_(type) {}
1676 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1677   absl::Status Parse(const TfLiteNode* tflite_node,
1678                      const TfLiteRegistration* registration,
1679                      GraphFloat32* graph, ObjectReader* reader) final {
1680     Node* node = graph->NewNode();
1681     node->operation.type = ToString(OperationType::POOLING_2D);
1682     RETURN_IF_ERROR(reader->AddInput(node, 0));
1683     RETURN_IF_ERROR(reader->AddOutput(node, 0));
1684 
1685     Pooling2DAttributes attr;
1686     attr.type = type_;
1687 
1688     auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1689 
1690     // Check whether there are custom options encoded. It happens if operation
1691     // is MaxPoolingWithArgmax2D. There is no way to read
1692     // tflite_node->builtin_code, so, simply check whether custom data is
1693     // available.
1694     const TfLitePoolParams* tf_options;
1695     if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) {
1696       RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1697     }
1698 
1699     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
1700     // Second output is optional. It is not required, it but must be added after
1701     // MaybeAddFusedActivation function is called
1702     reader->AddOutput(node, 1).IgnoreError();
1703 
1704     // First output is the result of pooling operation, while second output is
1705     // indices used for pooling.
1706     auto outputs = graph->FindOutputs(node->id);
1707     attr.output_indices = outputs.size() == 2;
1708     if (attr.output_indices) {
1709       // Fix data type for output indices. In the model it is set as float32.
1710       outputs[1]->tensor.type = DataType::INT32;
1711     }
1712     RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
1713     node->operation.attributes = attr;
1714     return absl::OkStatus();
1715   }
1716 
1717  private:
1718   const PoolingType type_;
1719 };
1720 
1721 class ReduceOperationParser : public TFLiteOperationParser {
1722  public:
ReduceOperationParser(OperationType operation_type)1723   explicit ReduceOperationParser(OperationType operation_type)
1724       : operation_type_(operation_type) {}
1725 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1726   absl::Status IsSupported(const TfLiteContext* context,
1727                            const TfLiteNode* tflite_node,
1728                            const TfLiteRegistration* registration) final {
1729     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1730   }
1731 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1732   absl::Status Parse(const TfLiteNode* tflite_node,
1733                      const TfLiteRegistration* registration,
1734                      GraphFloat32* graph, ObjectReader* reader) final {
1735     Node* node = graph->NewNode();
1736     node->operation.type = ToString(operation_type_);
1737     RETURN_IF_ERROR(reader->AddInput(node, 0));
1738 
1739     const TfLiteReducerParams* tf_options;
1740     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1741 
1742     ReduceAttributes attr;
1743     const TfLiteTensor* input = reader->GetInputTensor(0);
1744     const TfLiteTensor* axes = reader->GetInputTensor(1);
1745     for (int i = 0; i < NumElements(axes->dims); i++) {
1746       Axis axis;
1747       RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes->data.i32[i], &axis));
1748       attr.dims.insert(axis);
1749     }
1750     node->operation.attributes = attr;
1751 
1752     if (!tf_options->keep_dims) {
1753       // GPU delegates does not support implicit shapes transformations
1754       // adding explicit Reshape
1755       const auto& input_tensor = graph->FindInputs(node->id)[0]->tensor;
1756       auto reduce_output_shape = input_tensor.shape;
1757       for (auto axis : attr.dims) {
1758         reduce_output_shape.set(axis, 1);
1759       }
1760       Node* node_reshape = graph->NewNode();
1761       node_reshape->operation.type = ToString(OperationType::RESHAPE);
1762       ReshapeAttributes reshape_attr;
1763       const TfLiteTensor* output = reader->GetOutputTensor(0);
1764       RETURN_IF_ERROR(ExtractTensorShape(*output, &reshape_attr.new_shape));
1765       node_reshape->operation.attributes = reshape_attr;
1766       Value* reduce_result = graph->NewValue();
1767       reduce_result->tensor.type = input_tensor.type;
1768       reduce_result->tensor.shape = reduce_output_shape;
1769 
1770       RETURN_IF_ERROR(graph->SetProducer(node->id, reduce_result->id));
1771       RETURN_IF_ERROR(graph->AddConsumer(node_reshape->id, reduce_result->id));
1772       RETURN_IF_ERROR(reader->AddOutputs(node_reshape));
1773     } else {
1774       RETURN_IF_ERROR(reader->AddOutputs(node));
1775     }
1776     return absl::OkStatus();
1777   }
1778 
1779  private:
1780   const OperationType operation_type_;
1781 };
1782 
1783 class QuantizeOperationParser : public TFLiteOperationParser {
1784  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1785   absl::Status IsSupported(const TfLiteContext* context,
1786                            const TfLiteNode* tflite_node,
1787                            const TfLiteRegistration* registration) final {
1788     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1789     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1790   }
1791 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1792   absl::Status Parse(const TfLiteNode* tflite_node,
1793                      const TfLiteRegistration* registration,
1794                      GraphFloat32* graph, ObjectReader* reader) final {
1795     // 'Quantize' is rewritten as QuantizeAndDequantize since we are dealing
1796     // with floating-point versions of the original tensors.
1797     Node* node = graph->NewNode();
1798     node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
1799     RETURN_IF_ERROR(reader->AddInput(node, 0));
1800     RETURN_IF_ERROR(reader->AddOutputs(node));
1801 
1802     // Quantization attributes should already be present in the output tensor.
1803     auto output_value = graph->FindOutputs(node->id)[0];
1804     if (!output_value->quant_params) {
1805       return absl::InvalidArgumentError(
1806           "Encountered Quantize output with no quant params");
1807     }
1808     QuantizeAndDequantizeAttributes attr;
1809     attr.min = output_value->quant_params.value().min;
1810     attr.max = output_value->quant_params.value().max;
1811     attr.scale = output_value->quant_params.value().scale;
1812 
1813     node->operation.attributes = attr;
1814     return absl::OkStatus();
1815   }
1816 };
1817 
1818 class ReLUOperationParser : public TFLiteOperationParser {
1819  public:
ReLUOperationParser(int clip)1820   explicit ReLUOperationParser(int clip) : clip_(clip) {}
1821 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1822   absl::Status IsSupported(const TfLiteContext* context,
1823                            const TfLiteNode* tflite_node,
1824                            const TfLiteRegistration* registration) final {
1825     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1826     return absl::OkStatus();
1827   }
1828 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1829   absl::Status Parse(const TfLiteNode* tflite_node,
1830                      const TfLiteRegistration* registration,
1831                      GraphFloat32* graph, ObjectReader* reader) final {
1832     Node* node = graph->NewNode();
1833     node->operation.type = ToString(OperationType::RELU);
1834     RETURN_IF_ERROR(reader->AddInput(node, 0));
1835 
1836     ReLUAttributes attr;
1837     const TfLiteLeakyReluParams* tf_options;
1838     auto status = RetrieveBuiltinData(tflite_node, &tf_options);
1839     attr.alpha = status.ok() ? tf_options->alpha : 0;
1840     attr.clip = clip_;
1841     node->operation.attributes = attr;
1842     return reader->AddOutputs(node);
1843   }
1844 
1845  private:
1846   const int clip_;
1847 };
1848 
1849 class ResamplerOperationParser : public TFLiteOperationParser {
1850  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1851   absl::Status IsSupported(const TfLiteContext* context,
1852                            const TfLiteNode* tflite_node,
1853                            const TfLiteRegistration* registration) final {
1854     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1855   }
1856 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1857   absl::Status Parse(const TfLiteNode* tflite_node,
1858                      const TfLiteRegistration* registration,
1859                      GraphFloat32* graph, ObjectReader* reader) final {
1860     Node* node = graph->NewNode();
1861     RETURN_IF_ERROR(reader->AddInput(node, 0));  // src
1862     RETURN_IF_ERROR(reader->AddInput(node, 1));  // warp
1863     RETURN_IF_ERROR(reader->AddOutputs(node));
1864 
1865     node->operation.type = ToString(OperationType::RESAMPLER);
1866 
1867     auto src_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1868     auto warp_shape = graph->FindInputs(node->id)[1]->tensor.shape;
1869 
1870     auto output_value = graph->FindOutputs(node->id)[0];
1871     output_value->tensor.shape =
1872         BHWC(src_shape.b, warp_shape.h, warp_shape.w, src_shape.c);
1873     return absl::OkStatus();
1874   }
1875 };
1876 
1877 class ReshapeOperationParser : public TFLiteOperationParser {
1878  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1879   absl::Status IsSupported(const TfLiteContext* context,
1880                            const TfLiteNode* tflite_node,
1881                            const TfLiteRegistration* registration) final {
1882     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1883     // TODO(eignasheva): add shape checking
1884     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1885   }
1886 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1887   absl::Status Parse(const TfLiteNode* tflite_node,
1888                      const TfLiteRegistration* registration,
1889                      GraphFloat32* graph, ObjectReader* reader) final {
1890     Node* node = graph->NewNode();
1891     node->operation.type = ToString(OperationType::RESHAPE);
1892     RETURN_IF_ERROR(reader->AddInput(node, 0));
1893     RETURN_IF_ERROR(reader->AddOutputs(node));
1894     // Here we may have extra inputs. Other tensors were supposed to
1895     // define new shape, but in TFLite these are ignored.
1896     // TODO(akulik): check that shapes match?
1897 
1898     // New shape comes from output shape.
1899     ReshapeAttributes attr;
1900     attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1901     node->operation.attributes = attr;
1902     return absl::OkStatus();
1903   }
1904 };
1905 
1906 class Resize2DOperationParser : public TFLiteOperationParser {
1907  public:
Resize2DOperationParser(SamplingType sampling_type)1908   explicit Resize2DOperationParser(SamplingType sampling_type)
1909       : sampling_type_(sampling_type) {}
1910 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1911   absl::Status IsSupported(const TfLiteContext* context,
1912                            const TfLiteNode* tflite_node,
1913                            const TfLiteRegistration* registration) final {
1914     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
1915     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1916   }
1917 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1918   absl::Status Parse(const TfLiteNode* tflite_node,
1919                      const TfLiteRegistration* registration,
1920                      GraphFloat32* graph, ObjectReader* reader) final {
1921     Node* node = graph->NewNode();
1922     node->operation.type = ToString(OperationType::RESIZE);
1923     RETURN_IF_ERROR(reader->AddInput(node, 0));
1924     RETURN_IF_ERROR(reader->AddOutputs(node));
1925     // Here we may have extra inputs. Other tensors were supposed to
1926     // define new shape, but in TFLite these are ignored.
1927 
1928     Resize2DAttributes attr;
1929     RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &attr.align_corners));
1930     RETURN_IF_ERROR(
1931         GetHalfPixelCentersValue(tflite_node, &attr.half_pixel_centers));
1932     attr.type = sampling_type_;
1933     attr.new_shape.CopyAllDefinedAxis(
1934         graph->FindOutputs(node->id)[0]->tensor.shape);
1935     node->operation.attributes = attr;
1936     return absl::OkStatus();
1937   }
1938 
1939  private:
GetAlignCornersValue(const TfLiteNode * tflite_node,bool * align_corners)1940   absl::Status GetAlignCornersValue(const TfLiteNode* tflite_node,
1941                                     bool* align_corners) {
1942     switch (sampling_type_) {
1943       case SamplingType::BILINEAR:
1944         return GetAlignCornersValueForType<TfLiteResizeBilinearParams>(
1945             tflite_node, align_corners);
1946       case SamplingType::NEAREST:
1947         return GetAlignCornersValueForType<TfLiteResizeNearestNeighborParams>(
1948             tflite_node, align_corners);
1949       case SamplingType::UNKNOWN:
1950         return absl::InternalError("Sampling type is not specified");
1951     }
1952     return absl::OkStatus();
1953   }
1954 
1955   template <class T>
GetAlignCornersValueForType(const TfLiteNode * tflite_node,bool * align_corners)1956   absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node,
1957                                            bool* align_corners) {
1958     const T* tf_options;
1959     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1960     *align_corners = tf_options->align_corners;
1961     return absl::OkStatus();
1962   }
1963 
GetHalfPixelCentersValue(const TfLiteNode * tflite_node,bool * half_pixel_centers)1964   absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node,
1965                                         bool* half_pixel_centers) {
1966     if (sampling_type_ == SamplingType::BILINEAR) {
1967       const TfLiteResizeBilinearParams* tf_options;
1968       RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1969       if (tf_options->align_corners && tf_options->half_pixel_centers) {
1970         return absl::InternalError(
1971             "If half_pixel_centers is True, align_corners must be False.");
1972       }
1973       *half_pixel_centers = tf_options->half_pixel_centers;
1974     } else {
1975       const TfLiteResizeNearestNeighborParams* tf_options;
1976       RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1977       *half_pixel_centers = tf_options->half_pixel_centers;
1978     }
1979     return absl::OkStatus();
1980   }
1981 
1982   SamplingType sampling_type_ = SamplingType::UNKNOWN;
1983 };
1984 
1985 class SelectV2OperationParser : public TFLiteOperationParser {
1986  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1987   absl::Status IsSupported(const TfLiteContext* context,
1988                            const TfLiteNode* tflite_node,
1989                            const TfLiteRegistration* registration) final {
1990     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1991     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1992   }
1993 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1994   absl::Status Parse(const TfLiteNode* tflite_node,
1995                      const TfLiteRegistration* registration,
1996                      GraphFloat32* graph, ObjectReader* reader) final {
1997     Node* node = graph->NewNode();
1998     SelectV2Attributes attr;
1999     const TfLiteTensor* cond_tensor = reader->GetInputTensor(0);
2000     const TfLiteTensor* true_tensor = reader->GetInputTensor(1);
2001     const TfLiteTensor* false_tensor = reader->GetInputTensor(2);
2002     const bool is_if_constant = true_tensor->allocation_type == kTfLiteMmapRo;
2003     const bool is_else_constant =
2004         false_tensor->allocation_type == kTfLiteMmapRo;
2005     BHWC cond_shape, true_shape, false_shape;
2006     RETURN_IF_ERROR(ExtractTensorShape(*cond_tensor, &cond_shape));
2007     if (true_tensor->dims->size == 0) {
2008       attr.broadcast_true = true;
2009     } else {
2010       RETURN_IF_ERROR(ExtractTensorShape(*true_tensor, &true_shape));
2011       attr.broadcast_true = true_shape.DimensionsProduct() == 1;
2012     }
2013     if (false_tensor->dims->size == 0) {
2014       attr.broadcast_false = true;
2015     } else {
2016       RETURN_IF_ERROR(ExtractTensorShape(*false_tensor, &false_shape));
2017       attr.broadcast_false = false_shape.DimensionsProduct() == 1;
2018     }
2019     node->operation.type = ToString(OperationType::SELECT_V2);
2020     Value* if_value;
2021     Value* else_value;
2022     Tensor<BHWC, DataType::FLOAT32> if_tensor;
2023     Tensor<BHWC, DataType::FLOAT32> else_tensor;
2024     if (!attr.broadcast_true) {
2025       if (is_if_constant) {
2026         RETURN_IF_ERROR(reader->ReadTensor(1, &if_tensor));
2027       }
2028     } else {
2029       Tensor<Scalar, DataType::FLOAT32> if_scalar_tensor;
2030       RETURN_IF_ERROR(reader->ReadTensor(1, &if_scalar_tensor));
2031       if_tensor.shape = BHWC(1, 1, 1, 1);
2032       if_tensor.data.push_back(if_scalar_tensor.data[0]);
2033     }
2034     if (!attr.broadcast_false) {
2035       if (is_else_constant) {
2036         RETURN_IF_ERROR(reader->ReadTensor(2, &else_tensor));
2037       }
2038     } else {
2039       Tensor<Scalar, DataType::FLOAT32> else_scalar_tensor;
2040       RETURN_IF_ERROR(reader->ReadTensor(2, &else_scalar_tensor));
2041       else_tensor.shape = BHWC(1, 1, 1, 1);
2042       else_tensor.data.push_back(else_scalar_tensor.data[0]);
2043     }
2044     node->operation.attributes = std::move(attr);
2045     RETURN_IF_ERROR(reader->AddInput(node, 0));
2046     if (is_if_constant) {
2047       RETURN_IF_ERROR(NewConstNode(if_tensor, graph, &if_value));
2048       RETURN_IF_ERROR(graph->AddConsumer(node->id, if_value->id));
2049     } else {
2050       RETURN_IF_ERROR(reader->AddInput(node, 1));
2051     }
2052     if (is_else_constant) {
2053       RETURN_IF_ERROR(NewConstNode(else_tensor, graph, &else_value));
2054       RETURN_IF_ERROR(graph->AddConsumer(node->id, else_value->id));
2055     } else {
2056       RETURN_IF_ERROR(reader->AddInput(node, 2));
2057     }
2058     RETURN_IF_ERROR(reader->AddOutputs(node));
2059     return absl::OkStatus();
2060   }
2061 };
2062 
2063 class SliceOperationParser : public TFLiteOperationParser {
2064  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2065   absl::Status IsSupported(const TfLiteContext* context,
2066                            const TfLiteNode* tflite_node,
2067                            const TfLiteRegistration* registration) final {
2068     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2069     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2070   }
2071 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2072   absl::Status Parse(const TfLiteNode* tflite_node,
2073                      const TfLiteRegistration* registration,
2074                      GraphFloat32* graph, ObjectReader* reader) final {
2075     Node* node = graph->NewNode();
2076     node->operation.type = ToString(OperationType::SLICE);
2077     RETURN_IF_ERROR(reader->AddOutputs(node));
2078     Value* input;
2079     RETURN_IF_ERROR(reader->ReadValue(0, &input));
2080     RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
2081 
2082     const TfLiteTensor* tfl_input = reader->GetInputTensor(0);
2083     const int input_dims = tfl_input->dims->size;
2084 
2085     SliceAttributes attr;
2086     attr.strides = BHWC(1, 1, 1, 1);
2087     Tensor<Linear, DataType::INT32> starts, sizes;
2088     RETURN_IF_ERROR(reader->ReadTensor(1, &starts));
2089     RETURN_IF_ERROR(reader->ReadTensor(2, &sizes));
2090     if (starts.data.size() != sizes.data.size()) {
2091       return absl::InvalidArgumentError("Starts amount != sizes amount.");
2092     }
2093     BHWC bhwc_starts(0, 0, 0, 0);
2094     BHWC bhwc_sizes = input->tensor.shape;
2095     if (input_dims == 4) {
2096       // input in BHWC layout
2097       if (starts.data.size() == 4) {
2098         bhwc_starts.b = starts.data[0];
2099         bhwc_starts.h = starts.data[1];
2100         bhwc_starts.w = starts.data[2];
2101         bhwc_starts.c = starts.data[3];
2102         bhwc_sizes.b = sizes.data[0];
2103         bhwc_sizes.h = sizes.data[1];
2104         bhwc_sizes.w = sizes.data[2];
2105         bhwc_sizes.c = sizes.data[3];
2106       } else if (starts.data.size() == 3) {
2107         // if input is 4D(BHWC) and args 3D, we assume that args in HWC layout
2108         bhwc_starts.h = starts.data[0];
2109         bhwc_starts.w = starts.data[1];
2110         bhwc_starts.c = starts.data[2];
2111         bhwc_sizes.h = sizes.data[0];
2112         bhwc_sizes.w = sizes.data[1];
2113         bhwc_sizes.c = sizes.data[2];
2114       } else {
2115         return absl::UnimplementedError(
2116             "Slicing is supported for 3 or 4 dimensional tensors only.");
2117       }
2118     } else if (input_dims == 3) {
2119       // input in BWC layout
2120       if (starts.data.size() == 3) {
2121         bhwc_starts.b = starts.data[0];
2122         bhwc_starts.w = starts.data[1];
2123         bhwc_starts.c = starts.data[2];
2124         bhwc_sizes.b = sizes.data[0];
2125         bhwc_sizes.w = sizes.data[1];
2126         bhwc_sizes.c = sizes.data[2];
2127       } else {
2128         return absl::UnimplementedError(
2129             "Slicing is supported for 3 or 4 dimensional tensors only.");
2130       }
2131     } else {
2132       return absl::UnimplementedError(
2133           "Slicing is supported for 3 or 4 dimensional tensors only.");
2134     }
2135     const auto& in_shape = input->tensor.shape;
2136     if (bhwc_sizes.b == -1) {
2137       bhwc_sizes.b = in_shape.b - bhwc_starts.b;
2138     }
2139     if (bhwc_sizes.h == -1) {
2140       bhwc_sizes.h = in_shape.h - bhwc_starts.h;
2141     }
2142     if (bhwc_sizes.w == -1) {
2143       bhwc_sizes.w = in_shape.w - bhwc_starts.w;
2144     }
2145     if (bhwc_sizes.c == -1) {
2146       bhwc_sizes.c = in_shape.c - bhwc_starts.c;
2147     }
2148     attr.starts = bhwc_starts;
2149     attr.ends =
2150         BHWC(bhwc_starts.b + bhwc_sizes.b, bhwc_starts.h + bhwc_sizes.h,
2151              bhwc_starts.w + bhwc_sizes.w, bhwc_starts.c + bhwc_sizes.c);
2152     RETURN_IF_ERROR(UpdateIfNegative(in_shape, &attr));
2153 
2154     auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
2155     if ((attr.ends.b - attr.starts.b) != out_shape.b) {
2156       return absl::UnimplementedError("Output batch don't match");
2157     }
2158     if ((attr.ends.h - attr.starts.h) != out_shape.h) {
2159       return absl::UnimplementedError("Output height doesn't match");
2160     }
2161     if ((attr.ends.w - attr.starts.w) != out_shape.w) {
2162       return absl::UnimplementedError("Output width doesn't match");
2163     }
2164     if ((attr.ends.c - attr.starts.c) != out_shape.c) {
2165       return absl::UnimplementedError("Output channels don't match");
2166     }
2167     node->operation.attributes = attr;
2168     return absl::OkStatus();
2169   }
2170 
2171  private:
UpdateIfNegative(const BHWC & input_shape,SliceAttributes * attr)2172   absl::Status UpdateIfNegative(const BHWC& input_shape,
2173                                 SliceAttributes* attr) {
2174     if (attr->ends.h < 0) {
2175       attr->ends.h = input_shape.h + attr->ends.h;
2176     }
2177     if (attr->ends.w < 0) {
2178       attr->ends.w = input_shape.w + attr->ends.w;
2179     }
2180     if (attr->ends.c < 0) {
2181       attr->ends.c = input_shape.c + attr->ends.c;
2182     }
2183     if (attr->ends.b < 0) {
2184       attr->ends.b = input_shape.b + attr->ends.b;
2185     }
2186     return absl::OkStatus();
2187   }
2188 };
2189 
2190 class SoftmaxOperationParser : public TFLiteOperationParser {
2191  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2192   absl::Status IsSupported(const TfLiteContext* context,
2193                            const TfLiteNode* tflite_node,
2194                            const TfLiteRegistration* registration) final {
2195     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2196     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2197   }
2198 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2199   absl::Status Parse(const TfLiteNode* tflite_node,
2200                      const TfLiteRegistration* registration,
2201                      GraphFloat32* graph, ObjectReader* reader) final {
2202     Node* node = graph->NewNode();
2203     node->operation.type = ToString(OperationType::SOFTMAX);
2204     RETURN_IF_ERROR(reader->AddInput(node, 0));
2205     RETURN_IF_ERROR(reader->AddOutputs(node));
2206 
2207     const TfLiteSoftmaxParams* tf_options;
2208     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2209     if (tf_options->beta != 1) {
2210       // there is multiply by scalar operation fused in softmax. Make a layer
2211       // out of it before softmax.
2212       return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
2213       // auto mul_node = reader->NewPassthroughNode(node);
2214       // mul_node->operation.type = ToString(OperationType::MUL);
2215     }
2216     SoftmaxAttributes attr;
2217     attr.axis = Axis::CHANNELS;  // always by channels
2218     node->operation.attributes = attr;
2219     return absl::OkStatus();
2220   }
2221 };
2222 
2223 class SpaceToDepthOperationParser : public TFLiteOperationParser {
2224  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2225   absl::Status IsSupported(const TfLiteContext* context,
2226                            const TfLiteNode* tflite_node,
2227                            const TfLiteRegistration* registration) final {
2228     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2229     // TODO(impjdi): Dims check.
2230     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2231   }
2232 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2233   absl::Status Parse(const TfLiteNode* tflite_node,
2234                      const TfLiteRegistration* registration,
2235                      GraphFloat32* graph, ObjectReader* reader) final {
2236     Node* node = graph->NewNode();
2237     node->operation.type = ToString(OperationType::SPACE_TO_DEPTH);
2238     RETURN_IF_ERROR(reader->AddInput(node, 0));
2239     RETURN_IF_ERROR(reader->AddOutputs(node));
2240     const TfLiteSpaceToDepthParams* tf_options;
2241     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2242     SpaceToDepthAttributes attr;
2243     attr.block_size = tf_options->block_size;
2244     node->operation.attributes = attr;
2245     return absl::OkStatus();
2246   }
2247 };
2248 
2249 class SplitOperationParser : public TFLiteOperationParser {
2250  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2251   absl::Status IsSupported(const TfLiteContext* context,
2252                            const TfLiteNode* tflite_node,
2253                            const TfLiteRegistration* registration) final {
2254     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2255   }
2256 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2257   absl::Status Parse(const TfLiteNode* tflite_node,
2258                      const TfLiteRegistration* registration,
2259                      GraphFloat32* graph, ObjectReader* reader) final {
2260     const TfLiteSplitParams* split_params;
2261     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
2262     if (split_params->num_splits == 1) {
2263       // Adding Identity reshape that will be removed.
2264       Node* node = graph->NewNode();
2265       node->operation.type = ToString(OperationType::RESHAPE);
2266       RETURN_IF_ERROR(reader->AddInput(node, 1));
2267       RETURN_IF_ERROR(reader->AddOutputs(node));
2268       // New shape comes from output shape.
2269       ReshapeAttributes attr;
2270       attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
2271       node->operation.attributes = attr;
2272       return absl::OkStatus();
2273     }
2274     const TfLiteTensor* input = reader->GetInputTensor(1);
2275     const TfLiteTensor* axis_tensor = reader->GetInputTensor(0);
2276     SplitAttributes attr;
2277     RETURN_IF_ERROR(
2278         ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
2279 
2280     Node* node = graph->NewNode();
2281     node->operation.type = ToString(OperationType::SPLIT);
2282     node->operation.attributes = attr;
2283     RETURN_IF_ERROR(reader->AddInput(node, 1));
2284     for (int i = 0; i < tflite_node->outputs->size; ++i) {
2285       RETURN_IF_ERROR(reader->AddOutput(node, i));
2286     }
2287     return absl::OkStatus();
2288   }
2289 };
2290 
2291 class SplitVOperationParser : public TFLiteOperationParser {
2292  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2293   absl::Status IsSupported(const TfLiteContext* context,
2294                            const TfLiteNode* tflite_node,
2295                            const TfLiteRegistration* registration) final {
2296     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2297   }
2298 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2299   absl::Status Parse(const TfLiteNode* tflite_node,
2300                      const TfLiteRegistration* registration,
2301                      GraphFloat32* graph, ObjectReader* reader) final {
2302     const TfLiteSplitVParams* split_params;
2303     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
2304     if (split_params->num_splits == 1) {
2305       // Adding Identity reshape that will be removed.
2306       Node* node = graph->NewNode();
2307       node->operation.type = ToString(OperationType::RESHAPE);
2308       RETURN_IF_ERROR(reader->AddInput(node, 0));
2309       RETURN_IF_ERROR(reader->AddOutputs(node));
2310       // New shape comes from output shape.
2311       ReshapeAttributes attr;
2312       attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
2313       node->operation.attributes = attr;
2314       return absl::OkStatus();
2315     }
2316     const TfLiteTensor* input = reader->GetInputTensor(0);
2317     const TfLiteTensor* axis_tensor = reader->GetInputTensor(2);
2318     SplitAttributes attr;
2319     RETURN_IF_ERROR(
2320         ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
2321 
2322     Node* node = graph->NewNode();
2323     node->operation.type = ToString(OperationType::SPLIT);
2324     node->operation.attributes = attr;
2325     RETURN_IF_ERROR(reader->AddInput(node, 0));
2326     for (int i = 0; i < tflite_node->outputs->size; ++i) {
2327       RETURN_IF_ERROR(reader->AddOutput(node, i));
2328     }
2329     return absl::OkStatus();
2330   }
2331 };
2332 
2333 class StridedSliceOperationParser : public TFLiteOperationParser {
2334  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2335   absl::Status IsSupported(const TfLiteContext* context,
2336                            const TfLiteNode* tflite_node,
2337                            const TfLiteRegistration* registration) final {
2338     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2339     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2340   }
2341 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2342   absl::Status Parse(const TfLiteNode* tflite_node,
2343                      const TfLiteRegistration* registration,
2344                      GraphFloat32* graph, ObjectReader* reader) final {
2345     Node* node = graph->NewNode();
2346     node->operation.type = ToString(OperationType::SLICE);
2347     RETURN_IF_ERROR(reader->AddOutputs(node));
2348     Value* input;
2349     RETURN_IF_ERROR(reader->ReadValue(0, &input));
2350     RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
2351 
2352     Tensor<Linear, DataType::INT32> tmp;
2353     RETURN_IF_ERROR(reader->ReadTensor(1, &tmp));
2354 
2355     bool read_without_batch = tmp.data.size() == 3;
2356     bool read_with_batch = tmp.data.size() == 4;
2357     if (!read_without_batch && !read_with_batch) {
2358       // Error: Must be catched in IsSupported()
2359       return absl::UnimplementedError(
2360           "Slicing is supported for 3 or 4 dimensional tensors only.");
2361     }
2362 
2363     const TfLiteStridedSliceParams* tf_options;
2364     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2365     RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
2366 
2367     auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
2368 
2369     SliceAttributes attr;
2370     if (read_without_batch) {
2371       RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options,
2372                                               input->tensor.shape, &attr));
2373     }
2374     if (read_with_batch) {
2375       RETURN_IF_ERROR(
2376           ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr));
2377     }
2378     if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 ||
2379         attr.strides.c == 0) {
2380       return absl::InvalidArgumentError("stride values must be non-zero");
2381     }
2382     if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 ||
2383         attr.strides.c < 0) {
2384       return absl::UnimplementedError("Reverse slices are not supported.");
2385     }
2386     if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b !=
2387         out_shape.b) {
2388       return absl::UnimplementedError("Output batch don't match");
2389     }
2390     if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h !=
2391         out_shape.h) {
2392       return absl::UnimplementedError("Output height doesn't match");
2393     }
2394     if ((attr.ends.w - attr.starts.w + attr.strides.w - 1) / attr.strides.w !=
2395         out_shape.w) {
2396       return absl::UnimplementedError("Output width doesn't match");
2397     }
2398     if ((attr.ends.c - attr.starts.c + attr.strides.c - 1) / attr.strides.c !=
2399         out_shape.c) {
2400       return absl::UnimplementedError("Output channels don't match");
2401     }
2402     node->operation.attributes = attr;
2403     return absl::OkStatus();
2404   }
2405 
2406  private:
UpdateWithMask(const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,int ignore_b,int ignore_h,int ignore_w,int ignore_c,SliceAttributes * attr)2407   absl::Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options,
2408                               const BHWC& input_shape, int ignore_b,
2409                               int ignore_h, int ignore_w, int ignore_c,
2410                               SliceAttributes* attr) {
2411     if (tf_options->begin_mask & ignore_h) {
2412       attr->starts.h = 0;
2413     }
2414     if (tf_options->begin_mask & ignore_w) {
2415       attr->starts.w = 0;
2416     }
2417     if (tf_options->begin_mask & ignore_c) {
2418       attr->starts.c = 0;
2419     }
2420     if (tf_options->begin_mask & ignore_b) {
2421       attr->starts.b = 0;
2422     }
2423 
2424     if (tf_options->end_mask & ignore_h) {
2425       attr->ends.h = input_shape.h;
2426     }
2427     if (tf_options->end_mask & ignore_w) {
2428       attr->ends.w = input_shape.w;
2429     }
2430     if (tf_options->end_mask & ignore_c) {
2431       attr->ends.c = input_shape.c;
2432     }
2433     if (tf_options->end_mask & ignore_b) {
2434       attr->ends.b = input_shape.b;
2435     }
2436     return absl::OkStatus();
2437   }
2438 
UpdateIfNegative(const BHWC & input_shape,SliceAttributes * attr)2439   absl::Status UpdateIfNegative(const BHWC& input_shape,
2440                                 SliceAttributes* attr) {
2441     if (attr->ends.h < 0) {
2442       attr->ends.h = input_shape.h + attr->ends.h;
2443     }
2444     if (attr->ends.w < 0) {
2445       attr->ends.w = input_shape.w + attr->ends.w;
2446     }
2447     if (attr->ends.c < 0) {
2448       attr->ends.c = input_shape.c + attr->ends.c;
2449     }
2450     if (attr->ends.b < 0) {
2451       attr->ends.b = input_shape.b + attr->ends.b;
2452     }
2453 
2454     if (attr->starts.h < 0) {
2455       attr->starts.h = input_shape.h + attr->starts.h;
2456     }
2457     if (attr->starts.w < 0) {
2458       attr->starts.w = input_shape.w + attr->starts.w;
2459     }
2460     if (attr->starts.c < 0) {
2461       attr->starts.c = input_shape.c + attr->starts.c;
2462     }
2463     if (attr->starts.b < 0) {
2464       attr->starts.b = input_shape.b + attr->starts.b;
2465     }
2466 
2467     return absl::OkStatus();
2468   }
2469 
ReadAttribsWithBatch(const ObjectReader * reader,const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,SliceAttributes * attr)2470   absl::Status ReadAttribsWithBatch(const ObjectReader* reader,
2471                                     const TfLiteStridedSliceParams* tf_options,
2472                                     const BHWC& input_shape,
2473                                     SliceAttributes* attr) {
2474     auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
2475       Tensor<Linear, DataType::INT32> t;
2476       RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
2477       *bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]);
2478       return absl::OkStatus();
2479     };
2480 
2481     RETURN_IF_ERROR(read_bhwc(1, &attr->starts));
2482     RETURN_IF_ERROR(read_bhwc(2, &attr->ends));
2483     RETURN_IF_ERROR(read_bhwc(3, &attr->strides));
2484     RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
2485     RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr));
2486     return absl::OkStatus();
2487   }
2488 
ReadAttribsWithoutBatch(const ObjectReader * reader,const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,SliceAttributes * attr)2489   absl::Status ReadAttribsWithoutBatch(
2490       const ObjectReader* reader, const TfLiteStridedSliceParams* tf_options,
2491       const BHWC& input_shape, SliceAttributes* attr) {
2492     auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
2493       Tensor<Linear, DataType::INT32> t;
2494       RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
2495       *bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]);
2496       return absl::OkStatus();
2497     };
2498 
2499     RETURN_IF_ERROR(read_hwc(1, &attr->starts));
2500     RETURN_IF_ERROR(read_hwc(2, &attr->ends));
2501     RETURN_IF_ERROR(read_hwc(3, &attr->strides));
2502     RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
2503     RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr));
2504     attr->starts.b = 0;
2505     attr->ends.b = input_shape.b;
2506     attr->strides.b = 1;
2507     return absl::OkStatus();
2508   }
CheckOptionsSupport(const TfLiteStridedSliceParams * tf_options)2509   absl::Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) {
2510     if (tf_options->ellipsis_mask) {
2511       return absl::UnimplementedError("Slice does not support ellipsis_mask.");
2512     }
2513     if (tf_options->new_axis_mask) {
2514       return absl::UnimplementedError("Slice does not support new_axis_mask.");
2515     }
2516     if (tf_options->shrink_axis_mask) {
2517       return absl::UnimplementedError(
2518           "Slice does not support shrink_axis_mask parameter. ");
2519     }
2520     return absl::OkStatus();
2521   }
2522 };
2523 
2524 class TileOperationParser : public TFLiteOperationParser {
2525  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2526   absl::Status IsSupported(const TfLiteContext* context,
2527                            const TfLiteNode* tflite_node,
2528                            const TfLiteRegistration* registration) final {
2529     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2530   }
2531 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2532   absl::Status Parse(const TfLiteNode* tflite_node,
2533                      const TfLiteRegistration* registration,
2534                      GraphFloat32* graph, ObjectReader* reader) final {
2535     Node* node = graph->NewNode();
2536     node->operation.type = ToString(OperationType::TILE);
2537     RETURN_IF_ERROR(reader->AddInput(node, 0));
2538     RETURN_IF_ERROR(reader->AddOutputs(node));
2539     return absl::OkStatus();
2540   }
2541 };
2542 
2543 // Builtin op version of TRANSPOSE_CONV.
2544 class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
2545  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2546   absl::Status IsSupported(const TfLiteContext* context,
2547                            const TfLiteNode* tflite_node,
2548                            const TfLiteRegistration* registration) final {
2549     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
2550     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2551   }
2552 
2553   // TFLite's TRANSPOSE_CONV expects 3-4 input tensors (output shape, weights,
2554   // input, and an optional bias) and allows configurable padding & stride.
2555   // TODO(impjdi): Translate output_shape to attr.adjacent.
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2556   absl::Status Parse(const TfLiteNode* tflite_node,
2557                      const TfLiteRegistration* registration,
2558                      GraphFloat32* graph, ObjectReader* reader) final {
2559     auto* node = graph->NewNode();
2560     node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
2561     Value* input;
2562     RETURN_IF_ERROR(reader->ReadValue(2, &input));
2563     RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
2564     RETURN_IF_ERROR(reader->AddOutputs(node));
2565 
2566     const TfLiteTransposeConvParams* tf_options;
2567     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2568 
2569     ConvolutionTransposedAttributes attr;
2570     attr.stride = tf_options
2571                       ? HW(tf_options->stride_height, tf_options->stride_width)
2572                       : HW(1, 1);
2573     const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
2574     if (runtime_inputs == 2) {
2575       RETURN_IF_ERROR(reader->AddInput(node, 1));
2576       auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape;
2577       attr.weights.shape = OHWI(weights_shape.b, weights_shape.h,
2578                                 weights_shape.w, weights_shape.c);
2579     } else {  // runtime_inputs == 1;
2580       RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
2581     }
2582     reader->ReadTensor(3, &attr.bias).IgnoreError();  // bias is optional
2583 
2584     UpdatePadding(tf_options->padding,
2585                   graph->FindInputs(node->id)[0]->tensor.shape, &attr);
2586     node->operation.attributes = std::move(attr);
2587     return absl::OkStatus();
2588   }
2589 };
2590 
2591 // Custom op version of TRANSPOSE_CONV.
2592 class TransposeConvCustomOperationParser : public TFLiteOperationParser {
2593  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2594   absl::Status IsSupported(const TfLiteContext* context,
2595                            const TfLiteNode* tflite_node,
2596                            const TfLiteRegistration* registration) final {
2597     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2598   }
2599 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2600   absl::Status Parse(const TfLiteNode* tflite_node,
2601                      const TfLiteRegistration* registration,
2602                      GraphFloat32* graph, ObjectReader* reader) final {
2603     auto* node = graph->NewNode();
2604     node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
2605     RETURN_IF_ERROR(reader->AddInput(node, 0));
2606     RETURN_IF_ERROR(reader->AddOutputs(node));
2607 
2608     const TfLiteTransposeConvParams* tf_options;
2609     auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
2610 
2611     ConvolutionTransposedAttributes attr;
2612     attr.stride = status.ok()
2613                       ? HW(tf_options->stride_height, tf_options->stride_width)
2614                       : HW(1, 1);
2615     RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
2616     reader->ReadTensor(2, &attr.bias).IgnoreError();  // bias is optional
2617 
2618     UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
2619                   graph->FindInputs(node->id)[0]->tensor.shape, &attr);
2620     node->operation.attributes = std::move(attr);
2621     return absl::OkStatus();
2622   }
2623 };
2624 
2625 class TransposeOperationParser : public TFLiteOperationParser {
2626  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2627   absl::Status IsSupported(const TfLiteContext* context,
2628                            const TfLiteNode* tflite_node,
2629                            const TfLiteRegistration* registration) final {
2630     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2631     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2632   }
2633 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2634   absl::Status Parse(const TfLiteNode* tflite_node,
2635                      const TfLiteRegistration* registration,
2636                      GraphFloat32* graph, ObjectReader* reader) final {
2637     Node* node = graph->NewNode();
2638     node->operation.type = ToString(OperationType::TRANSPOSE);
2639     RETURN_IF_ERROR(reader->AddInput(node, 0));
2640     RETURN_IF_ERROR(reader->AddOutputs(node));
2641 
2642     TransposeAttributes attr;
2643     Tensor<Linear, DataType::INT32> perm;
2644     RETURN_IF_ERROR(reader->ReadTensor(1, &perm));
2645     std::map<Axis, int> axis_to_index = {{Axis::BATCH, 0},
2646                                          {Axis::HEIGHT, 1},
2647                                          {Axis::WIDTH, 2},
2648                                          {Axis::CHANNELS, 3}};
2649     if (perm.data.size() == 4) {
2650       attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[2], perm.data[3]);
2651     } else if (perm.data.size() == 3) {
2652       std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::WIDTH,
2653                                          Axis::CHANNELS};
2654       attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
2655       attr.perm.h = 1;
2656       attr.perm.w = axis_to_index[index_to_axis[perm.data[1]]];
2657       attr.perm.c = axis_to_index[index_to_axis[perm.data[2]]];
2658     } else if (perm.data.size() == 2) {
2659       std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::CHANNELS};
2660       attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
2661       attr.perm.h = 1;
2662       attr.perm.w = 2;
2663       attr.perm.c = axis_to_index[index_to_axis[perm.data[1]]];
2664     } else {
2665       return absl::InvalidArgumentError(
2666           "Permutation for transpose is invalid.");
2667     }
2668 
2669     node->operation.attributes = attr;
2670     return absl::OkStatus();
2671   }
2672 };
2673 
2674 class UnpackOperationParser : public TFLiteOperationParser {
2675  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2676   absl::Status IsSupported(const TfLiteContext* context,
2677                            const TfLiteNode* tflite_node,
2678                            const TfLiteRegistration* registration) final {
2679     return absl::OkStatus();
2680   }
2681 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2682   absl::Status Parse(const TfLiteNode* tflite_node,
2683                      const TfLiteRegistration* registration,
2684                      GraphFloat32* graph, ObjectReader* reader) final {
2685     const TfLiteUnpackParams* unpack_params;
2686     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &unpack_params));
2687     if (unpack_params->num == 1) {
2688       // Adding Identity reshape that will be removed.
2689       Node* node = graph->NewNode();
2690       node->operation.type = ToString(OperationType::RESHAPE);
2691       RETURN_IF_ERROR(reader->AddInput(node, 1));
2692       RETURN_IF_ERROR(reader->AddOutputs(node));
2693       // New shape comes from output shape.
2694       ReshapeAttributes attr;
2695       attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
2696       node->operation.attributes = attr;
2697       return absl::OkStatus();
2698     }
2699     const TfLiteTensor* input = reader->GetInputTensor(0);
2700     BHWC input_shape;
2701     RETURN_IF_ERROR(ExtractTensorShape(*input, &input_shape));
2702     SplitAttributes attr;
2703     RETURN_IF_ERROR(
2704         ExtractAxisFromIndex(*input, unpack_params->axis, &attr.axis));
2705     BHWC output_required_shape = input_shape;
2706     output_required_shape.set(attr.axis, 1);
2707 
2708     Node* node = graph->NewNode();
2709     node->operation.type = ToString(OperationType::SPLIT);
2710     node->operation.attributes = attr;
2711     RETURN_IF_ERROR(reader->AddInput(node, 0));
2712     auto input_value = graph->FindInputs(node->id)[0];
2713     for (int i = 0; i < tflite_node->outputs->size; ++i) {
2714       const TfLiteTensor* output = reader->GetOutputTensor(i);
2715       BHWC output_shape;
2716       RETURN_IF_ERROR(ExtractTensorShape(*output, &output_shape));
2717       if (output_shape != output_required_shape) {
2718         // GPU delegates does not support implicit shapes transformations
2719         // adding explicit Reshape
2720         Value* copy_value = graph->NewValue();
2721         copy_value->tensor.type = input_value->tensor.type;
2722         copy_value->tensor.shape = output_required_shape;
2723         RETURN_IF_ERROR(graph->SetProducer(node->id, copy_value->id));
2724         Node* node_reshape = graph->NewNode();
2725         node_reshape->operation.type = ToString(OperationType::RESHAPE);
2726         ReshapeAttributes reshape_attr;
2727         reshape_attr.new_shape = output_shape;
2728         node_reshape->operation.attributes = reshape_attr;
2729         RETURN_IF_ERROR(graph->AddConsumer(node_reshape->id, copy_value->id));
2730         RETURN_IF_ERROR(reader->AddOutput(node_reshape, i));
2731       } else {
2732         RETURN_IF_ERROR(reader->AddOutput(node, i));
2733       }
2734     }
2735     return absl::OkStatus();
2736   }
2737 };
2738 
2739 class Unpooling2DOperationParser : public TFLiteOperationParser {
2740  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2741   absl::Status IsSupported(const TfLiteContext* context,
2742                            const TfLiteNode* tflite_node,
2743                            const TfLiteRegistration* registration) final {
2744     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2745   }
2746 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2747   absl::Status Parse(const TfLiteNode* tflite_node,
2748                      const TfLiteRegistration* registration,
2749                      GraphFloat32* graph, ObjectReader* reader) final {
2750     Node* node = graph->NewNode();
2751     node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D);
2752     RETURN_IF_ERROR(reader->AddInput(node, 0));
2753     RETURN_IF_ERROR(reader->AddInput(node, 1));
2754     RETURN_IF_ERROR(reader->AddOutputs(node));
2755     auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
2756     MaxUnpooling2DAttributes attr;
2757 
2758     const TfLitePoolParams* tf_options;
2759     RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
2760 
2761     attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
2762     attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
2763     UpdatePadding(tf_options->padding, input_shape, &attr);
2764 
2765     node->operation.attributes = attr;
2766 
2767     auto output_value = graph->FindOutputs(node->id)[0];
2768     output_value->tensor.shape = CalculateOutputShape(input_shape, attr);
2769     return absl::OkStatus();
2770   }
2771 };
2772 
2773 // TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
2774 class BatchToSpaceOperationParser : public TFLiteOperationParser {
2775  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2776   absl::Status IsSupported(const TfLiteContext* context,
2777                            const TfLiteNode* tflite_node,
2778                            const TfLiteRegistration* registration) final {
2779     return absl::OkStatus();
2780   }
2781 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2782   absl::Status Parse(const TfLiteNode* tflite_node,
2783                      const TfLiteRegistration* registration,
2784                      GraphFloat32* graph, ObjectReader* reader) final {
2785     auto* node = graph->NewNode();
2786     node->operation.type = ToString(OperationType::BATCH_TO_SPACE);
2787     RETURN_IF_ERROR(reader->AddInput(node, 0));
2788     RETURN_IF_ERROR(reader->AddOutputs(node));
2789 
2790     BatchToSpaceAttributes bs_attr;
2791     Tensor<Linear, DataType::INT32> block;
2792     RETURN_IF_ERROR(reader->ReadTensor(1, &block));
2793     if (block.shape.v != 2) {
2794       return absl::InternalError("Space has to be HxW.");
2795     }
2796     bs_attr.block.h = block.data[0];
2797     bs_attr.block.w = block.data[1];
2798 
2799     Tensor<HW, DataType::INT32> crop;
2800     RETURN_IF_ERROR(reader->ReadTensor(2, &crop));
2801     auto crop_shape = crop.shape;
2802     if (crop_shape.h != 2 && crop_shape.w != 2) {
2803       return absl::InternalError("Space has to be HxW.");
2804     }
2805 
2806     bs_attr.crop.prepended.h = crop.data[0];
2807     bs_attr.crop.prepended.w = crop.data[2];
2808 
2809     bs_attr.crop.appended.h = crop.data[1];
2810     bs_attr.crop.appended.w = crop.data[3];
2811 
2812     node->operation.attributes = std::move(bs_attr);
2813     return absl::OkStatus();
2814   }
2815 };
2816 
2817 class SpaceToBatchOperationParser : public TFLiteOperationParser {
2818  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2819   absl::Status IsSupported(const TfLiteContext* context,
2820                            const TfLiteNode* tflite_node,
2821                            const TfLiteRegistration* registration) final {
2822     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2823   }
2824 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2825   absl::Status Parse(const TfLiteNode* tflite_node,
2826                      const TfLiteRegistration* registration,
2827                      GraphFloat32* graph, ObjectReader* reader) final {
2828     auto* node = graph->NewNode();
2829     node->operation.type = ToString(OperationType::SPACE_TO_BATCH);
2830     RETURN_IF_ERROR(reader->AddInput(node, 0));
2831     RETURN_IF_ERROR(reader->AddOutputs(node));
2832     SpaceToBatchAttributes sb_attr;
2833     Tensor<Linear, DataType::INT32> block;
2834     RETURN_IF_ERROR(reader->ReadTensor(1, &block));
2835     if (block.shape.v != 2) {
2836       return absl::InternalError("Space has to be HxW.");
2837     }
2838     sb_attr.block.h = block.data[0];
2839     sb_attr.block.w = block.data[1];
2840 
2841     Tensor<HW, DataType::INT32> padding;
2842     RETURN_IF_ERROR(reader->ReadTensor(2, &padding));
2843     auto padding_shape = padding.shape;
2844 
2845     if (padding_shape.h != 2 && padding_shape.w != 2) {
2846       return absl::InternalError("Space has to be HxW.");
2847     }
2848 
2849     sb_attr.padding.prepended.h = padding.data[0];
2850     sb_attr.padding.prepended.w = padding.data[2];
2851 
2852     sb_attr.padding.appended.h = padding.data[1];
2853     sb_attr.padding.appended.w = padding.data[3];
2854 
2855     node->operation.attributes = std::move(sb_attr);
2856     return absl::OkStatus();
2857   }
2858 };
2859 
2860 class MeanOperationParser : public TFLiteOperationParser {
2861  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2862   absl::Status IsSupported(const TfLiteContext* context,
2863                            const TfLiteNode* tflite_node,
2864                            const TfLiteRegistration* registration) final {
2865     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2866   }
2867 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2868   absl::Status Parse(const TfLiteNode* tflite_node,
2869                      const TfLiteRegistration* registration,
2870                      GraphFloat32* graph, ObjectReader* reader) final {
2871     auto* node = graph->NewNode();
2872     node->operation.type = ToString(OperationType::MEAN);
2873     RETURN_IF_ERROR(reader->AddInput(node, 0));
2874     RETURN_IF_ERROR(reader->AddOutputs(node));
2875 
2876     MeanAttributes attr;
2877     const TfLiteTensor* input = reader->GetInputTensor(0);
2878     const TfLiteTensor* axes = reader->GetInputTensor(1);
2879     for (int i = 0; i < NumElements(axes->dims); i++) {
2880       Axis axis;
2881       RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes->data.i32[i], &axis));
2882       attr.dims.insert(axis);
2883     }
2884     node->operation.attributes = attr;
2885     return absl::OkStatus();
2886   }
2887 };
2888 
2889 class UnsupportedOperationParser : public TFLiteOperationParser {
2890  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2891   absl::Status IsSupported(const TfLiteContext* context,
2892                            const TfLiteNode* tflite_node,
2893                            const TfLiteRegistration* registration) final {
2894     return absl::UnimplementedError("Operation is not supported.");
2895   }
2896 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2897   absl::Status Parse(const TfLiteNode* tflite_node,
2898                      const TfLiteRegistration* registration,
2899                      GraphFloat32* graph, ObjectReader* reader) final {
2900     return absl::UnimplementedError("Operation is not supported.");
2901   }
2902 };
2903 
IsSupported(const TfLiteContext * context,TfLiteNode * node,const TfLiteRegistration * registration,bool allow_quant_ops=false,const absl::flat_hash_set<TfLiteBuiltinOperator> * excluded_ops=nullptr)2904 absl::Status IsSupported(
2905     const TfLiteContext* context, TfLiteNode* node,
2906     const TfLiteRegistration* registration, bool allow_quant_ops = false,
2907     const absl::flat_hash_set<TfLiteBuiltinOperator>* excluded_ops = nullptr) {
2908   return NewOperationParser(registration, allow_quant_ops, excluded_ops)
2909       ->IsSupported(context, node, registration);
2910 }
2911 
IsAllAllowedTensors(TfLiteContext * context,const TfLiteIntArray * tensor_indices,const std::vector<TfLiteType> & allowed_types)2912 bool IsAllAllowedTensors(TfLiteContext* context,
2913                          const TfLiteIntArray* tensor_indices,
2914                          const std::vector<TfLiteType>& allowed_types) {
2915   for (int i = 0; i < tensor_indices->size; ++i) {
2916     int tensor_idx = tensor_indices->data[i];
2917     if (tensor_idx == kTfLiteOptionalTensor) continue;
2918     const TfLiteTensor* t = &context->tensors[tensor_idx];
2919     if (t->dims && t->dims->size >= 5) {
2920       return false;
2921     }
2922     bool type_supported = false;
2923     for (auto allowed_type : allowed_types) {
2924       if (t->type == allowed_type) {
2925         type_supported = true;
2926         break;
2927       }
2928     }
2929     if (t->allocation_type == kTfLiteArenaRw && !type_supported) {
2930       return false;
2931     }
2932   }
2933   return true;
2934 }
2935 }  // namespace
2936 
NewOperationParser(const TfLiteRegistration * registration,bool allow_quant_ops,const absl::flat_hash_set<TfLiteBuiltinOperator> * excluded_ops)2937 std::unique_ptr<TFLiteOperationParser> NewOperationParser(
2938     const TfLiteRegistration* registration, bool allow_quant_ops,
2939     const absl::flat_hash_set<TfLiteBuiltinOperator>* excluded_ops) {
2940   const auto builtin_code = registration->builtin_code;
2941   if (excluded_ops != nullptr &&
2942       excluded_ops->contains(
2943           static_cast<TfLiteBuiltinOperator>(builtin_code))) {
2944     return std::make_unique<UnsupportedOperationParser>();
2945   }
2946   switch (builtin_code) {
2947     case kTfLiteBuiltinAbs:
2948       return std::make_unique<ElementwiseOperationParser>(OperationType::ABS);
2949     case kTfLiteBuiltinAdd:
2950       return std::make_unique<ElementwiseOperationParser>(OperationType::ADD);
2951     case kTfLiteBuiltinAveragePool2d:
2952       return std::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
2953     case kTfLiteBuiltinBatchMatmul:
2954       return std::make_unique<BatchedMatMulOperationParser>();
2955     case kTfLiteBuiltinCast:
2956       return std::make_unique<CastOperationParser>();
2957     case kTfLiteBuiltinConcatenation:
2958       return std::make_unique<ConcatenationOperationParser>();
2959     case kTfLiteBuiltinConv2d:
2960       return std::make_unique<Conv2DOperationParser>();
2961     case kTfLiteBuiltinCos:
2962       return std::make_unique<ElementwiseOperationParser>(OperationType::COS);
2963     case kTfLiteBuiltinCumsum:
2964       return std::make_unique<CumsumOperationParser>();
2965     case kTfLiteBuiltinDensify:
2966       return std::make_unique<DensifyOperationParser>();
2967     case kTfLiteBuiltinDepthwiseConv2d:
2968       return std::make_unique<DepthwiseConvolutionOperationParser>();
2969     case kTfLiteBuiltinDepthToSpace:
2970       return std::make_unique<DepthToSpaceOperationParser>();
2971     case kTfLiteBuiltinDequantize:
2972       if (allow_quant_ops) {
2973         return std::make_unique<DequantizeOperationParser>();
2974       }
2975       break;
2976     case kTfLiteBuiltinDiv:
2977       return std::make_unique<ElementwiseOperationParser>(OperationType::DIV);
2978     case kTfLiteBuiltinEqual:
2979       return std::make_unique<ElementwiseOperationParser>(OperationType::EQUAL);
2980     case kTfLiteBuiltinElu:
2981       return std::make_unique<ElementwiseOperationParser>(OperationType::ELU);
2982     case kTfLiteBuiltinExp:
2983       return std::make_unique<ElementwiseOperationParser>(OperationType::EXP);
2984     case kTfLiteBuiltinFloor:
2985       return std::make_unique<ElementwiseOperationParser>(OperationType::FLOOR);
2986     case kTfLiteBuiltinFloorDiv:
2987       return std::make_unique<ElementwiseOperationParser>(
2988           OperationType::FLOOR_DIV);
2989     case kTfLiteBuiltinFloorMod:
2990       return std::make_unique<ElementwiseOperationParser>(
2991           OperationType::FLOOR_MOD);
2992     case kTfLiteBuiltinFullyConnected:
2993       return std::make_unique<FullyConnectedOperationParser>();
2994     case kTfLiteBuiltinGreater:
2995       return std::make_unique<ElementwiseOperationParser>(
2996           OperationType::GREATER);
2997     case kTfLiteBuiltinGreaterEqual:
2998       return std::make_unique<ElementwiseOperationParser>(
2999           OperationType::GREATER_EQUAL);
3000     case kTfLiteBuiltinHardSwish:
3001       return std::make_unique<HardSwishOperationParser>();
3002     case kTfLiteBuiltinLess:
3003       return std::make_unique<ElementwiseOperationParser>(OperationType::LESS);
3004     case kTfLiteBuiltinLessEqual:
3005       return std::make_unique<ElementwiseOperationParser>(
3006           OperationType::LESS_EQUAL);
3007     case kTfLiteBuiltinLogistic:
3008       return std::make_unique<ElementwiseOperationParser>(
3009           OperationType::SIGMOID);
3010     case kTfLiteBuiltinLog:
3011       return std::make_unique<ElementwiseOperationParser>(OperationType::LOG);
3012     case kTfLiteBuiltinLstm:
3013       return std::make_unique<LSTMOperationParser>();
3014     case kTfLiteBuiltinMaximum:
3015       return std::make_unique<ElementwiseOperationParser>(
3016           OperationType::MAXIMUM);
3017     case kTfLiteBuiltinMaxPool2d:
3018       return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
3019     case kTfLiteBuiltinMean:
3020       return std::make_unique<MeanOperationParser>();
3021     case kTfLiteBuiltinMinimum:
3022       return std::make_unique<ElementwiseOperationParser>(
3023           OperationType::MINIMUM);
3024     case kTfLiteBuiltinMirrorPad:
3025       return std::make_unique<PadOperationParser>(/*mirror_pad=*/true);
3026     case kTfLiteBuiltinMul:
3027       return std::make_unique<ElementwiseOperationParser>(OperationType::MUL);
3028     case kTfLiteBuiltinNeg:
3029       return std::make_unique<ElementwiseOperationParser>(OperationType::NEG);
3030     case kTfLiteBuiltinNotEqual:
3031       return std::make_unique<ElementwiseOperationParser>(
3032           OperationType::NOT_EQUAL);
3033     case kTfLiteBuiltinOneHot:
3034       return std::make_unique<OneHotOperationParser>();
3035     case kTfLiteBuiltinPack:
3036       return std::make_unique<PackOperationParser>();
3037     case kTfLiteBuiltinPad:
3038       return std::make_unique<PadOperationParser>(/*mirror_pad=*/false);
3039     case kTfLiteBuiltinPow:
3040       return std::make_unique<ElementwiseOperationParser>(OperationType::POW);
3041     case kTfLiteBuiltinReduceMax:
3042       return std::make_unique<ReduceOperationParser>(
3043           OperationType::REDUCE_MAXIMUM);
3044     case kTfLiteBuiltinReduceMin:
3045       return std::make_unique<ReduceOperationParser>(
3046           OperationType::REDUCE_MINIMUM);
3047     case kTfLiteBuiltinReduceProd:
3048       return std::make_unique<ReduceOperationParser>(
3049           OperationType::REDUCE_PRODUCT);
3050     case kTfLiteBuiltinQuantize:
3051       if (allow_quant_ops) {
3052         return std::make_unique<QuantizeOperationParser>();
3053       }
3054       break;
3055     case kTfLiteBuiltinRelu:
3056       return std::make_unique<ReLUOperationParser>(0);
3057     case kTfLiteBuiltinRelu6:
3058       return std::make_unique<ReLUOperationParser>(6);
3059     case kTfLiteBuiltinReluN1To1:
3060       return std::make_unique<ClampOperationsParser>(-1.0, 1.0);
3061     case kTfLiteBuiltinLeakyRelu:
3062       return std::make_unique<ReLUOperationParser>(0);
3063     case kTfLiteBuiltinPrelu:
3064       return std::make_unique<PReLUOperationParser>();
3065     case kTfLiteBuiltinReshape:
3066       return std::make_unique<ReshapeOperationParser>();
3067     case kTfLiteBuiltinResizeBilinear:
3068       return std::make_unique<Resize2DOperationParser>(SamplingType::BILINEAR);
3069     case kTfLiteBuiltinResizeNearestNeighbor:
3070       return std::make_unique<Resize2DOperationParser>(SamplingType::NEAREST);
3071     case kTfLiteBuiltinRsqrt:
3072       return std::make_unique<ElementwiseOperationParser>(OperationType::RSQRT);
3073     case kTfLiteBuiltinSelectV2:
3074       return std::make_unique<SelectV2OperationParser>();
3075     case kTfLiteBuiltinSin:
3076       return std::make_unique<ElementwiseOperationParser>(OperationType::SIN);
3077     case kTfLiteBuiltinSlice:
3078       return std::make_unique<SliceOperationParser>();
3079     case kTfLiteBuiltinSoftmax:
3080       return std::make_unique<SoftmaxOperationParser>();
3081     case kTfLiteBuiltinSpaceToDepth:
3082       return std::make_unique<SpaceToDepthOperationParser>();
3083     case kTfLiteBuiltinSplit:
3084       return std::make_unique<SplitOperationParser>();
3085     case kTfLiteBuiltinSplitV:
3086       return std::make_unique<SplitVOperationParser>();
3087     case kTfLiteBuiltinSqrt:
3088       return std::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
3089     case kTfLiteBuiltinSquare:
3090       return std::make_unique<ElementwiseOperationParser>(
3091           OperationType::SQUARE);
3092     case kTfLiteBuiltinSquaredDifference:
3093       return std::make_unique<ElementwiseOperationParser>(
3094           OperationType::SQUARED_DIFF);
3095     case kTfLiteBuiltinStridedSlice:
3096       return std::make_unique<StridedSliceOperationParser>();
3097     case kTfLiteBuiltinSub:
3098       return std::make_unique<ElementwiseOperationParser>(OperationType::SUB);
3099     case kTfLiteBuiltinSum:
3100       return std::make_unique<ReduceOperationParser>(OperationType::REDUCE_SUM);
3101     case kTfLiteBuiltinTanh:
3102       return std::make_unique<ElementwiseOperationParser>(OperationType::TANH);
3103     case kTfLiteBuiltinTile:
3104       return std::make_unique<TileOperationParser>();
3105     case kTfLiteBuiltinTranspose:
3106       return std::make_unique<TransposeOperationParser>();
3107     case kTfLiteBuiltinTransposeConv:
3108       return std::make_unique<TransposeConvBuiltinOperationParser>();
3109     case kTfLiteBuiltinUnpack:
3110       return std::make_unique<UnpackOperationParser>();
3111     case kTfLiteBuiltinCustom: {
3112       const absl::string_view custom_name = registration->custom_name;
3113       if (custom_name == "Convolution2DTransposeBias") {
3114         return std::make_unique<TransposeConvCustomOperationParser>();
3115       }
3116       if (custom_name == "MaxPoolingWithArgmax2D") {
3117         return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
3118       }
3119       if (custom_name == "MaxUnpooling2D") {
3120         return std::make_unique<Unpooling2DOperationParser>();
3121       }
3122       if (custom_name == "Resampler") {
3123         return std::make_unique<ResamplerOperationParser>();
3124       }
3125       return NewCustomOperationParser(registration->custom_name);
3126     }
3127   }
3128   return std::make_unique<UnsupportedOperationParser>();
3129 }
3130 
3131 // TODO(impjdi): Check number of input/output tensors and their dimensions.
3132 // TODO(impjdi): Check ops' parameters.
GetOpsToReplace(TfLiteContext * context,bool allow_quant_ops,int max_delegated_partitions,const absl::flat_hash_set<TfLiteBuiltinOperator> * excluded_ops)3133 TfLiteIntArray* GetOpsToReplace(
3134     TfLiteContext* context, bool allow_quant_ops, int max_delegated_partitions,
3135     const absl::flat_hash_set<TfLiteBuiltinOperator>* excluded_ops) {
3136   delegates::IsNodeSupportedFn node_supported_fn =
3137       [=](TfLiteContext* context, TfLiteNode* node,
3138           TfLiteRegistration* registration,
3139           std::string* unsupported_details) -> bool {
3140     const auto status =
3141         IsSupported(context, node, registration, allow_quant_ops, excluded_ops);
3142     if (!status.ok()) {
3143       if (unsupported_details) {
3144         *unsupported_details = std::string(status.message());
3145       }
3146       return false;
3147     }
3148 
3149     std::vector<TfLiteType> allowed_in_types = {kTfLiteFloat32, kTfLiteFloat16};
3150     std::vector<TfLiteType> allowed_out_types = {kTfLiteFloat32,
3151                                                  kTfLiteFloat16};
3152     if (allow_quant_ops) {
3153       // Since we only check non-constant tensors, type cannot be Int32.
3154       allowed_in_types.push_back(kTfLiteInt8);
3155       allowed_in_types.push_back(kTfLiteUInt8);
3156       allowed_out_types.push_back(kTfLiteInt8);
3157       allowed_out_types.push_back(kTfLiteUInt8);
3158     }
3159     if (IsLogicalCode(registration->builtin_code)) {
3160       allowed_out_types.push_back(kTfLiteBool);
3161     }
3162     if (registration->builtin_code == kTfLiteBuiltinCast) {
3163       allowed_in_types.push_back(kTfLiteBool);
3164       allowed_in_types.push_back(kTfLiteFloat32);
3165       allowed_in_types.push_back(kTfLiteInt32);
3166       allowed_out_types.push_back(kTfLiteFloat32);
3167       allowed_out_types.push_back(kTfLiteInt32);
3168       allowed_out_types.push_back(kTfLiteBool);
3169     }
3170     if (registration->builtin_code == kTfLiteBuiltinOneHot) {
3171       allowed_in_types.push_back(kTfLiteInt32);
3172     }
3173     if (registration->builtin_code == kTfLiteBuiltinSelectV2) {
3174       allowed_in_types.push_back(kTfLiteBool);
3175     }
3176     if (!IsAllAllowedTensors(context, node->inputs, allowed_in_types) ||
3177         !IsAllAllowedTensors(context, node->outputs, allowed_out_types)) {
3178       if (unsupported_details) {
3179         *unsupported_details =
3180             "OP is supported, but tensor type/shape isn't compatible.";
3181       }
3182       return false;
3183     }
3184     return true;
3185   };
3186 
3187   delegates::FP16GraphPartitionHelper partition_helper(context,
3188                                                        node_supported_fn);
3189   std::set<std::string> unsupported_nodes_info;
3190   if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) {
3191     return TfLiteIntArrayCreate(0);
3192   }
3193 
3194   // By default, we simply get 1st largest partition as 'max_delegate_partions'
3195   // is set to 1 by default.
3196   std::vector<int> ops_to_replace =
3197       partition_helper.GetNodesOfFirstNLargestPartitions(
3198           max_delegated_partitions);
3199 
3200   if (!unsupported_nodes_info.empty() &&
3201       partition_helper.num_total_nodes() > ops_to_replace.size()) {
3202     std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n");
3203     std::string error_message = absl::StrCat(
3204         "Following operations are not supported by GPU delegate:\n",
3205         unsupported, "\n");
3206     if (!ops_to_replace.empty()) {
3207       absl::StrAppend(
3208           &error_message, ops_to_replace.size(),
3209           " operations will run on the GPU, and the remaining ",
3210           partition_helper.num_total_nodes() - ops_to_replace.size());
3211     } else {
3212       absl::StrAppend(&error_message,
3213                       "No operations will run on the GPU, and all ",
3214                       partition_helper.num_total_nodes());
3215     }
3216     absl::StrAppend(&error_message, " operations will run on the CPU.");
3217     TF_LITE_KERNEL_LOG(context, error_message.c_str());
3218   }
3219   return ConvertVectorToTfLiteIntArray(ops_to_replace);
3220 }
3221 
3222 // Creates inputs and outputs passed by io_tensors parameters in the resulting
3223 // graph. We force it to make sure that delegated subgraph has same order of
3224 // inputs and outputs with the original one. When delegated model is built from
3225 // the tflite model representation tensors are created lazily, so there is no
3226 // guarantee that the order will match the source model tensors order.
PrecreateIOTensors(TfLiteContext * context,GraphFloat32 * graph,const std::vector<int> & io_ids,absl::flat_hash_map<int,int> * quant_conversion_map,absl::flat_hash_map<int,Value * > * tensor_to_value)3227 absl::Status PrecreateIOTensors(
3228     TfLiteContext* context, GraphFloat32* graph, const std::vector<int>& io_ids,
3229     absl::flat_hash_map<int, int>* quant_conversion_map,
3230     absl::flat_hash_map<int, Value*>* tensor_to_value) {
3231   for (const auto& id : io_ids) {
3232     const TfLiteTensor& tflite_tensor = context->tensors[id];
3233     if (tflite::IsConstantTensor(&tflite_tensor)) continue;
3234     RETURN_IF_ERROR(ObjectReader::ReadNonConstantTensor(
3235         context, tensor_to_value, quant_conversion_map, graph, id));
3236   }
3237   return absl::OkStatus();
3238 }
3239 
CopyVariableTensorOutputs(TfLiteNode * tflite_node,TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader & reader,const absl::flat_hash_map<int,ValueId> & new_variable_tensor_values)3240 absl::Status CopyVariableTensorOutputs(
3241     TfLiteNode* tflite_node, TfLiteRegistration* registration,
3242     GraphFloat32* graph, ObjectReader& reader,
3243     const absl::flat_hash_map<int, ValueId>& new_variable_tensor_values) {
3244   absl::flat_hash_map<int, ValueId> new_variable_tensor_values_copy(
3245       new_variable_tensor_values);
3246   // Retrieve the final value id for the variable input tensors.
3247   for (int i = 0; i < tflite_node->inputs->size; i++) {
3248     int tensor_idx = tflite_node->inputs->data[i];
3249     Value* value;
3250     if (!reader.ReadValueByTensorIdx(tensor_idx, &value).ok()) continue;
3251     if (value->tensor.is_variable_input) {
3252       if (new_variable_tensor_values_copy.find(i) ==
3253           new_variable_tensor_values_copy.end()) {
3254         return absl::InvalidArgumentError(
3255             absl::StrCat(GetOpNameByRegistration(*registration),
3256                          " did not provide a new value for the variable input "
3257                          "tensor with index ",
3258                          tensor_idx));
3259       } else {
3260         Node* node = graph->NewNode();
3261         node->operation.type = ToString(OperationType::COPY);
3262         RETURN_IF_ERROR(graph->AddConsumer(
3263             node->id, new_variable_tensor_values_copy.at(i)));
3264         RETURN_IF_ERROR(reader.AddUpdate(node, i));
3265         new_variable_tensor_values_copy.erase(
3266             new_variable_tensor_values_copy.find(i));
3267       }
3268     }
3269   }
3270   if (!new_variable_tensor_values_copy.empty()) {
3271     return absl::InvalidArgumentError(
3272         "More input variable tensors asked to be copied than present on the "
3273         "node");
3274   }
3275   return absl::OkStatus();
3276 }
3277 
BuildModel(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)3278 absl::Status BuildModel(TfLiteContext* context,
3279                         const TfLiteDelegateParams* delegate_params,
3280                         GraphFloat32* graph,
3281                         absl::flat_hash_map<int, int>* quant_conversion_map) {
3282   std::vector<int> inputs(delegate_params->input_tensors->size);
3283   std::vector<int> outputs(delegate_params->output_tensors->size);
3284   for (int i = 0; i < delegate_params->input_tensors->size; i++) {
3285     inputs[i] = delegate_params->input_tensors->data[i];
3286   }
3287   for (int i = 0; i < delegate_params->output_tensors->size; i++) {
3288     outputs[i] = delegate_params->output_tensors->data[i];
3289   }
3290   return BuildModelEnforceIO(context, delegate_params, inputs, outputs, graph,
3291                              quant_conversion_map);
3292 }
3293 
BuildModelEnforceIO(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,const std::vector<int> & input_ids,const std::vector<int> & output_ids,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)3294 absl::Status BuildModelEnforceIO(
3295     TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
3296     const std::vector<int>& input_ids, const std::vector<int>& output_ids,
3297     GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) {
3298   std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
3299   std::vector<int> tflite_nodes;
3300   for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
3301     TfLiteNode* tflite_node = nullptr;
3302     TfLiteRegistration* registration = nullptr;
3303     RETURN_IF_ERROR(GetNodeAndRegistration(
3304         context, delegate_params->nodes_to_replace->data[i], &tflite_node,
3305         &registration));
3306     if (registration->builtin_code == kTfLiteBuiltinDequantize &&
3307         context->tensors[tflite_node->inputs->data[0]].type ==
3308             TfLiteType::kTfLiteFloat16 &&
3309         context->tensors[tflite_node->inputs->data[0]].allocation_type ==
3310             TfLiteAllocationType::kTfLiteMmapRo) {
3311       // Ignore Fp16 Dequantize nodes only if they are the final nodes before
3312       // weights, i.e. no other nodes preceded them (e.g. DENSIFY).
3313       continue;
3314     }
3315     auto op_parser = NewOperationParser(
3316         registration, /*allow_quant_ops=*/quant_conversion_map != nullptr);
3317     if (!op_parser) {
3318       return absl::UnimplementedError(
3319           absl::StrCat("Operation ", registration->builtin_code, "(",
3320                        registration->custom_name,
3321                        ") is not supported by TFLite GPU Delegate."));
3322     }
3323     operations.push_back(std::move(op_parser));
3324     tflite_nodes.push_back(i);
3325   }
3326   absl::flat_hash_map<int, Value*> tensor_to_value;
3327   std::vector<ValueId> variable_inputs_to_value_id;
3328 
3329   RETURN_IF_ERROR(PrecreateIOTensors(context, graph, input_ids,
3330                                      quant_conversion_map, &tensor_to_value));
3331   RETURN_IF_ERROR(PrecreateIOTensors(context, graph, output_ids,
3332                                      quant_conversion_map, &tensor_to_value));
3333   for (int i = 0; i < operations.size(); ++i) {
3334     TfLiteNode* tflite_node;
3335     TfLiteRegistration* registration;
3336     RETURN_IF_ERROR(GetNodeAndRegistration(
3337         context, delegate_params->nodes_to_replace->data[tflite_nodes[i]],
3338         &tflite_node, &registration));
3339     ObjectReader reader(graph, context, tflite_node, &tensor_to_value,
3340                         quant_conversion_map);
3341     const auto status =
3342         operations[i]->Parse(tflite_node, registration, graph, &reader);
3343     if (!status.ok()) {
3344       return absl::InternalError(absl::StrCat(
3345           GetOpNameByRegistration(*registration), ": ", status.message()));
3346     }
3347 
3348     absl::flat_hash_map<int, ValueId> new_value_for_variable_input_tensors =
3349         operations[i]->GetNewValueIdsForVariableInputNodes();
3350 
3351     RETURN_IF_ERROR(
3352         CopyVariableTensorOutputs(tflite_node, registration, graph, reader,
3353                                   new_value_for_variable_input_tensors));
3354   }
3355 
3356   // Variable input tensors expect to be unchanged throughout model execution.
3357   // They need to be an output of the graph in order to have them unchanged.
3358   for (auto value_id : variable_inputs_to_value_id) {
3359     if (!graph->IsGraphOutput(value_id)) {
3360       return absl::InvalidArgumentError(
3361           absl::StrCat("Variable input tensors must be a graph output. Value ",
3362                        value_id, " is not a graph output"));
3363     }
3364   }
3365   return absl::OkStatus();
3366 }
3367 
BuildFinalModel(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)3368 absl::Status BuildFinalModel(
3369     TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
3370     GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) {
3371   RETURN_IF_ERROR(
3372       BuildModel(context, delegate_params, graph, quant_conversion_map));
3373 
3374   // Apply general transformations on the graph.
3375   ModelTransformer transformer(graph);
3376   if (!ApplyModelTransformations(&transformer)) {
3377     return absl::InternalError("Graph transformations failed");
3378   }
3379   return absl::OkStatus();
3380 }
3381 
3382 namespace {
3383 
3384 class DelegateContext {
3385  public:
3386   struct DelegateData {
3387     std::vector<int> input_ids;
3388     std::vector<int> output_ids;
3389     GraphFloat32* graph;
3390     std::unique_ptr<absl::flat_hash_map<int, int>> quant_conversion_map;
3391   };
Init(TfLiteContext * context,const TfLiteDelegateParams * delegate_params)3392   bool Init(TfLiteContext* context,
3393             const TfLiteDelegateParams* delegate_params) {
3394     const auto* delegate_data =
3395         reinterpret_cast<DelegateData*>(delegate_params->delegate->data_);
3396     return delegate_data->graph &&
3397            BuildModelEnforceIO(context, delegate_params,
3398                                delegate_data->input_ids,
3399                                delegate_data->output_ids, delegate_data->graph,
3400                                delegate_data->quant_conversion_map.get())
3401                .ok();
3402   }
3403 };
3404 
DelegatePrepare(TfLiteContext * context,TfLiteDelegate * delegate)3405 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
3406   TfLiteRegistration registration{};
3407   registration.init = [](TfLiteContext* context, const char* buffer,
3408                          size_t) -> void* {
3409     auto* delegate_context = new DelegateContext();
3410     if (!delegate_context->Init(
3411             context, reinterpret_cast<const TfLiteDelegateParams*>(buffer))) {
3412       delete delegate_context;
3413       return nullptr;
3414     }
3415     return delegate_context;
3416   };
3417   registration.free = [](TfLiteContext* context, void* buffer) -> void {
3418     delete reinterpret_cast<DelegateContext*>(buffer);
3419   };
3420   registration.prepare = [](TfLiteContext* context,
3421                             TfLiteNode* node) -> TfLiteStatus {
3422     return node->user_data ? kTfLiteOk : kTfLiteError;
3423   };
3424 
3425   const auto* delegate_data =
3426       reinterpret_cast<const DelegateContext::DelegateData*>(delegate->data_);
3427   TfLiteIntArray* ops_to_replace = GetOpsToReplace(
3428       context, static_cast<bool>(delegate_data->quant_conversion_map));
3429   const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
3430       context, registration, ops_to_replace, delegate);
3431   TfLiteIntArrayFree(ops_to_replace);
3432   return status;
3433 }
3434 
3435 }  // namespace
3436 
BuildFromFlatBuffer(const tflite::FlatBufferModel & flatbuffer,const tflite::OpResolver & op_resolver,GraphFloat32 * graph,bool allow_quant_ops)3437 absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
3438                                  const tflite::OpResolver& op_resolver,
3439                                  GraphFloat32* graph, bool allow_quant_ops) {
3440   std::unique_ptr<tflite::Interpreter> interpreter;
3441   tflite::InterpreterBuilder interpreter_builder(flatbuffer, op_resolver);
3442   if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) {
3443     return absl::InternalError("Unable to prepare TfLite interpreter.");
3444   }
3445   TfLiteDelegate delegate;
3446 
3447   DelegateContext::DelegateData delegate_data{interpreter->inputs(),
3448                                               interpreter->outputs(), graph};
3449   if (allow_quant_ops) {
3450     delegate_data.quant_conversion_map =
3451         std::make_unique<absl::flat_hash_map<int, int>>();
3452   }
3453 
3454   delegate.data_ = &delegate_data;
3455   delegate.flags = kTfLiteDelegateFlagsNone;
3456   delegate.Prepare = DelegatePrepare;
3457   delegate.CopyFromBufferHandle = nullptr;
3458   delegate.CopyToBufferHandle = nullptr;
3459   delegate.FreeBufferHandle = nullptr;
3460 
3461   if (interpreter->ModifyGraphWithDelegate(&delegate) != kTfLiteOk) {
3462     return absl::InternalError("Conversion from TfLite model failed.");
3463   }
3464 
3465   ModelTransformer transformer(graph);
3466   if (!ApplyModelTransformations(&transformer)) {
3467     return absl::InternalError("Graph transformations failed");
3468   }
3469 
3470   return absl::OkStatus();
3471 }
3472 
3473 }  // namespace gpu
3474 }  // namespace tflite
3475