1 /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/base/attributes.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/status/status.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/strings/string_view.h"
34 #include "tensorflow/lite/builtin_ops.h"
35 #include "tensorflow/lite/c/builtin_op_data.h"
36 #include "tensorflow/lite/c/c_api_types.h"
37 #include "tensorflow/lite/c/common.h"
38 #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
39 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
40 #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
41 #include "tensorflow/lite/delegates/gpu/common/model.h"
42 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
43 #include "tensorflow/lite/delegates/gpu/common/model_builder_internal.h"
44 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
45 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
46 #include "tensorflow/lite/delegates/gpu/common/operation_parser.h"
47 #include "tensorflow/lite/delegates/gpu/common/operations.h"
48 #include "tensorflow/lite/delegates/gpu/common/shape.h"
49 #include "tensorflow/lite/delegates/gpu/common/status.h"
50 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
51 #include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
52 #include "tensorflow/lite/delegates/utils.h"
53 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
54 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
55 #include "tensorflow/lite/kernels/kernel_util.h"
56 #include "tensorflow/lite/tools/versioning/gpu_compatibility.h"
57 #include "tensorflow/lite/util.h"
58
59 namespace tflite {
60 namespace gpu {
61 namespace {
62
GetFullyConnectedAttributes(int weights_tensor_id,int bias_tensor_id,ObjectReader * reader,FullyConnectedAttributes * attr)63 absl::Status GetFullyConnectedAttributes(int weights_tensor_id,
64 int bias_tensor_id,
65 ObjectReader* reader,
66 FullyConnectedAttributes* attr) {
67 Tensor<HW, DataType::FLOAT32> weights;
68 RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
69 attr->weights.data = std::move(weights.data);
70 attr->weights.id = weights.id;
71 attr->weights.shape.h = 1;
72 attr->weights.shape.w = 1;
73 attr->weights.shape.o = weights.shape.h;
74 attr->weights.shape.i = weights.shape.w;
75 reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError(); // optional
76 return absl::OkStatus();
77 }
78
79 template <typename ParamsT>
RetrieveBuiltinData(const TfLiteNode * tflite_node,const ParamsT ** tf_options)80 absl::Status RetrieveBuiltinData(const TfLiteNode* tflite_node,
81 const ParamsT** tf_options) {
82 *tf_options = static_cast<const ParamsT*>(tflite_node->builtin_data);
83 if (!*tf_options) {
84 return absl::InternalError("Unable to retrieve builtin_data.");
85 }
86 return absl::OkStatus();
87 }
88
89 template <typename ParamsT>
RetrieveCustomInitialData(const TfLiteNode * tflite_node,const ParamsT ** tf_options)90 absl::Status RetrieveCustomInitialData(const TfLiteNode* tflite_node,
91 const ParamsT** tf_options) {
92 *tf_options = static_cast<const ParamsT*>(tflite_node->custom_initial_data);
93 if (!*tf_options) {
94 return absl::InternalError("Unable to retrieve custom_initial_data.");
95 }
96 return absl::OkStatus();
97 }
98
99 // Creates a simple node that holds tensor value.
NewConstNode(TensorFloat32 t,GraphFloat32 * graph,Value ** value)100 absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, Value** value) {
101 ConstTensorAttributes attr;
102 attr.tensor = std::move(t);
103 Node* node = graph->NewNode();
104 node->operation.attributes = attr;
105 node->operation.type = ToString(OperationType::CONSTANT);
106 *value = graph->NewValue();
107 RETURN_IF_ERROR(graph->SetProducer(node->id, (*value)->id));
108 // Keep data inside this tensor.
109 (*value)->tensor.ref = attr.tensor.id;
110 (*value)->tensor.type = attr.tensor.kType;
111 (*value)->tensor.shape = attr.tensor.shape;
112 return absl::OkStatus();
113 }
114
ParseInputsWithConstTensor(Node * node,ObjectReader * reader,TensorOrScalar * tensor_or_scalar)115 absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
116 TensorOrScalar* tensor_or_scalar) {
117 const std::string& opname = node->operation.type;
118
119 // Determine runtime/constant tensors.
120 const TfLiteTensor* input0 = reader->GetInputTensor(0);
121 if (!input0) {
122 return absl::InvalidArgumentError("Couldn't get the 1st input tensor for " +
123 opname);
124 }
125 const TfLiteTensor* input1 = reader->GetInputTensor(1);
126 if (!input1) {
127 return absl::InvalidArgumentError("Couldn't get the 2nd input tensor for " +
128 opname);
129 }
130 const bool constant_tensor0 = IsConstantTensor(input0);
131 const bool constant_tensor1 = IsConstantTensor(input1);
132 if (constant_tensor0 && constant_tensor1) {
133 return absl::InvalidArgumentError("No runtime input tensors for " + opname);
134 }
135 const bool runtime_tensor0 = !constant_tensor0;
136 const bool runtime_tensor1 = !constant_tensor1;
137
138 if (runtime_tensor0 && runtime_tensor1) {
139 RETURN_IF_ERROR(reader->AddInput(node, 0));
140 RETURN_IF_ERROR(reader->AddInput(node, 1));
141 } else {
142 int runtime_tensor = 0;
143 int constant_tensor = 1;
144 TfLiteIntArray* constant_dims = input1->dims;
145 if (constant_tensor0 && runtime_tensor1) {
146 runtime_tensor = 1;
147 constant_tensor = 0;
148 constant_dims = input0->dims;
149 }
150 RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
151 if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) {
152 Tensor<Scalar, DataType::FLOAT32> tensor;
153 RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
154 *tensor_or_scalar = tensor.data[0];
155 } else {
156 if (CheckIfLinearConvertible(constant_dims).ok()) {
157 Tensor<Linear, DataType::FLOAT32> tensor;
158 RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
159 *tensor_or_scalar = std::move(tensor);
160 } else if (constant_dims->size == 2) {
161 Tensor<HW, DataType::FLOAT32> tensor_hw;
162 RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor_hw));
163 Tensor<HWC, DataType::FLOAT32> tensor;
164 tensor.id = tensor_hw.id;
165 tensor.shape = HWC(1, tensor_hw.shape.h, tensor_hw.shape.w);
166 tensor.data = tensor_hw.data;
167 *tensor_or_scalar = std::move(tensor);
168 } else {
169 Tensor<HWC, DataType::FLOAT32> tensor;
170 RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
171 *tensor_or_scalar = std::move(tensor);
172 }
173 }
174 }
175 return absl::OkStatus();
176 }
177
MaybeFuseActivationForElementwiseNode(OperationType operation_type,const TfLiteNode * tflite_node,GraphFloat32 * graph,Node * node)178 absl::Status MaybeFuseActivationForElementwiseNode(
179 OperationType operation_type, const TfLiteNode* tflite_node,
180 GraphFloat32* graph, Node* node) {
181 TfLiteFusedActivation activation = kTfLiteActNone;
182 switch (operation_type) {
183 case OperationType::MUL: {
184 const TfLiteMulParams* tf_options;
185 if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
186 activation = tf_options->activation;
187 }
188 break;
189 }
190 case OperationType::ADD: {
191 const TfLiteAddParams* tf_options;
192 if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
193 activation = tf_options->activation;
194 }
195 break;
196 }
197 case OperationType::SUB: {
198 const TfLiteSubParams* tf_options;
199 if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
200 activation = tf_options->activation;
201 }
202 break;
203 }
204 case OperationType::DIV: {
205 const TfLiteDivParams* tf_options;
206 if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
207 activation = tf_options->activation;
208 }
209 break;
210 }
211 default:
212 // No activation expected.
213 activation = kTfLiteActNone;
214 }
215
216 if (activation) {
217 return MaybeFuseActivation(activation, graph, node);
218 }
219 return absl::OkStatus();
220 }
221
222 struct TensorInfo {
223 std::vector<std::pair<TfLiteNode*, TfLiteRegistration*>> producers;
224 std::vector<std::pair<TfLiteNode*, TfLiteRegistration*>> consumers;
225 };
226
GetTensorInfo(const TfLiteContext * context,int tensor_id,TensorInfo * result)227 absl::Status GetTensorInfo(const TfLiteContext* context, int tensor_id,
228 TensorInfo* result) {
229 TfLiteIntArray* execution_plan = nullptr;
230 if (context->GetExecutionPlan(const_cast<TfLiteContext*>(context),
231 &execution_plan) != kTfLiteOk) {
232 return absl::UnavailableError("Unable to get graph execution plan.");
233 }
234 for (int i = 0; i < execution_plan->size; ++i) {
235 const int node_index = execution_plan->data[i];
236 TfLiteNode* node = nullptr;
237 TfLiteRegistration* registration = nullptr;
238 if (context->GetNodeAndRegistration(const_cast<TfLiteContext*>(context),
239 node_index, &node,
240 ®istration) != kTfLiteOk) {
241 return absl::UnavailableError(
242 "Unable to get node and registration for node.");
243 }
244 for (int j = 0; j < node->inputs->size; ++j) {
245 if (tensor_id == node->inputs->data[j]) {
246 result->consumers.push_back({node, registration});
247 }
248 }
249 for (int j = 0; j < node->outputs->size; ++j) {
250 if (tensor_id == node->outputs->data[j]) {
251 result->producers.push_back({node, registration});
252 }
253 }
254 }
255 return absl::OkStatus();
256 }
257
IsLogicalCode(int32_t builtin_code)258 bool IsLogicalCode(int32_t builtin_code) {
259 return builtin_code == kTfLiteBuiltinGreater ||
260 builtin_code == kTfLiteBuiltinGreaterEqual ||
261 builtin_code == kTfLiteBuiltinLess ||
262 builtin_code == kTfLiteBuiltinLessEqual ||
263 builtin_code == kTfLiteBuiltinEqual ||
264 builtin_code == kTfLiteBuiltinNotEqual;
265 }
266
IsLogicalOp(tflite::gpu::OperationType op_type)267 bool IsLogicalOp(tflite::gpu::OperationType op_type) {
268 return op_type == tflite::gpu::OperationType::GREATER ||
269 op_type == tflite::gpu::OperationType::GREATER_EQUAL ||
270 op_type == tflite::gpu::OperationType::LESS ||
271 op_type == tflite::gpu::OperationType::LESS_EQUAL ||
272 op_type == tflite::gpu::OperationType::EQUAL ||
273 op_type == tflite::gpu::OperationType::NOT_EQUAL;
274 }
275
276 class BatchedMatMulOperationParser : public TFLiteOperationParser {
277 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)278 absl::Status IsSupported(const TfLiteContext* context,
279 const TfLiteNode* tflite_node,
280 const TfLiteRegistration* registration) final {
281 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
282 }
283
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)284 absl::Status Parse(const TfLiteNode* tflite_node,
285 const TfLiteRegistration* registration,
286 GraphFloat32* graph, ObjectReader* reader) final {
287 if (reader->GetNumberOfRuntimeInputs() == 2) {
288 Node* node = graph->NewNode();
289 node->operation.type = ToString(OperationType::BATCHED_MATMUL);
290 RETURN_IF_ERROR(reader->AddInput(node, 0));
291 RETURN_IF_ERROR(reader->AddInput(node, 1));
292 RETURN_IF_ERROR(reader->AddOutputs(node));
293 return absl::OkStatus();
294 } else if (reader->GetNumberOfRuntimeInputs() == 1) {
295 // Second input is constant, replace with Convolution2D
296 const TfLiteTensor* second_input = reader->GetInputTensor(1);
297 if (!IsConstantTensor(second_input) || second_input->dims->size != 2) {
298 // first input must be runtime and second is 2d constant tensor
299 return absl::UnavailableError("Not supported batched mat mul case");
300 }
301 Node* node = graph->NewNode();
302 node->operation.type = ToString(OperationType::CONVOLUTION_2D);
303 RETURN_IF_ERROR(reader->AddInput(node, 0));
304 RETURN_IF_ERROR(reader->AddOutputs(node));
305
306 Tensor<HW, DataType::FLOAT32> weights;
307 RETURN_IF_ERROR(reader->ReadTensor(1, &weights));
308 Convolution2DAttributes attr;
309 attr.weights.data.resize(weights.shape.w * weights.shape.h);
310 for (int i = 0; i < weights.shape.w; ++i) {
311 for (int j = 0; j < weights.shape.h; ++j) {
312 attr.weights.data[i * weights.shape.h + j] =
313 weights.data[j * weights.shape.w + i];
314 }
315 }
316 attr.weights.id = weights.id;
317 attr.weights.shape.h = 1;
318 attr.weights.shape.w = 1;
319 attr.weights.shape.o = weights.shape.w;
320 attr.weights.shape.i = weights.shape.h;
321 attr.strides = HW(1, 1);
322 attr.dilations = HW(1, 1);
323 attr.padding.appended = HW(0, 0);
324 attr.padding.prepended = HW(0, 0);
325 node->operation.attributes = std::move(attr);
326 return absl::OkStatus();
327 } else {
328 return absl::UnavailableError("Not supported batched mat mul case");
329 }
330 return absl::OkStatus();
331 }
332 };
333
334 class CastOperationParser : public TFLiteOperationParser {
335 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)336 absl::Status IsSupported(const TfLiteContext* context,
337 const TfLiteNode* tflite_node,
338 const TfLiteRegistration* registration) final {
339 TfLiteType src_type = context->tensors[tflite_node->inputs->data[0]].type;
340 TfLiteType dst_type = context->tensors[tflite_node->outputs->data[0]].type;
341 if (src_type == kTfLiteBool &&
342 (dst_type == kTfLiteFloat16 || dst_type == kTfLiteFloat32)) {
343 // check that we have next sequence:
344 // logical_op->bool_tensor->CAST->float_tensor.
345 TensorInfo input_tensor_info;
346 RETURN_IF_ERROR(GetTensorInfo(context, tflite_node->inputs->data[0],
347 &input_tensor_info));
348 if (input_tensor_info.producers.size() != 1 ||
349 input_tensor_info.consumers.size() != 1) {
350 return absl::UnavailableError("Not supported cast case");
351 }
352 // If the cast is an output, do the cast to float on CPU.
353 TensorInfo output_tensor_info;
354 RETURN_IF_ERROR(GetTensorInfo(context, tflite_node->outputs->data[0],
355 &output_tensor_info));
356 if (output_tensor_info.consumers.size() != 1) {
357 return absl::UnavailableError(
358 "Cast from bool not supported for outputs");
359 }
360 if (IsLogicalCode(input_tensor_info.producers[0].second->builtin_code)) {
361 return absl::OkStatus();
362 }
363 }
364 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
365 }
366
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)367 absl::Status Parse(const TfLiteNode* tflite_node,
368 const TfLiteRegistration* registration,
369 GraphFloat32* graph, ObjectReader* reader) final {
370 Node* node = graph->NewNode();
371 node->operation.type = ToString(OperationType::CAST);
372 RETURN_IF_ERROR(reader->AddInput(node, 0));
373 RETURN_IF_ERROR(reader->AddOutputs(node));
374 return absl::OkStatus();
375 }
376 };
377
378 class ClampOperationsParser : public TFLiteOperationParser {
379 public:
ClampOperationsParser(float clamp_a,float clamp_b)380 explicit ClampOperationsParser(float clamp_a, float clamp_b)
381 : clamp_a_(clamp_a), clamp_b_(clamp_b) {}
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)382 absl::Status IsSupported(const TfLiteContext* context,
383 const TfLiteNode* tflite_node,
384 const TfLiteRegistration* registration) final {
385 return absl::OkStatus();
386 }
387
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)388 absl::Status Parse(const TfLiteNode* tflite_node,
389 const TfLiteRegistration* registration,
390 GraphFloat32* graph, ObjectReader* reader) final {
391 // clamp(v, a, b) = clamp(v - a, 0.0, b - a) + a;
392 // We replace clamp(...) with sequence of elementwise ops:
393 // substaction -> usual relu with alpha = 0.0 -> addition.
394 // node_sub = v0 = v - a // add op (add -a)
395 // node_relu = v1 = clamp(v0, 0.0, clip); // relu op alpha = 0.0,
396 // clip = b - a;
397 // node_add = v2 = v1 + a // add op (add a)
398 Node* node_sub = graph->NewNode();
399 Node* node_relu = graph->NewNode();
400 Node* node_add = graph->NewNode();
401
402 ElementwiseAttributes sub_attr;
403 sub_attr.param = -clamp_a_;
404 node_sub->operation.type = ToString(OperationType::ADD);
405 node_sub->operation.attributes = std::move(sub_attr);
406
407 ReLUAttributes relu_attr;
408 relu_attr.alpha = 0.0f;
409 relu_attr.clip = clamp_b_ - clamp_a_;
410 node_relu->operation.type = ToString(OperationType::RELU);
411 node_relu->operation.attributes = relu_attr;
412
413 ElementwiseAttributes add_attr;
414 add_attr.param = clamp_a_;
415 node_add->operation.type = ToString(OperationType::ADD);
416 node_add->operation.attributes = std::move(add_attr);
417
418 RETURN_IF_ERROR(reader->AddInput(node_sub, 0));
419 auto input = graph->FindInputs(node_sub->id)[0];
420
421 Value* v0 = graph->NewValue();
422 Value* v1 = graph->NewValue();
423 v0->tensor.type = input->tensor.type;
424 v0->tensor.shape = input->tensor.shape;
425 v1->tensor.type = input->tensor.type;
426 v1->tensor.shape = input->tensor.shape;
427
428 RETURN_IF_ERROR(graph->SetProducer(node_sub->id, v0->id));
429 RETURN_IF_ERROR(graph->AddConsumer(node_relu->id, v0->id));
430 RETURN_IF_ERROR(graph->SetProducer(node_relu->id, v1->id));
431 RETURN_IF_ERROR(graph->AddConsumer(node_add->id, v1->id));
432
433 RETURN_IF_ERROR(reader->AddOutputs(node_add));
434 return absl::OkStatus();
435 }
436
437 private:
438 const float clamp_a_, clamp_b_;
439 };
440
441 class ConcatenationOperationParser : public TFLiteOperationParser {
442 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)443 absl::Status IsSupported(const TfLiteContext* context,
444 const TfLiteNode* tflite_node,
445 const TfLiteRegistration* registration) final {
446 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
447
448 // TODO(eignasheva): add proper tensor availability checking
449 // for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
450 // RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, idx));
451 // }
452 // TODO(eignasheva): add axis checking.
453 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
454 }
455
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)456 absl::Status Parse(const TfLiteNode* tflite_node,
457 const TfLiteRegistration* registration,
458 GraphFloat32* graph, ObjectReader* reader) final {
459 ConcatAttributes attr;
460 // Read inputs first to make sure const node is added to a graph before
461 // concat node to ensure topological order.
462 std::vector<const Value*> inputs;
463 for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
464 Value* value;
465 const auto status = reader->ReadValue(idx, &value);
466 if (status.ok()) {
467 inputs.push_back(value);
468 } else {
469 TensorFloat32 tensor;
470 RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
471 Value* value;
472 RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
473 inputs.push_back(value);
474 }
475 }
476
477 for (int i = 0; i < inputs.size(); ++i) {
478 for (int j = 0; j < i; ++j) {
479 if (inputs[i] == inputs[j]) {
480 Node* node_copy = graph->NewNode();
481 node_copy->operation.type = ToString(OperationType::COPY);
482 RETURN_IF_ERROR(graph->AddConsumer(node_copy->id, inputs[j]->id));
483 Value* copy_value = graph->NewValue();
484 copy_value->tensor.type = inputs[j]->tensor.type;
485 copy_value->tensor.shape = inputs[j]->tensor.shape;
486 RETURN_IF_ERROR(graph->SetProducer(node_copy->id, copy_value->id));
487 inputs[i] = copy_value;
488 break;
489 }
490 }
491 }
492
493 Node* node = graph->NewNode();
494 node->operation.type = ToString(OperationType::CONCAT);
495 RETURN_IF_ERROR(reader->AddOutputs(node));
496 for (int i = 0; i < inputs.size(); ++i) {
497 RETURN_IF_ERROR(graph->AddConsumer(node->id, inputs[i]->id));
498 }
499
500 std::vector<BHWC> input_shapes;
501 for (auto input : graph->FindInputs(node->id)) {
502 input_shapes.push_back(input->tensor.shape);
503 }
504 RETURN_IF_ERROR(SetAxis(input_shapes, &attr.axis));
505
506 // Guess axis.
507 BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
508 for (auto input : graph->FindInputs(node->id)) {
509 if (input->tensor.shape.h != output_shape.h) {
510 attr.axis = Axis::HEIGHT;
511 break;
512 }
513 if (input->tensor.shape.w != output_shape.w) {
514 attr.axis = Axis::WIDTH;
515 break;
516 }
517 if (input->tensor.shape.c != output_shape.c) {
518 attr.axis = Axis::CHANNELS;
519 break;
520 }
521 }
522 const TfLiteConcatenationParams* tf_options;
523 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
524 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
525 node->operation.attributes = attr;
526 return absl::OkStatus();
527 }
528
529 private:
SetAxis(const std::vector<BHWC> & input_shapes,Axis * axis)530 absl::Status SetAxis(const std::vector<BHWC>& input_shapes, Axis* axis) {
531 *axis = Axis::BATCH;
532 for (int i = 1; i < input_shapes.size(); i++) {
533 if (input_shapes[0].h != input_shapes[i].h &&
534 input_shapes[0].w != input_shapes[i].w &&
535 input_shapes[0].c != input_shapes[i].c) {
536 *axis = Axis::HEIGHT;
537 break;
538 }
539 }
540 if (*axis == Axis::BATCH) return absl::OkStatus();
541 for (int i = 1; i < input_shapes.size(); i++) {
542 if (input_shapes[0].b != input_shapes[i].b &&
543 input_shapes[0].w != input_shapes[i].w &&
544 input_shapes[0].c != input_shapes[i].c) {
545 *axis = Axis::WIDTH;
546 break;
547 }
548 }
549 if (*axis == Axis::HEIGHT) return absl::OkStatus();
550 for (int i = 1; i < input_shapes.size(); i++) {
551 if (input_shapes[0].b != input_shapes[i].b &&
552 input_shapes[0].h != input_shapes[i].h &&
553 input_shapes[0].c != input_shapes[i].c) {
554 *axis = Axis::CHANNELS;
555 break;
556 }
557 }
558 if (*axis == Axis::WIDTH) return absl::OkStatus();
559 for (int i = 1; i < input_shapes.size(); i++) {
560 if (input_shapes[0].b != input_shapes[i].b &&
561 input_shapes[0].w != input_shapes[i].w &&
562 input_shapes[0].h != input_shapes[i].h) {
563 return absl::UnimplementedError(
564 "Can concatenate tensors only by batch, height, width, or "
565 "channels.");
566 }
567 }
568 return absl::OkStatus();
569 }
570 };
571
572 class Conv2DOperationParser : public TFLiteOperationParser {
573 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)574 absl::Status IsSupported(const TfLiteContext* context,
575 const TfLiteNode* tflite_node,
576 const TfLiteRegistration* registration) final {
577 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 6));
578 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
579 }
580
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)581 absl::Status Parse(const TfLiteNode* tflite_node,
582 const TfLiteRegistration* registration,
583 GraphFloat32* graph, ObjectReader* reader) final {
584 const TfLiteConvParams* tf_options;
585 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
586 Convolution2DAttributes attr;
587 RETURN_IF_ERROR(ReadAttributes(tflite_node, tf_options, reader, &attr));
588
589 const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
590 if (runtime_inputs == 2) {
591 // weights are second runtime input
592 const TfLiteTensor* src_tensor = reader->GetInputTensor(0);
593 const TfLiteTensor* weights_tensor = reader->GetInputTensor(1);
594 BHWC src_shape, weights_shape;
595 RETURN_IF_ERROR(ExtractTensorShape(*src_tensor, &src_shape));
596 RETURN_IF_ERROR(ExtractTensorShape(*weights_tensor, &weights_shape));
597 if (src_shape.c != weights_shape.c) {
598 return absl::InternalError(
599 "No support of CONVOLUTION_2D with runtime grouped weights.");
600 }
601
602 Node* node = graph->NewNode();
603 node->operation.type = ToString(OperationType::CONVOLUTION_2D);
604 node->operation.attributes = std::move(attr);
605 RETURN_IF_ERROR(reader->AddInput(node, 0));
606 RETURN_IF_ERROR(reader->AddInput(node, 1));
607 RETURN_IF_ERROR(reader->AddOutputs(node));
608 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
609 return absl::OkStatus();
610 } else {
611 // weights are constants
612 const int src_group_size = attr.weights.shape.i;
613 const int dst_group_size = attr.weights.shape.o / attr.groups;
614 const bool supported_grouped_conv =
615 src_group_size % 4 == 0 && dst_group_size % 4 == 0;
616 if (attr.groups != 1 && !supported_grouped_conv) {
617 // Not supported case, replace with usual convolutions:
618 return ResolveGroupedConvolution(attr, tf_options, reader, graph);
619 } else {
620 Node* node = graph->NewNode();
621 node->operation.type = ToString(OperationType::CONVOLUTION_2D);
622 node->operation.attributes = std::move(attr);
623 RETURN_IF_ERROR(reader->AddInput(node, 0));
624 RETURN_IF_ERROR(reader->AddOutputs(node));
625 RETURN_IF_ERROR(
626 MaybeFuseActivation(tf_options->activation, graph, node));
627 return absl::OkStatus();
628 }
629 }
630 }
631
632 private:
ReadAttributes(const TfLiteNode * tflite_node,const TfLiteConvParams * tf_options,ObjectReader * reader,Convolution2DAttributes * attr)633 absl::Status ReadAttributes(const TfLiteNode* tflite_node,
634 const TfLiteConvParams* tf_options,
635 ObjectReader* reader,
636 Convolution2DAttributes* attr) {
637 const TfLiteTensor* src_tensor = reader->GetInputTensor(0);
638 BHWC src_shape;
639 RETURN_IF_ERROR(ExtractTensorShape(*src_tensor, &src_shape));
640 const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
641 if (runtime_inputs == 1) {
642 RETURN_IF_ERROR(reader->ReadTensor(1, &attr->weights));
643 attr->groups = src_shape.c / attr->weights.shape.i;
644 } else {
645 const TfLiteTensor* weights_tensor = reader->GetInputTensor(1);
646 if (!weights_tensor) {
647 return absl::InternalError("Expected second runtime tensor.");
648 }
649 BHWC weights_shape;
650 RETURN_IF_ERROR(ExtractTensorShape(*weights_tensor, &weights_shape));
651 attr->weights.shape = OHWI(weights_shape.b, weights_shape.h,
652 weights_shape.w, weights_shape.c);
653 attr->groups = 1;
654 }
655 reader->ReadTensor(2, &attr->bias).IgnoreError(); // bias is optional
656 attr->strides = ToHW(tf_options->stride_height, tf_options->stride_width);
657 attr->dilations = HW(tf_options->dilation_height_factor,
658 tf_options->dilation_width_factor);
659 UpdatePadding(tf_options->padding, src_shape, attr);
660 return absl::OkStatus();
661 }
662
663 // Replace single grouped convolution(N = groups count) with this sequence:
664 // split input to N tensors in channels dim
665 // N usual convs
666 // concat N tensors to 1 output in channels dim
ResolveGroupedConvolution(const Convolution2DAttributes & attr,const TfLiteConvParams * tf_options,ObjectReader * reader,GraphFloat32 * graph)667 absl::Status ResolveGroupedConvolution(const Convolution2DAttributes& attr,
668 const TfLiteConvParams* tf_options,
669 ObjectReader* reader,
670 GraphFloat32* graph) {
671 const TfLiteTensor* src_tensor = reader->GetInputTensor(0);
672 const TfLiteTensor* dst_tensor = reader->GetOutputTensor(0);
673 BHWC src_shape, dst_shape;
674 RETURN_IF_ERROR(ExtractTensorShape(*src_tensor, &src_shape));
675 RETURN_IF_ERROR(ExtractTensorShape(*dst_tensor, &dst_shape));
676
677 DataType src_type = DataType::FLOAT32;
678 if (src_tensor->type == kTfLiteFloat16) {
679 src_type = DataType::FLOAT16;
680 }
681 DataType dst_type = DataType::FLOAT32;
682 if (dst_tensor->type == kTfLiteFloat16) {
683 dst_type = DataType::FLOAT16;
684 }
685
686 const int src_group_size = attr.weights.shape.i;
687 const int dst_group_size = attr.weights.shape.o / attr.groups;
688
689 Node* split_node = graph->NewNode();
690 RETURN_IF_ERROR(reader->AddInput(split_node, 0));
691 {
692 SplitAttributes split_attr;
693 split_attr.axis = Axis::CHANNELS;
694 split_node->operation.type = ToString(OperationType::SPLIT);
695 split_node->operation.attributes = split_attr;
696 }
697
698 std::vector<Node*> conv_nodes(attr.groups);
699 std::vector<Value*> conv_src(attr.groups);
700 std::vector<Value*> conv_dst(attr.groups);
701 for (int i = 0; i < attr.groups; ++i) {
702 conv_nodes[i] = graph->NewNode();
703 conv_src[i] = graph->NewValue();
704 conv_dst[i] = graph->NewValue();
705 conv_src[i]->tensor.shape = src_shape;
706 conv_src[i]->tensor.type = src_type;
707 conv_src[i]->tensor.shape.c = src_group_size;
708 conv_dst[i]->tensor.shape = dst_shape;
709 conv_dst[i]->tensor.type = dst_type;
710 conv_dst[i]->tensor.shape.c = dst_group_size;
711 Convolution2DAttributes conv_attr;
712 conv_attr = attr;
713 conv_attr.groups = 1;
714 conv_attr.weights.id = -1;
715 conv_attr.weights.shape.o = dst_group_size;
716 conv_attr.weights.data.resize(
717 conv_attr.weights.shape.DimensionsProduct());
718 for (int out_i = 0; out_i < dst_group_size; ++out_i) {
719 for (int in_i = 0; in_i < src_group_size; ++in_i) {
720 for (int ky = 0; ky < attr.weights.shape.h; ++ky) {
721 for (int kx = 0; kx < attr.weights.shape.w; ++kx) {
722 const int src_index = attr.weights.shape.LinearIndex(
723 {{i * dst_group_size + out_i, ky, kx, in_i}});
724 const int dst_index =
725 conv_attr.weights.shape.LinearIndex({{out_i, ky, kx, in_i}});
726 conv_attr.weights.data[dst_index] = attr.weights.data[src_index];
727 }
728 }
729 }
730 }
731 conv_attr.bias.shape.v = dst_group_size;
732 conv_attr.bias.data.resize(conv_attr.bias.shape.DimensionsProduct());
733 for (int out_i = 0; out_i < dst_group_size; ++out_i) {
734 if (i * dst_group_size + out_i < attr.bias.data.size()) {
735 conv_attr.bias.data[out_i] =
736 attr.bias.data[i * dst_group_size + out_i];
737 } else {
738 conv_attr.bias.data[out_i] = 0.0f;
739 }
740 }
741 conv_nodes[i]->operation.type = ToString(OperationType::CONVOLUTION_2D);
742 conv_nodes[i]->operation.attributes = conv_attr;
743
744 RETURN_IF_ERROR(graph->SetProducer(split_node->id, conv_src[i]->id));
745 RETURN_IF_ERROR(graph->AddConsumer(conv_nodes[i]->id, conv_src[i]->id));
746 RETURN_IF_ERROR(graph->SetProducer(conv_nodes[i]->id, conv_dst[i]->id));
747 }
748
749 Node* concat_node = graph->NewNode();
750 {
751 ConcatAttributes concat_attr;
752 concat_attr.axis = Axis::CHANNELS;
753 concat_node->operation.type = ToString(OperationType::CONCAT);
754 concat_node->operation.attributes = concat_attr;
755 }
756 for (int i = 0; i < attr.groups; ++i) {
757 RETURN_IF_ERROR(graph->AddConsumer(concat_node->id, conv_dst[i]->id));
758 }
759 RETURN_IF_ERROR(reader->AddOutputs(concat_node));
760 RETURN_IF_ERROR(
761 MaybeFuseActivation(tf_options->activation, graph, concat_node));
762 return absl::OkStatus();
763 }
764 };
765
766 class CumsumOperationParser : public TFLiteOperationParser {
767 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)768 absl::Status IsSupported(const TfLiteContext* context,
769 const TfLiteNode* tflite_node,
770 const TfLiteRegistration* registration) final {
771 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
772 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
773 }
774
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)775 absl::Status Parse(const TfLiteNode* tflite_node,
776 const TfLiteRegistration* registration,
777 GraphFloat32* graph, ObjectReader* reader) final {
778 Node* node = graph->NewNode();
779 CumsumAttributes attr;
780 const TfLiteTensor* input_tensor = reader->GetInputTensor(0);
781 const TfLiteTensor* axis_tensor = reader->GetInputTensor(1);
782 const TfLiteIntArray* shape = input_tensor->dims;
783 const int tflite_axis = GetTensorData<int32_t>(axis_tensor)[0];
784 const Axis axes[4] = {Axis::BATCH, Axis::WIDTH, Axis::HEIGHT,
785 Axis::CHANNELS};
786 attr.axis = axes[tflite_axis + 4 - shape->size];
787 node->operation.type = ToString(OperationType::CUMSUM);
788 Tensor<BHWC, DataType::FLOAT32> inputs;
789 node->operation.attributes = std::move(attr);
790 RETURN_IF_ERROR(reader->AddInput(node, 0));
791 RETURN_IF_ERROR(reader->AddOutputs(node));
792 return absl::OkStatus();
793 }
794 };
795
796 // Doesn't have a kernel implementation.
797 class DensifyOperationParser : public TFLiteOperationParser {
798 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)799 absl::Status IsSupported(const TfLiteContext* context,
800 const TfLiteNode* tflite_node,
801 const TfLiteRegistration* registration) final {
802 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
803 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
804 }
805
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)806 absl::Status Parse(const TfLiteNode* tflite_node,
807 const TfLiteRegistration* registration,
808 GraphFloat32* graph, ObjectReader* reader) final {
809 Node* node = graph->NewNode();
810 node->operation.type = ToString(OperationType::DENSIFY);
811 const TfLiteTensor* const_tensor = reader->GetInputTensor(0);
812 if (!const_tensor->sparsity) {
813 return absl::InvalidArgumentError("Input tensor must be sparse.");
814 }
815 TensorFloat32 sparse_tensor;
816 RETURN_IF_ERROR(reader->ReadTensor(0, &sparse_tensor));
817 DensifyAttributes attributes;
818 attributes.tensor = std::move(sparse_tensor);
819 node->operation.attributes = attributes;
820 return reader->AddOutputs(node);
821 }
822 };
823
824 class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
825 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)826 absl::Status IsSupported(const TfLiteContext* context,
827 const TfLiteNode* tflite_node,
828 const TfLiteRegistration* registration) final {
829 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 6));
830 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
831 }
832
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)833 absl::Status Parse(const TfLiteNode* tflite_node,
834 const TfLiteRegistration* registration,
835 GraphFloat32* graph, ObjectReader* reader) final {
836 Node* node = graph->NewNode();
837 node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION);
838 RETURN_IF_ERROR(reader->AddInput(node, 0));
839 RETURN_IF_ERROR(reader->AddOutputs(node));
840
841 DepthwiseConvolution2DAttributes attr;
842 const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
843 if (runtime_inputs == 2) {
844 RETURN_IF_ERROR(reader->AddInput(node, 1));
845 auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape;
846 attr.weights.shape = OHWI(weights_shape.b, weights_shape.h,
847 weights_shape.w, weights_shape.c);
848 } else { // runtime_inputs == 1;
849 RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
850 }
851 reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
852 const TfLiteDepthwiseConvParams* tf_options;
853 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
854 attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
855 attr.dilations = HW(std::max(1, tf_options->dilation_height_factor),
856 std::max(1, tf_options->dilation_width_factor));
857 UpdatePadding(tf_options->padding,
858 graph->FindInputs(node->id)[0]->tensor.shape, &attr);
859 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
860 const int depth_multiplier = tf_options->depth_multiplier;
861 if (depth_multiplier != 1) {
862 const TfLiteTensor* input = reader->GetInputTensor(0);
863 const TfLiteTensor* filter = reader->GetInputTensor(1);
864 const TfLiteTensor* output = reader->GetOutputTensor(0);
865 TransposeWeights(input, filter, output, depth_multiplier, &attr);
866 }
867 node->operation.attributes = std::move(attr);
868 return absl::OkStatus();
869 }
870
871 private:
872 // TFLite CPU stores weights as:
873 // [1, kernel_height, kernel_width, input_depth * depth_multiplier]
874 // TFLite GPU stores weights as:
875 // [depth_multiplier, kernel_height, kernel_width, input_depth]
TransposeWeights(const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * output,int depth_multiplier,DepthwiseConvolution2DAttributes * attr)876 static void TransposeWeights(const TfLiteTensor* input,
877 const TfLiteTensor* filter,
878 const TfLiteTensor* output, int depth_multiplier,
879 DepthwiseConvolution2DAttributes* attr) {
880 const int input_depth = input->dims->data[3];
881 const int filter_height = filter->dims->data[1];
882 const int filter_width = filter->dims->data[2];
883 const int output_depth = output->dims->data[3];
884 Tensor<OHWI, DataType::FLOAT32> weights;
885 weights.id = attr->weights.id;
886 weights.shape =
887 OHWI(output_depth, filter_height, filter_width, input_depth);
888 weights.data.resize(weights.shape.DimensionsProduct());
889 float* dst = &weights.data[0];
890 for (int j = 0; j < output_depth; ++j) {
891 const float* src = attr->weights.data.data() + j;
892 for (int i = 0; i < filter_height * filter_width; ++i) {
893 *dst = *src;
894 dst++;
895 src += output_depth;
896 }
897 }
898 attr->weights = std::move(weights);
899 }
900 };
901
902 class DepthToSpaceOperationParser : public TFLiteOperationParser {
903 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)904 absl::Status IsSupported(const TfLiteContext* context,
905 const TfLiteNode* tflite_node,
906 const TfLiteRegistration* registration) final {
907 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
908 }
909
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)910 absl::Status Parse(const TfLiteNode* tflite_node,
911 const TfLiteRegistration* registration,
912 GraphFloat32* graph, ObjectReader* reader) final {
913 Node* node = graph->NewNode();
914 node->operation.type = ToString(OperationType::DEPTH_TO_SPACE);
915 RETURN_IF_ERROR(reader->AddInput(node, 0));
916 RETURN_IF_ERROR(reader->AddOutputs(node));
917 const TfLiteDepthToSpaceParams* tf_options;
918 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
919 SpaceToDepthAttributes attr;
920 attr.block_size = tf_options->block_size;
921 node->operation.attributes = attr;
922 return absl::OkStatus();
923 }
924 };
925
926 class DequantizeOperationParser : public TFLiteOperationParser {
927 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)928 absl::Status IsSupported(const TfLiteContext* context,
929 const TfLiteNode* tflite_node,
930 const TfLiteRegistration* registration) final {
931 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
932 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
933 }
934
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)935 absl::Status Parse(const TfLiteNode* tflite_node,
936 const TfLiteRegistration* registration,
937 GraphFloat32* graph, ObjectReader* reader) final {
938 // 'Dequantize' is rewritten as QuantizeAndDequantize since we are dealing
939 // with floating-point versions of the original tensors.
940 const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
941 if (runtime_inputs == 0) {
942 // constant input, can be dequantized here
943 ConstTensorAttributes attr;
944 RETURN_IF_ERROR(reader->ReadTensor(0, &attr.tensor));
945 Node* node = graph->NewNode();
946 node->operation.attributes = attr;
947 node->operation.type = ToString(OperationType::CONSTANT);
948 return reader->AddOutputs(node);
949 }
950 Node* node = graph->NewNode();
951 node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
952 // Non-constant dequantization.
953 RETURN_IF_ERROR(reader->AddInput(node, 0));
954 RETURN_IF_ERROR(reader->AddOutputs(node));
955
956 // Quantization attributes should already be present in the input tensor.
957 auto input_value = graph->FindInputs(node->id)[0];
958 if (!input_value->quant_params) {
959 if (runtime_inputs == 1) {
960 // DEQUANTIZE op is preceded by DENSIFY op and doesn't have any
961 // quantization params. The DEQUANTIZE op latter will be removed from
962 // the graph in `MergeDensify` graph transformation.
963 return absl::OkStatus();
964 }
965 return absl::InvalidArgumentError(
966 "Encountered Dequantize input with no quant params");
967 }
968 QuantizeAndDequantizeAttributes attr;
969 attr.min = input_value->quant_params.value().min;
970 attr.max = input_value->quant_params.value().max;
971 attr.scale = input_value->quant_params.value().scale;
972
973 node->operation.attributes = attr;
974 return absl::OkStatus();
975 }
976 };
977
978 class ElementwiseOperationParser : public TFLiteOperationParser {
979 public:
ElementwiseOperationParser(OperationType operation_type)980 explicit ElementwiseOperationParser(OperationType operation_type)
981 : operation_type_(operation_type) {}
982
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)983 absl::Status IsSupported(const TfLiteContext* context,
984 const TfLiteNode* tflite_node,
985 const TfLiteRegistration* registration) final {
986 const int kMaxSupportedOpVersion =
987 operation_type_ == OperationType::MUL ? 3 : 2;
988 RETURN_IF_ERROR(
989 CheckMaxSupportedOpVersion(registration, kMaxSupportedOpVersion));
990 if (IsLogicalOp(operation_type_)) {
991 TensorInfo output_tensor_info;
992 RETURN_IF_ERROR(GetTensorInfo(context, tflite_node->outputs->data[0],
993 &output_tensor_info));
994 if (output_tensor_info.producers.size() != 1 ||
995 output_tensor_info.consumers.size() != 1) {
996 return absl::UnavailableError("Not supported logical op case");
997 }
998 const auto& next_node = output_tensor_info.consumers[0];
999 TfLiteType dst_type =
1000 context->tensors[next_node.first->outputs->data[0]].type;
1001 if (next_node.second->builtin_code == kTfLiteBuiltinCast &&
1002 (dst_type == kTfLiteFloat16 || dst_type == kTfLiteFloat32)) {
1003 return absl::OkStatus();
1004 } else {
1005 return absl::UnimplementedError("Not supported logical op case.");
1006 }
1007 }
1008 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1009 }
1010
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1011 absl::Status Parse(const TfLiteNode* tflite_node,
1012 const TfLiteRegistration* registration,
1013 GraphFloat32* graph, ObjectReader* reader) final {
1014 Node* node = graph->NewNode();
1015 node->operation.type = ToString(operation_type_);
1016 if (operation_type_ == OperationType::ADD) {
1017 ElementwiseAttributes attr;
1018 node->operation.attributes = std::move(attr);
1019 }
1020
1021 if (IsOneArgumentOperation()) {
1022 RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
1023 /*runtime_inputs=*/1,
1024 /*const_inputs=*/0,
1025 /*outputs=*/1));
1026
1027 RETURN_IF_ERROR(reader->AddInput(node, 0));
1028 } else if (IsTwoArgumentOperation() &&
1029 reader
1030 ->VerifyInputsConstsOutputs(tflite_node,
1031 /*runtime_inputs=*/2,
1032 /*const_inputs=*/0,
1033 /*outputs=*/1)
1034 .ok()) {
1035 if (tflite_node->inputs->size != 2) {
1036 return absl::InvalidArgumentError("Applies only two input tensors");
1037 }
1038 const TfLiteTensor* input0 = reader->GetInputTensor(0);
1039 const TfLiteTensor* input1 = reader->GetInputTensor(1);
1040
1041 // TODO(b/166831113): Support the same inputs for operations.
1042 if (input0 == input1) {
1043 if (operation_type_ == OperationType::MUL) {
1044 // replace MUL(A, A) with SQUARE(A)
1045 node->operation.type = ToString(OperationType::SQUARE);
1046 RETURN_IF_ERROR(reader->AddInput(node, 0));
1047 } else if (operation_type_ == OperationType::ADD) {
1048 // replace ADD(A, A) with MUL(A, 2.0)
1049 node->operation.type = ToString(OperationType::MUL);
1050 ElementwiseAttributes attr;
1051 attr.param = 2.0f;
1052 node->operation.attributes = std::move(attr);
1053 RETURN_IF_ERROR(reader->AddInput(node, 0));
1054 } else {
1055 return absl::UnimplementedError(
1056 "No support of few identical inputs in the same operation.");
1057 }
1058 } else {
1059 int input_tensor0 = 0;
1060 int input_tensor1 = 1;
1061 if (operation_type_ == OperationType::MUL ||
1062 operation_type_ == OperationType::ADD) {
1063 // The "larger" input tensor must be bound to 1st input and the
1064 // "smaller" input tensor must be bound to 2nd input.
1065 BHWC shape0;
1066 RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
1067 BHWC shape1;
1068 RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1));
1069 if (shape0.h <= shape1.h && shape0.w <= shape1.w &&
1070 shape0.c == shape1.c) {
1071 input_tensor0 = 1;
1072 input_tensor1 = 0;
1073 }
1074 }
1075
1076 RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
1077 RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
1078 }
1079 } else if (IsTwoArgumentOperationWithConst()) {
1080 RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
1081 /*runtime_inputs=*/1,
1082 /*const_inputs=*/1,
1083 /*outputs=*/1));
1084 ElementwiseAttributes attr;
1085 RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
1086 attr.runtime_tensor_is_second =
1087 IsConstantTensor(reader->GetInputTensor(0));
1088 node->operation.attributes = std::move(attr);
1089 } else {
1090 return absl::InvalidArgumentError("Incorrect operation type passed");
1091 }
1092
1093 RETURN_IF_ERROR(reader->AddOutputs(node));
1094 return MaybeFuseActivationForElementwiseNode(operation_type_, tflite_node,
1095 graph, node);
1096 }
1097
1098 private:
GetActivation(const TfLiteNode * tflite_node,TfLiteFusedActivation * activation) const1099 absl::Status GetActivation(const TfLiteNode* tflite_node,
1100 TfLiteFusedActivation* activation) const {
1101 if (operation_type_ == OperationType::DIV) {
1102 const TfLiteDivParams* tf_options;
1103 auto status = RetrieveBuiltinData(tflite_node, &tf_options);
1104 *activation = status.ok() ? tf_options->activation : kTfLiteActNone;
1105 return absl::OkStatus();
1106 }
1107 if (operation_type_ == OperationType::SUB) {
1108 const TfLiteSubParams* tf_options;
1109 auto status = RetrieveBuiltinData(tflite_node, &tf_options);
1110 *activation = status.ok() ? tf_options->activation : kTfLiteActNone;
1111 return absl::OkStatus();
1112 }
1113
1114 // Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or
1115 // TfLiteXxxParams.activation.
1116 *activation = kTfLiteActNone;
1117 return absl::OkStatus();
1118 }
1119
IsOneArgumentOperation() const1120 bool IsOneArgumentOperation() const {
1121 switch (operation_type_) {
1122 case OperationType::ABS:
1123 case OperationType::COPY:
1124 case OperationType::COS:
1125 case OperationType::ELU:
1126 case OperationType::EXP:
1127 case OperationType::FLOOR:
1128 case OperationType::LOG:
1129 case OperationType::NEG:
1130 case OperationType::RSQRT:
1131 case OperationType::SIGMOID:
1132 case OperationType::SIN:
1133 case OperationType::SQRT:
1134 case OperationType::SQUARE:
1135 case OperationType::TANH:
1136 return true;
1137 default:
1138 return false;
1139 }
1140 }
1141
IsTwoArgumentOperation() const1142 bool IsTwoArgumentOperation() const {
1143 switch (operation_type_) {
1144 case OperationType::ADD:
1145 case OperationType::DIV:
1146 case OperationType::EQUAL:
1147 case OperationType::FLOOR_DIV:
1148 case OperationType::FLOOR_MOD:
1149 case OperationType::GREATER:
1150 case OperationType::GREATER_EQUAL:
1151 case OperationType::LESS:
1152 case OperationType::LESS_EQUAL:
1153 case OperationType::MAXIMUM:
1154 case OperationType::MINIMUM:
1155 case OperationType::MUL:
1156 case OperationType::NOT_EQUAL:
1157 case OperationType::POW:
1158 case OperationType::SQUARED_DIFF:
1159 case OperationType::SUB:
1160 return true;
1161 default:
1162 return false;
1163 }
1164 }
1165
IsTwoArgumentOperationWithConst() const1166 bool IsTwoArgumentOperationWithConst() const {
1167 switch (operation_type_) {
1168 case OperationType::ADD:
1169 case OperationType::DIV:
1170 case OperationType::EQUAL:
1171 case OperationType::FLOOR_DIV:
1172 case OperationType::FLOOR_MOD:
1173 case OperationType::GREATER:
1174 case OperationType::GREATER_EQUAL:
1175 case OperationType::LESS:
1176 case OperationType::LESS_EQUAL:
1177 case OperationType::MAXIMUM:
1178 case OperationType::MINIMUM:
1179 case OperationType::MUL:
1180 case OperationType::NOT_EQUAL:
1181 case OperationType::POW:
1182 case OperationType::SQUARED_DIFF:
1183 case OperationType::SUB:
1184 return true;
1185 default:
1186 return false;
1187 }
1188 }
1189
1190 OperationType operation_type_;
1191 };
1192
1193 class FullyConnectedOperationParser : public TFLiteOperationParser {
1194 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1195 absl::Status IsSupported(const TfLiteContext* context,
1196 const TfLiteNode* tflite_node,
1197 const TfLiteRegistration* registration) final {
1198 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 9));
1199 // TODO(eignasheva): check input shape
1200 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1201 }
1202
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1203 absl::Status Parse(const TfLiteNode* tflite_node,
1204 const TfLiteRegistration* registration,
1205 GraphFloat32* graph, ObjectReader* reader) final {
1206 const TfLiteFullyConnectedParams* tf_options;
1207 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1208
1209 if (reader->GetNumberOfRuntimeInputs() == 2) {
1210 // Create Convolution2D, so as it supports runtime weights.
1211 Node* node = graph->NewNode();
1212 node->operation.type = ToString(OperationType::CONVOLUTION_2D);
1213 RETURN_IF_ERROR(reader->AddInput(node, 0));
1214 RETURN_IF_ERROR(reader->AddInput(node, 1));
1215
1216 const TfLiteTensor* input_tensor = reader->GetInputTensor(0);
1217 BHWC input_shape;
1218 RETURN_IF_ERROR(ExtractTensorShape(*input_tensor, &input_shape));
1219 const TfLiteTensor* input2_tensor = reader->GetInputTensor(1);
1220 BHWC input2_shape;
1221 RETURN_IF_ERROR(ExtractTensorShape(*input2_tensor, &input2_shape));
1222 const TfLiteTensor* output_tensor = reader->GetOutputTensor(0);
1223 BHWC output_shape;
1224 RETURN_IF_ERROR(ExtractTensorShape(*output_tensor, &output_shape));
1225 BHWC output_ref_shape = input_shape;
1226 output_ref_shape.c = input2_shape.b;
1227 if (output_ref_shape != output_shape) {
1228 Value* copy_value = graph->NewValue();
1229 auto input_value = graph->FindInputs(node->id)[0];
1230 copy_value->tensor.type = input_value->tensor.type;
1231 copy_value->tensor.shape = output_ref_shape;
1232 Node* node_reshape = graph->NewNode();
1233 node_reshape->operation.type = ToString(OperationType::RESHAPE);
1234 ReshapeAttributes reshape_attr;
1235 reshape_attr.new_shape = output_shape;
1236 node_reshape->operation.attributes = reshape_attr;
1237 RETURN_IF_ERROR(graph->SetProducer(node->id, copy_value->id));
1238 RETURN_IF_ERROR(graph->AddConsumer(node_reshape->id, copy_value->id));
1239 RETURN_IF_ERROR(reader->AddOutputs(node_reshape));
1240 } else {
1241 RETURN_IF_ERROR(reader->AddOutputs(node));
1242 }
1243
1244 Convolution2DAttributes attr;
1245 reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
1246
1247 attr.strides = HW(1, 1);
1248 attr.dilations = HW(1, 1);
1249 attr.padding.appended = HW(0, 0);
1250 attr.padding.prepended = HW(0, 0);
1251 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
1252 node->operation.attributes = std::move(attr);
1253 return absl::OkStatus();
1254 }
1255 Node* node = graph->NewNode();
1256 RETURN_IF_ERROR(reader->AddInput(node, 0));
1257
1258 if (tf_options->weights_format !=
1259 kTfLiteFullyConnectedWeightsFormatDefault) {
1260 return absl::UnimplementedError(
1261 "Unsupported FullyConnected weights format.");
1262 }
1263
1264 FullyConnectedAttributes attr;
1265 RETURN_IF_ERROR(GetFullyConnectedAttributes(1, 2, reader, &attr));
1266
1267 auto input = graph->FindInputs(node->id)[0];
1268 if (input->tensor.shape.c != attr.weights.shape.i) {
1269 return absl::UnimplementedError(
1270 "Amount of input channels should match weights width");
1271 }
1272
1273 Node* conv = node;
1274 if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
1275 // In Gpu delegates assume that height and width = 1 for FullyConnected
1276 // Using usual convolution2d when height or width != 1
1277 Convolution2DAttributes conv_attr;
1278 conv_attr.strides = HW(1, 1);
1279 conv_attr.dilations = HW(1, 1);
1280 conv_attr.padding.appended = HW(0, 0);
1281 conv_attr.padding.prepended = HW(0, 0);
1282 conv_attr.weights = attr.weights;
1283 conv_attr.bias = attr.bias;
1284 conv->operation.type = ToString(OperationType::CONVOLUTION_2D);
1285 conv->operation.attributes = std::move(conv_attr);
1286 } else {
1287 conv->operation.type = ToString(OperationType::FULLY_CONNECTED);
1288 conv->operation.attributes = std::move(attr);
1289 }
1290 RETURN_IF_ERROR(reader->AddOutputs(conv));
1291 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, conv));
1292 return absl::OkStatus();
1293 }
1294 };
1295
1296 class HardSwishOperationParser : public TFLiteOperationParser {
1297 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1298 absl::Status IsSupported(const TfLiteContext* context,
1299 const TfLiteNode* tflite_node,
1300 const TfLiteRegistration* registration) final {
1301 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1302 }
1303
Parse(const TfLiteNode *,const TfLiteRegistration *,GraphFloat32 * graph,ObjectReader * reader)1304 absl::Status Parse(const TfLiteNode*, const TfLiteRegistration*,
1305 GraphFloat32* graph, ObjectReader* reader) final {
1306 Node* node = graph->NewNode();
1307 node->operation.type = ToString(OperationType::HARD_SWISH);
1308 RETURN_IF_ERROR(reader->AddInput(node, 0));
1309 return reader->AddOutputs(node);
1310 }
1311 };
1312
1313 // Basic LSTM Cell:
1314 //
1315 // 1name = name is at input index 1
1316 // name1 = name is at output index 1
1317 //
1318 // 0input 1prev_activ
1319 // \ /
1320 // [[concat]]
1321 // \
1322 // concat_temp2 2weights 3biases
1323 // \ / /
1324 // [[fully-connected]]
1325 // \
1326 // activ_temp3 4prev_state
1327 // \ /
1328 // [[LSTM]]
1329 // / \
1330 // new_state1 activation0
1331 //
1332 // For full LSTM cells, see this blog post:
1333 // https://colah.github.io/posts/2015-08-Understanding-LSTMs/
1334 // In addition to Peephole connections and Combined Input Forget Gates (CIFG)
1335 // described in that post, this code also adds the following optional features:
1336 // - Configurable activations (sigmoid or TANH)
1337 // - L2 Normalization of gates: https://arxiv.org/abs/1607.06450
1338 // - Output projection:
1339 // https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html
1340 // - Configurable clipping of cell state and output state.
1341 class LSTMOperationParser : public TFLiteOperationParser {
1342 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1343 absl::Status IsSupported(const TfLiteContext* context,
1344 const TfLiteNode* tflite_node,
1345 const TfLiteRegistration* registration) final {
1346 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 4));
1347 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1348 }
1349
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1350 absl::Status Parse(const TfLiteNode* tflite_node,
1351 const TfLiteRegistration* registration,
1352 GraphFloat32* graph, ObjectReader* reader) final {
1353 const TfLiteLSTMParams* tf_options;
1354 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1355 switch (tf_options->kernel_type) {
1356 case kTfLiteLSTMFullKernel:
1357 return ParseFull(tflite_node, registration, graph, reader, tf_options);
1358 case kTfLiteLSTMBasicKernel:
1359 return ParseBasic(tflite_node, registration, graph, reader, tf_options);
1360 }
1361 }
1362
GetNewValueIdsForVariableInputNodes()1363 absl::flat_hash_map<int, ValueId> GetNewValueIdsForVariableInputNodes()
1364 final {
1365 return new_variable_input_value_map_;
1366 }
1367
1368 private:
ParseBasic(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * tf_options)1369 absl::Status ParseBasic(const TfLiteNode* tflite_node,
1370 const TfLiteRegistration* registration,
1371 GraphFloat32* graph, ObjectReader* reader,
1372 const TfLiteLSTMParams* tf_options) {
1373 if (tflite_node->inputs->size != 5) {
1374 return absl::InvalidArgumentError("LSTM should have 5 input tensors");
1375 }
1376 if (tflite_node->outputs->size != 4) {
1377 return absl::InvalidArgumentError("LSTM should have 4 output tensors");
1378 }
1379 RETURN_IF_ERROR(CheckBasicParameters(tf_options));
1380
1381 Node* concat_node = graph->NewNode();
1382 concat_node->operation.type = ToString(OperationType::CONCAT);
1383 ConcatAttributes concat_attr;
1384 concat_attr.axis = Axis::CHANNELS;
1385 concat_node->operation.attributes = concat_attr;
1386
1387 Node* fc_node = graph->NewNode();
1388 fc_node->operation.type = ToString(OperationType::FULLY_CONNECTED);
1389 FullyConnectedAttributes fc_attr;
1390 RETURN_IF_ERROR(GetFullyConnectedAttributes(2, 3, reader, &fc_attr));
1391 fc_node->operation.attributes = std::move(fc_attr);
1392
1393 Node* lstm_node = graph->NewNode();
1394 lstm_node->operation.type = ToString(OperationType::LSTM);
1395 LstmAttributes lstm_attr;
1396 lstm_attr.kernel_type = LstmKernelType::BASIC;
1397 lstm_node->operation.attributes = lstm_attr;
1398
1399 Value* concat_temp;
1400 int concat_tensor_idx = tflite_node->outputs->data[2];
1401 RETURN_IF_ERROR(
1402 reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
1403 Value* activ_temp;
1404 int activ_tensor_idx = tflite_node->outputs->data[3];
1405 RETURN_IF_ERROR(
1406 reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
1407
1408 RETURN_IF_ERROR(reader->AddInput(concat_node, 0)); // input
1409 RETURN_IF_ERROR(reader->AddInput(concat_node, 1)); // prev_activ
1410 RETURN_IF_ERROR(graph->SetProducer(concat_node->id, concat_temp->id));
1411
1412 RETURN_IF_ERROR(graph->AddConsumer(fc_node->id, concat_temp->id));
1413 RETURN_IF_ERROR(graph->SetProducer(fc_node->id, activ_temp->id));
1414
1415 RETURN_IF_ERROR(graph->AddConsumer(lstm_node->id, activ_temp->id));
1416 RETURN_IF_ERROR(reader->AddInput(lstm_node, 4)); // prev_state
1417 RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1)); // new_state
1418 RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0)); // activation
1419
1420 return absl::OkStatus();
1421 }
1422
CheckBasicParameters(const TfLiteLSTMParams * tf_options)1423 absl::Status CheckBasicParameters(const TfLiteLSTMParams* tf_options) {
1424 if (tf_options->activation != kTfLiteActTanh) {
1425 return absl::UnimplementedError("Only TANH activation is supported.");
1426 }
1427 if (tf_options->cell_clip != 0.0f) {
1428 return absl::UnimplementedError("cell_clip is not supported.");
1429 }
1430 if (tf_options->proj_clip != 0.0f) {
1431 return absl::UnimplementedError("proj_clip is not supported.");
1432 }
1433 return absl::OkStatus();
1434 }
1435
ParseFull(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * tf_options)1436 absl::Status ParseFull(const TfLiteNode* tflite_node,
1437 const TfLiteRegistration* registration,
1438 GraphFloat32* graph, ObjectReader* reader,
1439 const TfLiteLSTMParams* tf_options) {
1440 // Invoke full LSTM parser
1441 RETURN_IF_ERROR(ParseLSTMAttributes(tflite_node, registration, graph,
1442 reader, tf_options,
1443 &new_variable_input_value_map_));
1444 return absl::OkStatus();
1445 }
1446
CheckFullParameters(const TfLiteLSTMParams * tf_options)1447 absl::Status CheckFullParameters(const TfLiteLSTMParams* tf_options) {
1448 if (tf_options->activation != kTfLiteActSigmoid &&
1449 tf_options->activation != kTfLiteActTanh) {
1450 return absl::UnimplementedError(
1451 "Only sigmoid or tanh activation is supported.");
1452 }
1453
1454 return absl::OkStatus();
1455 }
1456
1457 absl::flat_hash_map<int, ValueId> new_variable_input_value_map_;
1458 };
1459
1460 class OneHotOperationParser : public TFLiteOperationParser {
1461 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1462 absl::Status IsSupported(const TfLiteContext* context,
1463 const TfLiteNode* tflite_node,
1464 const TfLiteRegistration* registration) final {
1465 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1466 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1467 }
1468
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1469 absl::Status Parse(const TfLiteNode* tflite_node,
1470 const TfLiteRegistration* registration,
1471 GraphFloat32* graph, ObjectReader* reader) final {
1472 Node* node = graph->NewNode();
1473 OneHotAttributes attr;
1474 const TfLiteTensor* on_tensor = reader->GetInputTensor(2);
1475 const TfLiteTensor* off_tensor = reader->GetInputTensor(3);
1476 attr.on_value = GetTensorData<float>(on_tensor)[0];
1477 attr.off_value = GetTensorData<float>(off_tensor)[0];
1478 node->operation.type = ToString(OperationType::ONE_HOT);
1479 node->operation.attributes = std::move(attr);
1480 RETURN_IF_ERROR(reader->AddInput(node, 0));
1481 RETURN_IF_ERROR(reader->AddOutputs(node));
1482 return absl::OkStatus();
1483 }
1484 };
1485
1486 class PackOperationParser : public TFLiteOperationParser {
1487 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1488 absl::Status IsSupported(const TfLiteContext* context,
1489 const TfLiteNode* tflite_node,
1490 const TfLiteRegistration* registration) final {
1491 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1492 }
1493
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1494 absl::Status Parse(const TfLiteNode* tflite_node,
1495 const TfLiteRegistration* registration,
1496 GraphFloat32* graph, ObjectReader* reader) final {
1497 if (tflite_node->inputs->size == 1) {
1498 // Pack with single input can be replaced with Reshape
1499 Node* node = graph->NewNode();
1500 node->operation.type = ToString(OperationType::RESHAPE);
1501 RETURN_IF_ERROR(reader->AddInput(node, 0));
1502 RETURN_IF_ERROR(reader->AddOutputs(node));
1503 // New shape comes from output shape.
1504 ReshapeAttributes attr;
1505 attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1506 node->operation.attributes = attr;
1507 return absl::OkStatus();
1508 } else {
1509 // Pack with few inputs can be replaced with Concat
1510 const TfLitePackParams* tf_options;
1511 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1512
1513 // Read inputs first to make sure const node is added to a graph before
1514 // concat node to ensure topological order.
1515 std::vector<const Value*> inputs;
1516 for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
1517 Value* value;
1518 const auto status = reader->ReadValue(idx, &value);
1519 if (status.ok()) {
1520 inputs.push_back(value);
1521 } else {
1522 TensorFloat32 tensor;
1523 RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
1524 Value* value;
1525 RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
1526 inputs.push_back(value);
1527 }
1528 }
1529
1530 const TfLiteTensor* output = reader->GetOutputTensor(0);
1531 ConcatAttributes attr;
1532 RETURN_IF_ERROR(
1533 ExtractAxisFromIndex(*output, tf_options->axis, &attr.axis));
1534 BHWC output_shape;
1535 RETURN_IF_ERROR(ExtractTensorShape(*output, &output_shape));
1536 BHWC input_required_shape = output_shape;
1537 input_required_shape.set(attr.axis, 1);
1538 for (int i = 0; i < inputs.size(); ++i) {
1539 BHWC input_shape = inputs[i]->tensor.shape;
1540 if (input_shape != input_required_shape) {
1541 // GPU delegates does not support implicit shapes transformations
1542 // adding explicit Reshape
1543 Node* node_reshape = graph->NewNode();
1544 node_reshape->operation.type = ToString(OperationType::RESHAPE);
1545 ReshapeAttributes reshape_attr;
1546 reshape_attr.new_shape = input_required_shape;
1547 node_reshape->operation.attributes = reshape_attr;
1548 RETURN_IF_ERROR(graph->AddConsumer(node_reshape->id, inputs[i]->id));
1549 Value* copy_value = graph->NewValue();
1550 copy_value->tensor.type = inputs[i]->tensor.type;
1551 copy_value->tensor.shape = input_required_shape;
1552 RETURN_IF_ERROR(graph->SetProducer(node_reshape->id, copy_value->id));
1553 inputs[i] = copy_value;
1554 }
1555 }
1556
1557 Node* node = graph->NewNode();
1558 node->operation.type = ToString(OperationType::CONCAT);
1559 RETURN_IF_ERROR(reader->AddOutputs(node));
1560 for (const Value* input : inputs) {
1561 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1562 }
1563 node->operation.attributes = attr;
1564 return absl::OkStatus();
1565 }
1566 }
1567 };
1568
1569 class PReLUOperationParser : public TFLiteOperationParser {
1570 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1571 absl::Status IsSupported(const TfLiteContext* context,
1572 const TfLiteNode* tflite_node,
1573 const TfLiteRegistration* registration) final {
1574 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1575 // TODO(eignasheva): add params check
1576 return absl::OkStatus();
1577 }
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1578 absl::Status Parse(const TfLiteNode* tflite_node,
1579 const TfLiteRegistration* registration,
1580 GraphFloat32* graph, ObjectReader* reader) final {
1581 Node* node = graph->NewNode();
1582 node->operation.type = ToString(OperationType::PRELU);
1583 RETURN_IF_ERROR(reader->AddInput(node, 0));
1584 auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1585
1586 PReLUAttributes attr;
1587 Tensor<Linear, DataType::FLOAT32> linear_alpha;
1588 absl::Status status = reader->ReadTensor(1, &linear_alpha);
1589 if (status.ok()) {
1590 if (linear_alpha.shape.v != input_shape.c) {
1591 return absl::InvalidArgumentError(
1592 "Linear alpha shape does not match the number of input channels.");
1593 }
1594 attr.alpha = std::move(linear_alpha);
1595 } else {
1596 Tensor<HWC, DataType::FLOAT32> hwc_alpha;
1597 RETURN_IF_ERROR(reader->ReadTensor(1, &hwc_alpha));
1598 if (hwc_alpha.shape.h != input_shape.h ||
1599 hwc_alpha.shape.w != input_shape.w ||
1600 hwc_alpha.shape.c != input_shape.c) {
1601 return absl::InvalidArgumentError(
1602 "Alpha shape does not match input shape.");
1603 }
1604 attr.alpha = std::move(hwc_alpha);
1605 }
1606 node->operation.attributes = std::move(attr);
1607 return reader->AddOutputs(node);
1608 }
1609 };
1610
1611 class PadOperationParser : public TFLiteOperationParser {
1612 public:
PadOperationParser(bool mirror_pad)1613 explicit PadOperationParser(bool mirror_pad) : mirror_pad_(mirror_pad) {}
1614
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1615 absl::Status IsSupported(const TfLiteContext* context,
1616 const TfLiteNode* tflite_node,
1617 const TfLiteRegistration* registration) final {
1618 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1619 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1620 }
1621
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1622 absl::Status Parse(const TfLiteNode* tflite_node,
1623 const TfLiteRegistration* registration,
1624 GraphFloat32* graph, ObjectReader* reader) final {
1625 Node* node = graph->NewNode();
1626 node->operation.type = ToString(OperationType::PAD);
1627 RETURN_IF_ERROR(reader->AddInput(node, 0));
1628 RETURN_IF_ERROR(reader->AddOutputs(node));
1629
1630 PadAttributes attr;
1631 if (mirror_pad_) {
1632 attr.type = PaddingContentType::REFLECT;
1633 } else /*zero pad*/ {
1634 attr.type = PaddingContentType::ZEROS;
1635 }
1636
1637 Tensor<HW, DataType::INT32> paddings;
1638 RETURN_IF_ERROR(reader->ReadTensor(1, &paddings));
1639
1640 if (paddings.shape.h == 4 && paddings.shape.w == 2) {
1641 // 4x2 tensor with paddings.
1642 attr.prepended = BHWC(paddings.data[0], paddings.data[2],
1643 paddings.data[4], paddings.data[6]);
1644 attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5],
1645 paddings.data[7]);
1646 } else if (paddings.shape.h == 3 && paddings.shape.w == 2) {
1647 // 3x2 tensor with paddings.
1648 attr.prepended =
1649 BHWC(1, paddings.data[0], paddings.data[2], paddings.data[4]);
1650 attr.appended =
1651 BHWC(1, paddings.data[1], paddings.data[3], paddings.data[5]);
1652 } else {
1653 // It shouldn't fail here since it's checked at IsSupported().
1654 return absl::InvalidArgumentError(
1655 "Paddings tensor has unexpected shape.");
1656 }
1657 node->operation.attributes = attr;
1658 return absl::OkStatus();
1659 }
1660
1661 private:
1662 bool mirror_pad_ = false;
1663 };
1664
1665 class Pooling2DOperationParser : public TFLiteOperationParser {
1666 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1667 absl::Status IsSupported(const TfLiteContext* context,
1668 const TfLiteNode* tflite_node,
1669 const TfLiteRegistration* registration) final {
1670 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1671 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1672 }
1673
1674 public:
Pooling2DOperationParser(PoolingType type)1675 explicit Pooling2DOperationParser(PoolingType type) : type_(type) {}
1676
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1677 absl::Status Parse(const TfLiteNode* tflite_node,
1678 const TfLiteRegistration* registration,
1679 GraphFloat32* graph, ObjectReader* reader) final {
1680 Node* node = graph->NewNode();
1681 node->operation.type = ToString(OperationType::POOLING_2D);
1682 RETURN_IF_ERROR(reader->AddInput(node, 0));
1683 RETURN_IF_ERROR(reader->AddOutput(node, 0));
1684
1685 Pooling2DAttributes attr;
1686 attr.type = type_;
1687
1688 auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1689
1690 // Check whether there are custom options encoded. It happens if operation
1691 // is MaxPoolingWithArgmax2D. There is no way to read
1692 // tflite_node->builtin_code, so, simply check whether custom data is
1693 // available.
1694 const TfLitePoolParams* tf_options;
1695 if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) {
1696 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1697 }
1698
1699 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
1700 // Second output is optional. It is not required, it but must be added after
1701 // MaybeAddFusedActivation function is called
1702 reader->AddOutput(node, 1).IgnoreError();
1703
1704 // First output is the result of pooling operation, while second output is
1705 // indices used for pooling.
1706 auto outputs = graph->FindOutputs(node->id);
1707 attr.output_indices = outputs.size() == 2;
1708 if (attr.output_indices) {
1709 // Fix data type for output indices. In the model it is set as float32.
1710 outputs[1]->tensor.type = DataType::INT32;
1711 }
1712 RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
1713 node->operation.attributes = attr;
1714 return absl::OkStatus();
1715 }
1716
1717 private:
1718 const PoolingType type_;
1719 };
1720
1721 class ReduceOperationParser : public TFLiteOperationParser {
1722 public:
ReduceOperationParser(OperationType operation_type)1723 explicit ReduceOperationParser(OperationType operation_type)
1724 : operation_type_(operation_type) {}
1725
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1726 absl::Status IsSupported(const TfLiteContext* context,
1727 const TfLiteNode* tflite_node,
1728 const TfLiteRegistration* registration) final {
1729 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1730 }
1731
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1732 absl::Status Parse(const TfLiteNode* tflite_node,
1733 const TfLiteRegistration* registration,
1734 GraphFloat32* graph, ObjectReader* reader) final {
1735 Node* node = graph->NewNode();
1736 node->operation.type = ToString(operation_type_);
1737 RETURN_IF_ERROR(reader->AddInput(node, 0));
1738
1739 const TfLiteReducerParams* tf_options;
1740 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1741
1742 ReduceAttributes attr;
1743 const TfLiteTensor* input = reader->GetInputTensor(0);
1744 const TfLiteTensor* axes = reader->GetInputTensor(1);
1745 for (int i = 0; i < NumElements(axes->dims); i++) {
1746 Axis axis;
1747 RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes->data.i32[i], &axis));
1748 attr.dims.insert(axis);
1749 }
1750 node->operation.attributes = attr;
1751
1752 if (!tf_options->keep_dims) {
1753 // GPU delegates does not support implicit shapes transformations
1754 // adding explicit Reshape
1755 const auto& input_tensor = graph->FindInputs(node->id)[0]->tensor;
1756 auto reduce_output_shape = input_tensor.shape;
1757 for (auto axis : attr.dims) {
1758 reduce_output_shape.set(axis, 1);
1759 }
1760 Node* node_reshape = graph->NewNode();
1761 node_reshape->operation.type = ToString(OperationType::RESHAPE);
1762 ReshapeAttributes reshape_attr;
1763 const TfLiteTensor* output = reader->GetOutputTensor(0);
1764 RETURN_IF_ERROR(ExtractTensorShape(*output, &reshape_attr.new_shape));
1765 node_reshape->operation.attributes = reshape_attr;
1766 Value* reduce_result = graph->NewValue();
1767 reduce_result->tensor.type = input_tensor.type;
1768 reduce_result->tensor.shape = reduce_output_shape;
1769
1770 RETURN_IF_ERROR(graph->SetProducer(node->id, reduce_result->id));
1771 RETURN_IF_ERROR(graph->AddConsumer(node_reshape->id, reduce_result->id));
1772 RETURN_IF_ERROR(reader->AddOutputs(node_reshape));
1773 } else {
1774 RETURN_IF_ERROR(reader->AddOutputs(node));
1775 }
1776 return absl::OkStatus();
1777 }
1778
1779 private:
1780 const OperationType operation_type_;
1781 };
1782
1783 class QuantizeOperationParser : public TFLiteOperationParser {
1784 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1785 absl::Status IsSupported(const TfLiteContext* context,
1786 const TfLiteNode* tflite_node,
1787 const TfLiteRegistration* registration) final {
1788 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1789 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1790 }
1791
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1792 absl::Status Parse(const TfLiteNode* tflite_node,
1793 const TfLiteRegistration* registration,
1794 GraphFloat32* graph, ObjectReader* reader) final {
1795 // 'Quantize' is rewritten as QuantizeAndDequantize since we are dealing
1796 // with floating-point versions of the original tensors.
1797 Node* node = graph->NewNode();
1798 node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
1799 RETURN_IF_ERROR(reader->AddInput(node, 0));
1800 RETURN_IF_ERROR(reader->AddOutputs(node));
1801
1802 // Quantization attributes should already be present in the output tensor.
1803 auto output_value = graph->FindOutputs(node->id)[0];
1804 if (!output_value->quant_params) {
1805 return absl::InvalidArgumentError(
1806 "Encountered Quantize output with no quant params");
1807 }
1808 QuantizeAndDequantizeAttributes attr;
1809 attr.min = output_value->quant_params.value().min;
1810 attr.max = output_value->quant_params.value().max;
1811 attr.scale = output_value->quant_params.value().scale;
1812
1813 node->operation.attributes = attr;
1814 return absl::OkStatus();
1815 }
1816 };
1817
1818 class ReLUOperationParser : public TFLiteOperationParser {
1819 public:
ReLUOperationParser(int clip)1820 explicit ReLUOperationParser(int clip) : clip_(clip) {}
1821
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1822 absl::Status IsSupported(const TfLiteContext* context,
1823 const TfLiteNode* tflite_node,
1824 const TfLiteRegistration* registration) final {
1825 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1826 return absl::OkStatus();
1827 }
1828
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1829 absl::Status Parse(const TfLiteNode* tflite_node,
1830 const TfLiteRegistration* registration,
1831 GraphFloat32* graph, ObjectReader* reader) final {
1832 Node* node = graph->NewNode();
1833 node->operation.type = ToString(OperationType::RELU);
1834 RETURN_IF_ERROR(reader->AddInput(node, 0));
1835
1836 ReLUAttributes attr;
1837 const TfLiteLeakyReluParams* tf_options;
1838 auto status = RetrieveBuiltinData(tflite_node, &tf_options);
1839 attr.alpha = status.ok() ? tf_options->alpha : 0;
1840 attr.clip = clip_;
1841 node->operation.attributes = attr;
1842 return reader->AddOutputs(node);
1843 }
1844
1845 private:
1846 const int clip_;
1847 };
1848
1849 class ResamplerOperationParser : public TFLiteOperationParser {
1850 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1851 absl::Status IsSupported(const TfLiteContext* context,
1852 const TfLiteNode* tflite_node,
1853 const TfLiteRegistration* registration) final {
1854 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1855 }
1856
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1857 absl::Status Parse(const TfLiteNode* tflite_node,
1858 const TfLiteRegistration* registration,
1859 GraphFloat32* graph, ObjectReader* reader) final {
1860 Node* node = graph->NewNode();
1861 RETURN_IF_ERROR(reader->AddInput(node, 0)); // src
1862 RETURN_IF_ERROR(reader->AddInput(node, 1)); // warp
1863 RETURN_IF_ERROR(reader->AddOutputs(node));
1864
1865 node->operation.type = ToString(OperationType::RESAMPLER);
1866
1867 auto src_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1868 auto warp_shape = graph->FindInputs(node->id)[1]->tensor.shape;
1869
1870 auto output_value = graph->FindOutputs(node->id)[0];
1871 output_value->tensor.shape =
1872 BHWC(src_shape.b, warp_shape.h, warp_shape.w, src_shape.c);
1873 return absl::OkStatus();
1874 }
1875 };
1876
1877 class ReshapeOperationParser : public TFLiteOperationParser {
1878 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1879 absl::Status IsSupported(const TfLiteContext* context,
1880 const TfLiteNode* tflite_node,
1881 const TfLiteRegistration* registration) final {
1882 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1883 // TODO(eignasheva): add shape checking
1884 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1885 }
1886
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1887 absl::Status Parse(const TfLiteNode* tflite_node,
1888 const TfLiteRegistration* registration,
1889 GraphFloat32* graph, ObjectReader* reader) final {
1890 Node* node = graph->NewNode();
1891 node->operation.type = ToString(OperationType::RESHAPE);
1892 RETURN_IF_ERROR(reader->AddInput(node, 0));
1893 RETURN_IF_ERROR(reader->AddOutputs(node));
1894 // Here we may have extra inputs. Other tensors were supposed to
1895 // define new shape, but in TFLite these are ignored.
1896 // TODO(akulik): check that shapes match?
1897
1898 // New shape comes from output shape.
1899 ReshapeAttributes attr;
1900 attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1901 node->operation.attributes = attr;
1902 return absl::OkStatus();
1903 }
1904 };
1905
1906 class Resize2DOperationParser : public TFLiteOperationParser {
1907 public:
Resize2DOperationParser(SamplingType sampling_type)1908 explicit Resize2DOperationParser(SamplingType sampling_type)
1909 : sampling_type_(sampling_type) {}
1910
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1911 absl::Status IsSupported(const TfLiteContext* context,
1912 const TfLiteNode* tflite_node,
1913 const TfLiteRegistration* registration) final {
1914 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
1915 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1916 }
1917
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1918 absl::Status Parse(const TfLiteNode* tflite_node,
1919 const TfLiteRegistration* registration,
1920 GraphFloat32* graph, ObjectReader* reader) final {
1921 Node* node = graph->NewNode();
1922 node->operation.type = ToString(OperationType::RESIZE);
1923 RETURN_IF_ERROR(reader->AddInput(node, 0));
1924 RETURN_IF_ERROR(reader->AddOutputs(node));
1925 // Here we may have extra inputs. Other tensors were supposed to
1926 // define new shape, but in TFLite these are ignored.
1927
1928 Resize2DAttributes attr;
1929 RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &attr.align_corners));
1930 RETURN_IF_ERROR(
1931 GetHalfPixelCentersValue(tflite_node, &attr.half_pixel_centers));
1932 attr.type = sampling_type_;
1933 attr.new_shape.CopyAllDefinedAxis(
1934 graph->FindOutputs(node->id)[0]->tensor.shape);
1935 node->operation.attributes = attr;
1936 return absl::OkStatus();
1937 }
1938
1939 private:
GetAlignCornersValue(const TfLiteNode * tflite_node,bool * align_corners)1940 absl::Status GetAlignCornersValue(const TfLiteNode* tflite_node,
1941 bool* align_corners) {
1942 switch (sampling_type_) {
1943 case SamplingType::BILINEAR:
1944 return GetAlignCornersValueForType<TfLiteResizeBilinearParams>(
1945 tflite_node, align_corners);
1946 case SamplingType::NEAREST:
1947 return GetAlignCornersValueForType<TfLiteResizeNearestNeighborParams>(
1948 tflite_node, align_corners);
1949 case SamplingType::UNKNOWN:
1950 return absl::InternalError("Sampling type is not specified");
1951 }
1952 return absl::OkStatus();
1953 }
1954
1955 template <class T>
GetAlignCornersValueForType(const TfLiteNode * tflite_node,bool * align_corners)1956 absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node,
1957 bool* align_corners) {
1958 const T* tf_options;
1959 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1960 *align_corners = tf_options->align_corners;
1961 return absl::OkStatus();
1962 }
1963
GetHalfPixelCentersValue(const TfLiteNode * tflite_node,bool * half_pixel_centers)1964 absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node,
1965 bool* half_pixel_centers) {
1966 if (sampling_type_ == SamplingType::BILINEAR) {
1967 const TfLiteResizeBilinearParams* tf_options;
1968 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1969 if (tf_options->align_corners && tf_options->half_pixel_centers) {
1970 return absl::InternalError(
1971 "If half_pixel_centers is True, align_corners must be False.");
1972 }
1973 *half_pixel_centers = tf_options->half_pixel_centers;
1974 } else {
1975 const TfLiteResizeNearestNeighborParams* tf_options;
1976 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1977 *half_pixel_centers = tf_options->half_pixel_centers;
1978 }
1979 return absl::OkStatus();
1980 }
1981
1982 SamplingType sampling_type_ = SamplingType::UNKNOWN;
1983 };
1984
1985 class SelectV2OperationParser : public TFLiteOperationParser {
1986 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1987 absl::Status IsSupported(const TfLiteContext* context,
1988 const TfLiteNode* tflite_node,
1989 const TfLiteRegistration* registration) final {
1990 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1991 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1992 }
1993
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1994 absl::Status Parse(const TfLiteNode* tflite_node,
1995 const TfLiteRegistration* registration,
1996 GraphFloat32* graph, ObjectReader* reader) final {
1997 Node* node = graph->NewNode();
1998 SelectV2Attributes attr;
1999 const TfLiteTensor* cond_tensor = reader->GetInputTensor(0);
2000 const TfLiteTensor* true_tensor = reader->GetInputTensor(1);
2001 const TfLiteTensor* false_tensor = reader->GetInputTensor(2);
2002 const bool is_if_constant = true_tensor->allocation_type == kTfLiteMmapRo;
2003 const bool is_else_constant =
2004 false_tensor->allocation_type == kTfLiteMmapRo;
2005 BHWC cond_shape, true_shape, false_shape;
2006 RETURN_IF_ERROR(ExtractTensorShape(*cond_tensor, &cond_shape));
2007 if (true_tensor->dims->size == 0) {
2008 attr.broadcast_true = true;
2009 } else {
2010 RETURN_IF_ERROR(ExtractTensorShape(*true_tensor, &true_shape));
2011 attr.broadcast_true = true_shape.DimensionsProduct() == 1;
2012 }
2013 if (false_tensor->dims->size == 0) {
2014 attr.broadcast_false = true;
2015 } else {
2016 RETURN_IF_ERROR(ExtractTensorShape(*false_tensor, &false_shape));
2017 attr.broadcast_false = false_shape.DimensionsProduct() == 1;
2018 }
2019 node->operation.type = ToString(OperationType::SELECT_V2);
2020 Value* if_value;
2021 Value* else_value;
2022 Tensor<BHWC, DataType::FLOAT32> if_tensor;
2023 Tensor<BHWC, DataType::FLOAT32> else_tensor;
2024 if (!attr.broadcast_true) {
2025 if (is_if_constant) {
2026 RETURN_IF_ERROR(reader->ReadTensor(1, &if_tensor));
2027 }
2028 } else {
2029 Tensor<Scalar, DataType::FLOAT32> if_scalar_tensor;
2030 RETURN_IF_ERROR(reader->ReadTensor(1, &if_scalar_tensor));
2031 if_tensor.shape = BHWC(1, 1, 1, 1);
2032 if_tensor.data.push_back(if_scalar_tensor.data[0]);
2033 }
2034 if (!attr.broadcast_false) {
2035 if (is_else_constant) {
2036 RETURN_IF_ERROR(reader->ReadTensor(2, &else_tensor));
2037 }
2038 } else {
2039 Tensor<Scalar, DataType::FLOAT32> else_scalar_tensor;
2040 RETURN_IF_ERROR(reader->ReadTensor(2, &else_scalar_tensor));
2041 else_tensor.shape = BHWC(1, 1, 1, 1);
2042 else_tensor.data.push_back(else_scalar_tensor.data[0]);
2043 }
2044 node->operation.attributes = std::move(attr);
2045 RETURN_IF_ERROR(reader->AddInput(node, 0));
2046 if (is_if_constant) {
2047 RETURN_IF_ERROR(NewConstNode(if_tensor, graph, &if_value));
2048 RETURN_IF_ERROR(graph->AddConsumer(node->id, if_value->id));
2049 } else {
2050 RETURN_IF_ERROR(reader->AddInput(node, 1));
2051 }
2052 if (is_else_constant) {
2053 RETURN_IF_ERROR(NewConstNode(else_tensor, graph, &else_value));
2054 RETURN_IF_ERROR(graph->AddConsumer(node->id, else_value->id));
2055 } else {
2056 RETURN_IF_ERROR(reader->AddInput(node, 2));
2057 }
2058 RETURN_IF_ERROR(reader->AddOutputs(node));
2059 return absl::OkStatus();
2060 }
2061 };
2062
2063 class SliceOperationParser : public TFLiteOperationParser {
2064 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2065 absl::Status IsSupported(const TfLiteContext* context,
2066 const TfLiteNode* tflite_node,
2067 const TfLiteRegistration* registration) final {
2068 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2069 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2070 }
2071
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2072 absl::Status Parse(const TfLiteNode* tflite_node,
2073 const TfLiteRegistration* registration,
2074 GraphFloat32* graph, ObjectReader* reader) final {
2075 Node* node = graph->NewNode();
2076 node->operation.type = ToString(OperationType::SLICE);
2077 RETURN_IF_ERROR(reader->AddOutputs(node));
2078 Value* input;
2079 RETURN_IF_ERROR(reader->ReadValue(0, &input));
2080 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
2081
2082 const TfLiteTensor* tfl_input = reader->GetInputTensor(0);
2083 const int input_dims = tfl_input->dims->size;
2084
2085 SliceAttributes attr;
2086 attr.strides = BHWC(1, 1, 1, 1);
2087 Tensor<Linear, DataType::INT32> starts, sizes;
2088 RETURN_IF_ERROR(reader->ReadTensor(1, &starts));
2089 RETURN_IF_ERROR(reader->ReadTensor(2, &sizes));
2090 if (starts.data.size() != sizes.data.size()) {
2091 return absl::InvalidArgumentError("Starts amount != sizes amount.");
2092 }
2093 BHWC bhwc_starts(0, 0, 0, 0);
2094 BHWC bhwc_sizes = input->tensor.shape;
2095 if (input_dims == 4) {
2096 // input in BHWC layout
2097 if (starts.data.size() == 4) {
2098 bhwc_starts.b = starts.data[0];
2099 bhwc_starts.h = starts.data[1];
2100 bhwc_starts.w = starts.data[2];
2101 bhwc_starts.c = starts.data[3];
2102 bhwc_sizes.b = sizes.data[0];
2103 bhwc_sizes.h = sizes.data[1];
2104 bhwc_sizes.w = sizes.data[2];
2105 bhwc_sizes.c = sizes.data[3];
2106 } else if (starts.data.size() == 3) {
2107 // if input is 4D(BHWC) and args 3D, we assume that args in HWC layout
2108 bhwc_starts.h = starts.data[0];
2109 bhwc_starts.w = starts.data[1];
2110 bhwc_starts.c = starts.data[2];
2111 bhwc_sizes.h = sizes.data[0];
2112 bhwc_sizes.w = sizes.data[1];
2113 bhwc_sizes.c = sizes.data[2];
2114 } else {
2115 return absl::UnimplementedError(
2116 "Slicing is supported for 3 or 4 dimensional tensors only.");
2117 }
2118 } else if (input_dims == 3) {
2119 // input in BWC layout
2120 if (starts.data.size() == 3) {
2121 bhwc_starts.b = starts.data[0];
2122 bhwc_starts.w = starts.data[1];
2123 bhwc_starts.c = starts.data[2];
2124 bhwc_sizes.b = sizes.data[0];
2125 bhwc_sizes.w = sizes.data[1];
2126 bhwc_sizes.c = sizes.data[2];
2127 } else {
2128 return absl::UnimplementedError(
2129 "Slicing is supported for 3 or 4 dimensional tensors only.");
2130 }
2131 } else {
2132 return absl::UnimplementedError(
2133 "Slicing is supported for 3 or 4 dimensional tensors only.");
2134 }
2135 const auto& in_shape = input->tensor.shape;
2136 if (bhwc_sizes.b == -1) {
2137 bhwc_sizes.b = in_shape.b - bhwc_starts.b;
2138 }
2139 if (bhwc_sizes.h == -1) {
2140 bhwc_sizes.h = in_shape.h - bhwc_starts.h;
2141 }
2142 if (bhwc_sizes.w == -1) {
2143 bhwc_sizes.w = in_shape.w - bhwc_starts.w;
2144 }
2145 if (bhwc_sizes.c == -1) {
2146 bhwc_sizes.c = in_shape.c - bhwc_starts.c;
2147 }
2148 attr.starts = bhwc_starts;
2149 attr.ends =
2150 BHWC(bhwc_starts.b + bhwc_sizes.b, bhwc_starts.h + bhwc_sizes.h,
2151 bhwc_starts.w + bhwc_sizes.w, bhwc_starts.c + bhwc_sizes.c);
2152 RETURN_IF_ERROR(UpdateIfNegative(in_shape, &attr));
2153
2154 auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
2155 if ((attr.ends.b - attr.starts.b) != out_shape.b) {
2156 return absl::UnimplementedError("Output batch don't match");
2157 }
2158 if ((attr.ends.h - attr.starts.h) != out_shape.h) {
2159 return absl::UnimplementedError("Output height doesn't match");
2160 }
2161 if ((attr.ends.w - attr.starts.w) != out_shape.w) {
2162 return absl::UnimplementedError("Output width doesn't match");
2163 }
2164 if ((attr.ends.c - attr.starts.c) != out_shape.c) {
2165 return absl::UnimplementedError("Output channels don't match");
2166 }
2167 node->operation.attributes = attr;
2168 return absl::OkStatus();
2169 }
2170
2171 private:
UpdateIfNegative(const BHWC & input_shape,SliceAttributes * attr)2172 absl::Status UpdateIfNegative(const BHWC& input_shape,
2173 SliceAttributes* attr) {
2174 if (attr->ends.h < 0) {
2175 attr->ends.h = input_shape.h + attr->ends.h;
2176 }
2177 if (attr->ends.w < 0) {
2178 attr->ends.w = input_shape.w + attr->ends.w;
2179 }
2180 if (attr->ends.c < 0) {
2181 attr->ends.c = input_shape.c + attr->ends.c;
2182 }
2183 if (attr->ends.b < 0) {
2184 attr->ends.b = input_shape.b + attr->ends.b;
2185 }
2186 return absl::OkStatus();
2187 }
2188 };
2189
2190 class SoftmaxOperationParser : public TFLiteOperationParser {
2191 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2192 absl::Status IsSupported(const TfLiteContext* context,
2193 const TfLiteNode* tflite_node,
2194 const TfLiteRegistration* registration) final {
2195 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2196 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2197 }
2198
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2199 absl::Status Parse(const TfLiteNode* tflite_node,
2200 const TfLiteRegistration* registration,
2201 GraphFloat32* graph, ObjectReader* reader) final {
2202 Node* node = graph->NewNode();
2203 node->operation.type = ToString(OperationType::SOFTMAX);
2204 RETURN_IF_ERROR(reader->AddInput(node, 0));
2205 RETURN_IF_ERROR(reader->AddOutputs(node));
2206
2207 const TfLiteSoftmaxParams* tf_options;
2208 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2209 if (tf_options->beta != 1) {
2210 // there is multiply by scalar operation fused in softmax. Make a layer
2211 // out of it before softmax.
2212 return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
2213 // auto mul_node = reader->NewPassthroughNode(node);
2214 // mul_node->operation.type = ToString(OperationType::MUL);
2215 }
2216 SoftmaxAttributes attr;
2217 attr.axis = Axis::CHANNELS; // always by channels
2218 node->operation.attributes = attr;
2219 return absl::OkStatus();
2220 }
2221 };
2222
2223 class SpaceToDepthOperationParser : public TFLiteOperationParser {
2224 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2225 absl::Status IsSupported(const TfLiteContext* context,
2226 const TfLiteNode* tflite_node,
2227 const TfLiteRegistration* registration) final {
2228 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2229 // TODO(impjdi): Dims check.
2230 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2231 }
2232
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2233 absl::Status Parse(const TfLiteNode* tflite_node,
2234 const TfLiteRegistration* registration,
2235 GraphFloat32* graph, ObjectReader* reader) final {
2236 Node* node = graph->NewNode();
2237 node->operation.type = ToString(OperationType::SPACE_TO_DEPTH);
2238 RETURN_IF_ERROR(reader->AddInput(node, 0));
2239 RETURN_IF_ERROR(reader->AddOutputs(node));
2240 const TfLiteSpaceToDepthParams* tf_options;
2241 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2242 SpaceToDepthAttributes attr;
2243 attr.block_size = tf_options->block_size;
2244 node->operation.attributes = attr;
2245 return absl::OkStatus();
2246 }
2247 };
2248
2249 class SplitOperationParser : public TFLiteOperationParser {
2250 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2251 absl::Status IsSupported(const TfLiteContext* context,
2252 const TfLiteNode* tflite_node,
2253 const TfLiteRegistration* registration) final {
2254 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2255 }
2256
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2257 absl::Status Parse(const TfLiteNode* tflite_node,
2258 const TfLiteRegistration* registration,
2259 GraphFloat32* graph, ObjectReader* reader) final {
2260 const TfLiteSplitParams* split_params;
2261 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
2262 if (split_params->num_splits == 1) {
2263 // Adding Identity reshape that will be removed.
2264 Node* node = graph->NewNode();
2265 node->operation.type = ToString(OperationType::RESHAPE);
2266 RETURN_IF_ERROR(reader->AddInput(node, 1));
2267 RETURN_IF_ERROR(reader->AddOutputs(node));
2268 // New shape comes from output shape.
2269 ReshapeAttributes attr;
2270 attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
2271 node->operation.attributes = attr;
2272 return absl::OkStatus();
2273 }
2274 const TfLiteTensor* input = reader->GetInputTensor(1);
2275 const TfLiteTensor* axis_tensor = reader->GetInputTensor(0);
2276 SplitAttributes attr;
2277 RETURN_IF_ERROR(
2278 ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
2279
2280 Node* node = graph->NewNode();
2281 node->operation.type = ToString(OperationType::SPLIT);
2282 node->operation.attributes = attr;
2283 RETURN_IF_ERROR(reader->AddInput(node, 1));
2284 for (int i = 0; i < tflite_node->outputs->size; ++i) {
2285 RETURN_IF_ERROR(reader->AddOutput(node, i));
2286 }
2287 return absl::OkStatus();
2288 }
2289 };
2290
2291 class SplitVOperationParser : public TFLiteOperationParser {
2292 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2293 absl::Status IsSupported(const TfLiteContext* context,
2294 const TfLiteNode* tflite_node,
2295 const TfLiteRegistration* registration) final {
2296 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2297 }
2298
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2299 absl::Status Parse(const TfLiteNode* tflite_node,
2300 const TfLiteRegistration* registration,
2301 GraphFloat32* graph, ObjectReader* reader) final {
2302 const TfLiteSplitVParams* split_params;
2303 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
2304 if (split_params->num_splits == 1) {
2305 // Adding Identity reshape that will be removed.
2306 Node* node = graph->NewNode();
2307 node->operation.type = ToString(OperationType::RESHAPE);
2308 RETURN_IF_ERROR(reader->AddInput(node, 0));
2309 RETURN_IF_ERROR(reader->AddOutputs(node));
2310 // New shape comes from output shape.
2311 ReshapeAttributes attr;
2312 attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
2313 node->operation.attributes = attr;
2314 return absl::OkStatus();
2315 }
2316 const TfLiteTensor* input = reader->GetInputTensor(0);
2317 const TfLiteTensor* axis_tensor = reader->GetInputTensor(2);
2318 SplitAttributes attr;
2319 RETURN_IF_ERROR(
2320 ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
2321
2322 Node* node = graph->NewNode();
2323 node->operation.type = ToString(OperationType::SPLIT);
2324 node->operation.attributes = attr;
2325 RETURN_IF_ERROR(reader->AddInput(node, 0));
2326 for (int i = 0; i < tflite_node->outputs->size; ++i) {
2327 RETURN_IF_ERROR(reader->AddOutput(node, i));
2328 }
2329 return absl::OkStatus();
2330 }
2331 };
2332
2333 class StridedSliceOperationParser : public TFLiteOperationParser {
2334 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2335 absl::Status IsSupported(const TfLiteContext* context,
2336 const TfLiteNode* tflite_node,
2337 const TfLiteRegistration* registration) final {
2338 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2339 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2340 }
2341
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2342 absl::Status Parse(const TfLiteNode* tflite_node,
2343 const TfLiteRegistration* registration,
2344 GraphFloat32* graph, ObjectReader* reader) final {
2345 Node* node = graph->NewNode();
2346 node->operation.type = ToString(OperationType::SLICE);
2347 RETURN_IF_ERROR(reader->AddOutputs(node));
2348 Value* input;
2349 RETURN_IF_ERROR(reader->ReadValue(0, &input));
2350 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
2351
2352 Tensor<Linear, DataType::INT32> tmp;
2353 RETURN_IF_ERROR(reader->ReadTensor(1, &tmp));
2354
2355 bool read_without_batch = tmp.data.size() == 3;
2356 bool read_with_batch = tmp.data.size() == 4;
2357 if (!read_without_batch && !read_with_batch) {
2358 // Error: Must be catched in IsSupported()
2359 return absl::UnimplementedError(
2360 "Slicing is supported for 3 or 4 dimensional tensors only.");
2361 }
2362
2363 const TfLiteStridedSliceParams* tf_options;
2364 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2365 RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
2366
2367 auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
2368
2369 SliceAttributes attr;
2370 if (read_without_batch) {
2371 RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options,
2372 input->tensor.shape, &attr));
2373 }
2374 if (read_with_batch) {
2375 RETURN_IF_ERROR(
2376 ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr));
2377 }
2378 if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 ||
2379 attr.strides.c == 0) {
2380 return absl::InvalidArgumentError("stride values must be non-zero");
2381 }
2382 if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 ||
2383 attr.strides.c < 0) {
2384 return absl::UnimplementedError("Reverse slices are not supported.");
2385 }
2386 if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b !=
2387 out_shape.b) {
2388 return absl::UnimplementedError("Output batch don't match");
2389 }
2390 if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h !=
2391 out_shape.h) {
2392 return absl::UnimplementedError("Output height doesn't match");
2393 }
2394 if ((attr.ends.w - attr.starts.w + attr.strides.w - 1) / attr.strides.w !=
2395 out_shape.w) {
2396 return absl::UnimplementedError("Output width doesn't match");
2397 }
2398 if ((attr.ends.c - attr.starts.c + attr.strides.c - 1) / attr.strides.c !=
2399 out_shape.c) {
2400 return absl::UnimplementedError("Output channels don't match");
2401 }
2402 node->operation.attributes = attr;
2403 return absl::OkStatus();
2404 }
2405
2406 private:
UpdateWithMask(const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,int ignore_b,int ignore_h,int ignore_w,int ignore_c,SliceAttributes * attr)2407 absl::Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options,
2408 const BHWC& input_shape, int ignore_b,
2409 int ignore_h, int ignore_w, int ignore_c,
2410 SliceAttributes* attr) {
2411 if (tf_options->begin_mask & ignore_h) {
2412 attr->starts.h = 0;
2413 }
2414 if (tf_options->begin_mask & ignore_w) {
2415 attr->starts.w = 0;
2416 }
2417 if (tf_options->begin_mask & ignore_c) {
2418 attr->starts.c = 0;
2419 }
2420 if (tf_options->begin_mask & ignore_b) {
2421 attr->starts.b = 0;
2422 }
2423
2424 if (tf_options->end_mask & ignore_h) {
2425 attr->ends.h = input_shape.h;
2426 }
2427 if (tf_options->end_mask & ignore_w) {
2428 attr->ends.w = input_shape.w;
2429 }
2430 if (tf_options->end_mask & ignore_c) {
2431 attr->ends.c = input_shape.c;
2432 }
2433 if (tf_options->end_mask & ignore_b) {
2434 attr->ends.b = input_shape.b;
2435 }
2436 return absl::OkStatus();
2437 }
2438
UpdateIfNegative(const BHWC & input_shape,SliceAttributes * attr)2439 absl::Status UpdateIfNegative(const BHWC& input_shape,
2440 SliceAttributes* attr) {
2441 if (attr->ends.h < 0) {
2442 attr->ends.h = input_shape.h + attr->ends.h;
2443 }
2444 if (attr->ends.w < 0) {
2445 attr->ends.w = input_shape.w + attr->ends.w;
2446 }
2447 if (attr->ends.c < 0) {
2448 attr->ends.c = input_shape.c + attr->ends.c;
2449 }
2450 if (attr->ends.b < 0) {
2451 attr->ends.b = input_shape.b + attr->ends.b;
2452 }
2453
2454 if (attr->starts.h < 0) {
2455 attr->starts.h = input_shape.h + attr->starts.h;
2456 }
2457 if (attr->starts.w < 0) {
2458 attr->starts.w = input_shape.w + attr->starts.w;
2459 }
2460 if (attr->starts.c < 0) {
2461 attr->starts.c = input_shape.c + attr->starts.c;
2462 }
2463 if (attr->starts.b < 0) {
2464 attr->starts.b = input_shape.b + attr->starts.b;
2465 }
2466
2467 return absl::OkStatus();
2468 }
2469
ReadAttribsWithBatch(const ObjectReader * reader,const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,SliceAttributes * attr)2470 absl::Status ReadAttribsWithBatch(const ObjectReader* reader,
2471 const TfLiteStridedSliceParams* tf_options,
2472 const BHWC& input_shape,
2473 SliceAttributes* attr) {
2474 auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
2475 Tensor<Linear, DataType::INT32> t;
2476 RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
2477 *bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]);
2478 return absl::OkStatus();
2479 };
2480
2481 RETURN_IF_ERROR(read_bhwc(1, &attr->starts));
2482 RETURN_IF_ERROR(read_bhwc(2, &attr->ends));
2483 RETURN_IF_ERROR(read_bhwc(3, &attr->strides));
2484 RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
2485 RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr));
2486 return absl::OkStatus();
2487 }
2488
ReadAttribsWithoutBatch(const ObjectReader * reader,const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,SliceAttributes * attr)2489 absl::Status ReadAttribsWithoutBatch(
2490 const ObjectReader* reader, const TfLiteStridedSliceParams* tf_options,
2491 const BHWC& input_shape, SliceAttributes* attr) {
2492 auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
2493 Tensor<Linear, DataType::INT32> t;
2494 RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
2495 *bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]);
2496 return absl::OkStatus();
2497 };
2498
2499 RETURN_IF_ERROR(read_hwc(1, &attr->starts));
2500 RETURN_IF_ERROR(read_hwc(2, &attr->ends));
2501 RETURN_IF_ERROR(read_hwc(3, &attr->strides));
2502 RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
2503 RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr));
2504 attr->starts.b = 0;
2505 attr->ends.b = input_shape.b;
2506 attr->strides.b = 1;
2507 return absl::OkStatus();
2508 }
CheckOptionsSupport(const TfLiteStridedSliceParams * tf_options)2509 absl::Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) {
2510 if (tf_options->ellipsis_mask) {
2511 return absl::UnimplementedError("Slice does not support ellipsis_mask.");
2512 }
2513 if (tf_options->new_axis_mask) {
2514 return absl::UnimplementedError("Slice does not support new_axis_mask.");
2515 }
2516 if (tf_options->shrink_axis_mask) {
2517 return absl::UnimplementedError(
2518 "Slice does not support shrink_axis_mask parameter. ");
2519 }
2520 return absl::OkStatus();
2521 }
2522 };
2523
2524 class TileOperationParser : public TFLiteOperationParser {
2525 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2526 absl::Status IsSupported(const TfLiteContext* context,
2527 const TfLiteNode* tflite_node,
2528 const TfLiteRegistration* registration) final {
2529 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2530 }
2531
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2532 absl::Status Parse(const TfLiteNode* tflite_node,
2533 const TfLiteRegistration* registration,
2534 GraphFloat32* graph, ObjectReader* reader) final {
2535 Node* node = graph->NewNode();
2536 node->operation.type = ToString(OperationType::TILE);
2537 RETURN_IF_ERROR(reader->AddInput(node, 0));
2538 RETURN_IF_ERROR(reader->AddOutputs(node));
2539 return absl::OkStatus();
2540 }
2541 };
2542
2543 // Builtin op version of TRANSPOSE_CONV.
2544 class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
2545 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2546 absl::Status IsSupported(const TfLiteContext* context,
2547 const TfLiteNode* tflite_node,
2548 const TfLiteRegistration* registration) final {
2549 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
2550 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2551 }
2552
2553 // TFLite's TRANSPOSE_CONV expects 3-4 input tensors (output shape, weights,
2554 // input, and an optional bias) and allows configurable padding & stride.
2555 // TODO(impjdi): Translate output_shape to attr.adjacent.
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2556 absl::Status Parse(const TfLiteNode* tflite_node,
2557 const TfLiteRegistration* registration,
2558 GraphFloat32* graph, ObjectReader* reader) final {
2559 auto* node = graph->NewNode();
2560 node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
2561 Value* input;
2562 RETURN_IF_ERROR(reader->ReadValue(2, &input));
2563 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
2564 RETURN_IF_ERROR(reader->AddOutputs(node));
2565
2566 const TfLiteTransposeConvParams* tf_options;
2567 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2568
2569 ConvolutionTransposedAttributes attr;
2570 attr.stride = tf_options
2571 ? HW(tf_options->stride_height, tf_options->stride_width)
2572 : HW(1, 1);
2573 const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
2574 if (runtime_inputs == 2) {
2575 RETURN_IF_ERROR(reader->AddInput(node, 1));
2576 auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape;
2577 attr.weights.shape = OHWI(weights_shape.b, weights_shape.h,
2578 weights_shape.w, weights_shape.c);
2579 } else { // runtime_inputs == 1;
2580 RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
2581 }
2582 reader->ReadTensor(3, &attr.bias).IgnoreError(); // bias is optional
2583
2584 UpdatePadding(tf_options->padding,
2585 graph->FindInputs(node->id)[0]->tensor.shape, &attr);
2586 node->operation.attributes = std::move(attr);
2587 return absl::OkStatus();
2588 }
2589 };
2590
2591 // Custom op version of TRANSPOSE_CONV.
2592 class TransposeConvCustomOperationParser : public TFLiteOperationParser {
2593 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2594 absl::Status IsSupported(const TfLiteContext* context,
2595 const TfLiteNode* tflite_node,
2596 const TfLiteRegistration* registration) final {
2597 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2598 }
2599
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2600 absl::Status Parse(const TfLiteNode* tflite_node,
2601 const TfLiteRegistration* registration,
2602 GraphFloat32* graph, ObjectReader* reader) final {
2603 auto* node = graph->NewNode();
2604 node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
2605 RETURN_IF_ERROR(reader->AddInput(node, 0));
2606 RETURN_IF_ERROR(reader->AddOutputs(node));
2607
2608 const TfLiteTransposeConvParams* tf_options;
2609 auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
2610
2611 ConvolutionTransposedAttributes attr;
2612 attr.stride = status.ok()
2613 ? HW(tf_options->stride_height, tf_options->stride_width)
2614 : HW(1, 1);
2615 RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
2616 reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
2617
2618 UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
2619 graph->FindInputs(node->id)[0]->tensor.shape, &attr);
2620 node->operation.attributes = std::move(attr);
2621 return absl::OkStatus();
2622 }
2623 };
2624
2625 class TransposeOperationParser : public TFLiteOperationParser {
2626 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2627 absl::Status IsSupported(const TfLiteContext* context,
2628 const TfLiteNode* tflite_node,
2629 const TfLiteRegistration* registration) final {
2630 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2631 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2632 }
2633
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2634 absl::Status Parse(const TfLiteNode* tflite_node,
2635 const TfLiteRegistration* registration,
2636 GraphFloat32* graph, ObjectReader* reader) final {
2637 Node* node = graph->NewNode();
2638 node->operation.type = ToString(OperationType::TRANSPOSE);
2639 RETURN_IF_ERROR(reader->AddInput(node, 0));
2640 RETURN_IF_ERROR(reader->AddOutputs(node));
2641
2642 TransposeAttributes attr;
2643 Tensor<Linear, DataType::INT32> perm;
2644 RETURN_IF_ERROR(reader->ReadTensor(1, &perm));
2645 std::map<Axis, int> axis_to_index = {{Axis::BATCH, 0},
2646 {Axis::HEIGHT, 1},
2647 {Axis::WIDTH, 2},
2648 {Axis::CHANNELS, 3}};
2649 if (perm.data.size() == 4) {
2650 attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[2], perm.data[3]);
2651 } else if (perm.data.size() == 3) {
2652 std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::WIDTH,
2653 Axis::CHANNELS};
2654 attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
2655 attr.perm.h = 1;
2656 attr.perm.w = axis_to_index[index_to_axis[perm.data[1]]];
2657 attr.perm.c = axis_to_index[index_to_axis[perm.data[2]]];
2658 } else if (perm.data.size() == 2) {
2659 std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::CHANNELS};
2660 attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
2661 attr.perm.h = 1;
2662 attr.perm.w = 2;
2663 attr.perm.c = axis_to_index[index_to_axis[perm.data[1]]];
2664 } else {
2665 return absl::InvalidArgumentError(
2666 "Permutation for transpose is invalid.");
2667 }
2668
2669 node->operation.attributes = attr;
2670 return absl::OkStatus();
2671 }
2672 };
2673
2674 class UnpackOperationParser : public TFLiteOperationParser {
2675 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2676 absl::Status IsSupported(const TfLiteContext* context,
2677 const TfLiteNode* tflite_node,
2678 const TfLiteRegistration* registration) final {
2679 return absl::OkStatus();
2680 }
2681
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2682 absl::Status Parse(const TfLiteNode* tflite_node,
2683 const TfLiteRegistration* registration,
2684 GraphFloat32* graph, ObjectReader* reader) final {
2685 const TfLiteUnpackParams* unpack_params;
2686 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &unpack_params));
2687 if (unpack_params->num == 1) {
2688 // Adding Identity reshape that will be removed.
2689 Node* node = graph->NewNode();
2690 node->operation.type = ToString(OperationType::RESHAPE);
2691 RETURN_IF_ERROR(reader->AddInput(node, 1));
2692 RETURN_IF_ERROR(reader->AddOutputs(node));
2693 // New shape comes from output shape.
2694 ReshapeAttributes attr;
2695 attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
2696 node->operation.attributes = attr;
2697 return absl::OkStatus();
2698 }
2699 const TfLiteTensor* input = reader->GetInputTensor(0);
2700 BHWC input_shape;
2701 RETURN_IF_ERROR(ExtractTensorShape(*input, &input_shape));
2702 SplitAttributes attr;
2703 RETURN_IF_ERROR(
2704 ExtractAxisFromIndex(*input, unpack_params->axis, &attr.axis));
2705 BHWC output_required_shape = input_shape;
2706 output_required_shape.set(attr.axis, 1);
2707
2708 Node* node = graph->NewNode();
2709 node->operation.type = ToString(OperationType::SPLIT);
2710 node->operation.attributes = attr;
2711 RETURN_IF_ERROR(reader->AddInput(node, 0));
2712 auto input_value = graph->FindInputs(node->id)[0];
2713 for (int i = 0; i < tflite_node->outputs->size; ++i) {
2714 const TfLiteTensor* output = reader->GetOutputTensor(i);
2715 BHWC output_shape;
2716 RETURN_IF_ERROR(ExtractTensorShape(*output, &output_shape));
2717 if (output_shape != output_required_shape) {
2718 // GPU delegates does not support implicit shapes transformations
2719 // adding explicit Reshape
2720 Value* copy_value = graph->NewValue();
2721 copy_value->tensor.type = input_value->tensor.type;
2722 copy_value->tensor.shape = output_required_shape;
2723 RETURN_IF_ERROR(graph->SetProducer(node->id, copy_value->id));
2724 Node* node_reshape = graph->NewNode();
2725 node_reshape->operation.type = ToString(OperationType::RESHAPE);
2726 ReshapeAttributes reshape_attr;
2727 reshape_attr.new_shape = output_shape;
2728 node_reshape->operation.attributes = reshape_attr;
2729 RETURN_IF_ERROR(graph->AddConsumer(node_reshape->id, copy_value->id));
2730 RETURN_IF_ERROR(reader->AddOutput(node_reshape, i));
2731 } else {
2732 RETURN_IF_ERROR(reader->AddOutput(node, i));
2733 }
2734 }
2735 return absl::OkStatus();
2736 }
2737 };
2738
2739 class Unpooling2DOperationParser : public TFLiteOperationParser {
2740 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2741 absl::Status IsSupported(const TfLiteContext* context,
2742 const TfLiteNode* tflite_node,
2743 const TfLiteRegistration* registration) final {
2744 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2745 }
2746
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2747 absl::Status Parse(const TfLiteNode* tflite_node,
2748 const TfLiteRegistration* registration,
2749 GraphFloat32* graph, ObjectReader* reader) final {
2750 Node* node = graph->NewNode();
2751 node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D);
2752 RETURN_IF_ERROR(reader->AddInput(node, 0));
2753 RETURN_IF_ERROR(reader->AddInput(node, 1));
2754 RETURN_IF_ERROR(reader->AddOutputs(node));
2755 auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
2756 MaxUnpooling2DAttributes attr;
2757
2758 const TfLitePoolParams* tf_options;
2759 RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
2760
2761 attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
2762 attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
2763 UpdatePadding(tf_options->padding, input_shape, &attr);
2764
2765 node->operation.attributes = attr;
2766
2767 auto output_value = graph->FindOutputs(node->id)[0];
2768 output_value->tensor.shape = CalculateOutputShape(input_shape, attr);
2769 return absl::OkStatus();
2770 }
2771 };
2772
2773 // TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
2774 class BatchToSpaceOperationParser : public TFLiteOperationParser {
2775 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2776 absl::Status IsSupported(const TfLiteContext* context,
2777 const TfLiteNode* tflite_node,
2778 const TfLiteRegistration* registration) final {
2779 return absl::OkStatus();
2780 }
2781
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2782 absl::Status Parse(const TfLiteNode* tflite_node,
2783 const TfLiteRegistration* registration,
2784 GraphFloat32* graph, ObjectReader* reader) final {
2785 auto* node = graph->NewNode();
2786 node->operation.type = ToString(OperationType::BATCH_TO_SPACE);
2787 RETURN_IF_ERROR(reader->AddInput(node, 0));
2788 RETURN_IF_ERROR(reader->AddOutputs(node));
2789
2790 BatchToSpaceAttributes bs_attr;
2791 Tensor<Linear, DataType::INT32> block;
2792 RETURN_IF_ERROR(reader->ReadTensor(1, &block));
2793 if (block.shape.v != 2) {
2794 return absl::InternalError("Space has to be HxW.");
2795 }
2796 bs_attr.block.h = block.data[0];
2797 bs_attr.block.w = block.data[1];
2798
2799 Tensor<HW, DataType::INT32> crop;
2800 RETURN_IF_ERROR(reader->ReadTensor(2, &crop));
2801 auto crop_shape = crop.shape;
2802 if (crop_shape.h != 2 && crop_shape.w != 2) {
2803 return absl::InternalError("Space has to be HxW.");
2804 }
2805
2806 bs_attr.crop.prepended.h = crop.data[0];
2807 bs_attr.crop.prepended.w = crop.data[2];
2808
2809 bs_attr.crop.appended.h = crop.data[1];
2810 bs_attr.crop.appended.w = crop.data[3];
2811
2812 node->operation.attributes = std::move(bs_attr);
2813 return absl::OkStatus();
2814 }
2815 };
2816
2817 class SpaceToBatchOperationParser : public TFLiteOperationParser {
2818 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2819 absl::Status IsSupported(const TfLiteContext* context,
2820 const TfLiteNode* tflite_node,
2821 const TfLiteRegistration* registration) final {
2822 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2823 }
2824
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2825 absl::Status Parse(const TfLiteNode* tflite_node,
2826 const TfLiteRegistration* registration,
2827 GraphFloat32* graph, ObjectReader* reader) final {
2828 auto* node = graph->NewNode();
2829 node->operation.type = ToString(OperationType::SPACE_TO_BATCH);
2830 RETURN_IF_ERROR(reader->AddInput(node, 0));
2831 RETURN_IF_ERROR(reader->AddOutputs(node));
2832 SpaceToBatchAttributes sb_attr;
2833 Tensor<Linear, DataType::INT32> block;
2834 RETURN_IF_ERROR(reader->ReadTensor(1, &block));
2835 if (block.shape.v != 2) {
2836 return absl::InternalError("Space has to be HxW.");
2837 }
2838 sb_attr.block.h = block.data[0];
2839 sb_attr.block.w = block.data[1];
2840
2841 Tensor<HW, DataType::INT32> padding;
2842 RETURN_IF_ERROR(reader->ReadTensor(2, &padding));
2843 auto padding_shape = padding.shape;
2844
2845 if (padding_shape.h != 2 && padding_shape.w != 2) {
2846 return absl::InternalError("Space has to be HxW.");
2847 }
2848
2849 sb_attr.padding.prepended.h = padding.data[0];
2850 sb_attr.padding.prepended.w = padding.data[2];
2851
2852 sb_attr.padding.appended.h = padding.data[1];
2853 sb_attr.padding.appended.w = padding.data[3];
2854
2855 node->operation.attributes = std::move(sb_attr);
2856 return absl::OkStatus();
2857 }
2858 };
2859
2860 class MeanOperationParser : public TFLiteOperationParser {
2861 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2862 absl::Status IsSupported(const TfLiteContext* context,
2863 const TfLiteNode* tflite_node,
2864 const TfLiteRegistration* registration) final {
2865 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2866 }
2867
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2868 absl::Status Parse(const TfLiteNode* tflite_node,
2869 const TfLiteRegistration* registration,
2870 GraphFloat32* graph, ObjectReader* reader) final {
2871 auto* node = graph->NewNode();
2872 node->operation.type = ToString(OperationType::MEAN);
2873 RETURN_IF_ERROR(reader->AddInput(node, 0));
2874 RETURN_IF_ERROR(reader->AddOutputs(node));
2875
2876 MeanAttributes attr;
2877 const TfLiteTensor* input = reader->GetInputTensor(0);
2878 const TfLiteTensor* axes = reader->GetInputTensor(1);
2879 for (int i = 0; i < NumElements(axes->dims); i++) {
2880 Axis axis;
2881 RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes->data.i32[i], &axis));
2882 attr.dims.insert(axis);
2883 }
2884 node->operation.attributes = attr;
2885 return absl::OkStatus();
2886 }
2887 };
2888
2889 class UnsupportedOperationParser : public TFLiteOperationParser {
2890 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2891 absl::Status IsSupported(const TfLiteContext* context,
2892 const TfLiteNode* tflite_node,
2893 const TfLiteRegistration* registration) final {
2894 return absl::UnimplementedError("Operation is not supported.");
2895 }
2896
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2897 absl::Status Parse(const TfLiteNode* tflite_node,
2898 const TfLiteRegistration* registration,
2899 GraphFloat32* graph, ObjectReader* reader) final {
2900 return absl::UnimplementedError("Operation is not supported.");
2901 }
2902 };
2903
IsSupported(const TfLiteContext * context,TfLiteNode * node,const TfLiteRegistration * registration,bool allow_quant_ops=false,const absl::flat_hash_set<TfLiteBuiltinOperator> * excluded_ops=nullptr)2904 absl::Status IsSupported(
2905 const TfLiteContext* context, TfLiteNode* node,
2906 const TfLiteRegistration* registration, bool allow_quant_ops = false,
2907 const absl::flat_hash_set<TfLiteBuiltinOperator>* excluded_ops = nullptr) {
2908 return NewOperationParser(registration, allow_quant_ops, excluded_ops)
2909 ->IsSupported(context, node, registration);
2910 }
2911
IsAllAllowedTensors(TfLiteContext * context,const TfLiteIntArray * tensor_indices,const std::vector<TfLiteType> & allowed_types)2912 bool IsAllAllowedTensors(TfLiteContext* context,
2913 const TfLiteIntArray* tensor_indices,
2914 const std::vector<TfLiteType>& allowed_types) {
2915 for (int i = 0; i < tensor_indices->size; ++i) {
2916 int tensor_idx = tensor_indices->data[i];
2917 if (tensor_idx == kTfLiteOptionalTensor) continue;
2918 const TfLiteTensor* t = &context->tensors[tensor_idx];
2919 if (t->dims && t->dims->size >= 5) {
2920 return false;
2921 }
2922 bool type_supported = false;
2923 for (auto allowed_type : allowed_types) {
2924 if (t->type == allowed_type) {
2925 type_supported = true;
2926 break;
2927 }
2928 }
2929 if (t->allocation_type == kTfLiteArenaRw && !type_supported) {
2930 return false;
2931 }
2932 }
2933 return true;
2934 }
2935 } // namespace
2936
NewOperationParser(const TfLiteRegistration * registration,bool allow_quant_ops,const absl::flat_hash_set<TfLiteBuiltinOperator> * excluded_ops)2937 std::unique_ptr<TFLiteOperationParser> NewOperationParser(
2938 const TfLiteRegistration* registration, bool allow_quant_ops,
2939 const absl::flat_hash_set<TfLiteBuiltinOperator>* excluded_ops) {
2940 const auto builtin_code = registration->builtin_code;
2941 if (excluded_ops != nullptr &&
2942 excluded_ops->contains(
2943 static_cast<TfLiteBuiltinOperator>(builtin_code))) {
2944 return std::make_unique<UnsupportedOperationParser>();
2945 }
2946 switch (builtin_code) {
2947 case kTfLiteBuiltinAbs:
2948 return std::make_unique<ElementwiseOperationParser>(OperationType::ABS);
2949 case kTfLiteBuiltinAdd:
2950 return std::make_unique<ElementwiseOperationParser>(OperationType::ADD);
2951 case kTfLiteBuiltinAveragePool2d:
2952 return std::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
2953 case kTfLiteBuiltinBatchMatmul:
2954 return std::make_unique<BatchedMatMulOperationParser>();
2955 case kTfLiteBuiltinCast:
2956 return std::make_unique<CastOperationParser>();
2957 case kTfLiteBuiltinConcatenation:
2958 return std::make_unique<ConcatenationOperationParser>();
2959 case kTfLiteBuiltinConv2d:
2960 return std::make_unique<Conv2DOperationParser>();
2961 case kTfLiteBuiltinCos:
2962 return std::make_unique<ElementwiseOperationParser>(OperationType::COS);
2963 case kTfLiteBuiltinCumsum:
2964 return std::make_unique<CumsumOperationParser>();
2965 case kTfLiteBuiltinDensify:
2966 return std::make_unique<DensifyOperationParser>();
2967 case kTfLiteBuiltinDepthwiseConv2d:
2968 return std::make_unique<DepthwiseConvolutionOperationParser>();
2969 case kTfLiteBuiltinDepthToSpace:
2970 return std::make_unique<DepthToSpaceOperationParser>();
2971 case kTfLiteBuiltinDequantize:
2972 if (allow_quant_ops) {
2973 return std::make_unique<DequantizeOperationParser>();
2974 }
2975 break;
2976 case kTfLiteBuiltinDiv:
2977 return std::make_unique<ElementwiseOperationParser>(OperationType::DIV);
2978 case kTfLiteBuiltinEqual:
2979 return std::make_unique<ElementwiseOperationParser>(OperationType::EQUAL);
2980 case kTfLiteBuiltinElu:
2981 return std::make_unique<ElementwiseOperationParser>(OperationType::ELU);
2982 case kTfLiteBuiltinExp:
2983 return std::make_unique<ElementwiseOperationParser>(OperationType::EXP);
2984 case kTfLiteBuiltinFloor:
2985 return std::make_unique<ElementwiseOperationParser>(OperationType::FLOOR);
2986 case kTfLiteBuiltinFloorDiv:
2987 return std::make_unique<ElementwiseOperationParser>(
2988 OperationType::FLOOR_DIV);
2989 case kTfLiteBuiltinFloorMod:
2990 return std::make_unique<ElementwiseOperationParser>(
2991 OperationType::FLOOR_MOD);
2992 case kTfLiteBuiltinFullyConnected:
2993 return std::make_unique<FullyConnectedOperationParser>();
2994 case kTfLiteBuiltinGreater:
2995 return std::make_unique<ElementwiseOperationParser>(
2996 OperationType::GREATER);
2997 case kTfLiteBuiltinGreaterEqual:
2998 return std::make_unique<ElementwiseOperationParser>(
2999 OperationType::GREATER_EQUAL);
3000 case kTfLiteBuiltinHardSwish:
3001 return std::make_unique<HardSwishOperationParser>();
3002 case kTfLiteBuiltinLess:
3003 return std::make_unique<ElementwiseOperationParser>(OperationType::LESS);
3004 case kTfLiteBuiltinLessEqual:
3005 return std::make_unique<ElementwiseOperationParser>(
3006 OperationType::LESS_EQUAL);
3007 case kTfLiteBuiltinLogistic:
3008 return std::make_unique<ElementwiseOperationParser>(
3009 OperationType::SIGMOID);
3010 case kTfLiteBuiltinLog:
3011 return std::make_unique<ElementwiseOperationParser>(OperationType::LOG);
3012 case kTfLiteBuiltinLstm:
3013 return std::make_unique<LSTMOperationParser>();
3014 case kTfLiteBuiltinMaximum:
3015 return std::make_unique<ElementwiseOperationParser>(
3016 OperationType::MAXIMUM);
3017 case kTfLiteBuiltinMaxPool2d:
3018 return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
3019 case kTfLiteBuiltinMean:
3020 return std::make_unique<MeanOperationParser>();
3021 case kTfLiteBuiltinMinimum:
3022 return std::make_unique<ElementwiseOperationParser>(
3023 OperationType::MINIMUM);
3024 case kTfLiteBuiltinMirrorPad:
3025 return std::make_unique<PadOperationParser>(/*mirror_pad=*/true);
3026 case kTfLiteBuiltinMul:
3027 return std::make_unique<ElementwiseOperationParser>(OperationType::MUL);
3028 case kTfLiteBuiltinNeg:
3029 return std::make_unique<ElementwiseOperationParser>(OperationType::NEG);
3030 case kTfLiteBuiltinNotEqual:
3031 return std::make_unique<ElementwiseOperationParser>(
3032 OperationType::NOT_EQUAL);
3033 case kTfLiteBuiltinOneHot:
3034 return std::make_unique<OneHotOperationParser>();
3035 case kTfLiteBuiltinPack:
3036 return std::make_unique<PackOperationParser>();
3037 case kTfLiteBuiltinPad:
3038 return std::make_unique<PadOperationParser>(/*mirror_pad=*/false);
3039 case kTfLiteBuiltinPow:
3040 return std::make_unique<ElementwiseOperationParser>(OperationType::POW);
3041 case kTfLiteBuiltinReduceMax:
3042 return std::make_unique<ReduceOperationParser>(
3043 OperationType::REDUCE_MAXIMUM);
3044 case kTfLiteBuiltinReduceMin:
3045 return std::make_unique<ReduceOperationParser>(
3046 OperationType::REDUCE_MINIMUM);
3047 case kTfLiteBuiltinReduceProd:
3048 return std::make_unique<ReduceOperationParser>(
3049 OperationType::REDUCE_PRODUCT);
3050 case kTfLiteBuiltinQuantize:
3051 if (allow_quant_ops) {
3052 return std::make_unique<QuantizeOperationParser>();
3053 }
3054 break;
3055 case kTfLiteBuiltinRelu:
3056 return std::make_unique<ReLUOperationParser>(0);
3057 case kTfLiteBuiltinRelu6:
3058 return std::make_unique<ReLUOperationParser>(6);
3059 case kTfLiteBuiltinReluN1To1:
3060 return std::make_unique<ClampOperationsParser>(-1.0, 1.0);
3061 case kTfLiteBuiltinLeakyRelu:
3062 return std::make_unique<ReLUOperationParser>(0);
3063 case kTfLiteBuiltinPrelu:
3064 return std::make_unique<PReLUOperationParser>();
3065 case kTfLiteBuiltinReshape:
3066 return std::make_unique<ReshapeOperationParser>();
3067 case kTfLiteBuiltinResizeBilinear:
3068 return std::make_unique<Resize2DOperationParser>(SamplingType::BILINEAR);
3069 case kTfLiteBuiltinResizeNearestNeighbor:
3070 return std::make_unique<Resize2DOperationParser>(SamplingType::NEAREST);
3071 case kTfLiteBuiltinRsqrt:
3072 return std::make_unique<ElementwiseOperationParser>(OperationType::RSQRT);
3073 case kTfLiteBuiltinSelectV2:
3074 return std::make_unique<SelectV2OperationParser>();
3075 case kTfLiteBuiltinSin:
3076 return std::make_unique<ElementwiseOperationParser>(OperationType::SIN);
3077 case kTfLiteBuiltinSlice:
3078 return std::make_unique<SliceOperationParser>();
3079 case kTfLiteBuiltinSoftmax:
3080 return std::make_unique<SoftmaxOperationParser>();
3081 case kTfLiteBuiltinSpaceToDepth:
3082 return std::make_unique<SpaceToDepthOperationParser>();
3083 case kTfLiteBuiltinSplit:
3084 return std::make_unique<SplitOperationParser>();
3085 case kTfLiteBuiltinSplitV:
3086 return std::make_unique<SplitVOperationParser>();
3087 case kTfLiteBuiltinSqrt:
3088 return std::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
3089 case kTfLiteBuiltinSquare:
3090 return std::make_unique<ElementwiseOperationParser>(
3091 OperationType::SQUARE);
3092 case kTfLiteBuiltinSquaredDifference:
3093 return std::make_unique<ElementwiseOperationParser>(
3094 OperationType::SQUARED_DIFF);
3095 case kTfLiteBuiltinStridedSlice:
3096 return std::make_unique<StridedSliceOperationParser>();
3097 case kTfLiteBuiltinSub:
3098 return std::make_unique<ElementwiseOperationParser>(OperationType::SUB);
3099 case kTfLiteBuiltinSum:
3100 return std::make_unique<ReduceOperationParser>(OperationType::REDUCE_SUM);
3101 case kTfLiteBuiltinTanh:
3102 return std::make_unique<ElementwiseOperationParser>(OperationType::TANH);
3103 case kTfLiteBuiltinTile:
3104 return std::make_unique<TileOperationParser>();
3105 case kTfLiteBuiltinTranspose:
3106 return std::make_unique<TransposeOperationParser>();
3107 case kTfLiteBuiltinTransposeConv:
3108 return std::make_unique<TransposeConvBuiltinOperationParser>();
3109 case kTfLiteBuiltinUnpack:
3110 return std::make_unique<UnpackOperationParser>();
3111 case kTfLiteBuiltinCustom: {
3112 const absl::string_view custom_name = registration->custom_name;
3113 if (custom_name == "Convolution2DTransposeBias") {
3114 return std::make_unique<TransposeConvCustomOperationParser>();
3115 }
3116 if (custom_name == "MaxPoolingWithArgmax2D") {
3117 return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
3118 }
3119 if (custom_name == "MaxUnpooling2D") {
3120 return std::make_unique<Unpooling2DOperationParser>();
3121 }
3122 if (custom_name == "Resampler") {
3123 return std::make_unique<ResamplerOperationParser>();
3124 }
3125 return NewCustomOperationParser(registration->custom_name);
3126 }
3127 }
3128 return std::make_unique<UnsupportedOperationParser>();
3129 }
3130
3131 // TODO(impjdi): Check number of input/output tensors and their dimensions.
3132 // TODO(impjdi): Check ops' parameters.
GetOpsToReplace(TfLiteContext * context,bool allow_quant_ops,int max_delegated_partitions,const absl::flat_hash_set<TfLiteBuiltinOperator> * excluded_ops)3133 TfLiteIntArray* GetOpsToReplace(
3134 TfLiteContext* context, bool allow_quant_ops, int max_delegated_partitions,
3135 const absl::flat_hash_set<TfLiteBuiltinOperator>* excluded_ops) {
3136 delegates::IsNodeSupportedFn node_supported_fn =
3137 [=](TfLiteContext* context, TfLiteNode* node,
3138 TfLiteRegistration* registration,
3139 std::string* unsupported_details) -> bool {
3140 const auto status =
3141 IsSupported(context, node, registration, allow_quant_ops, excluded_ops);
3142 if (!status.ok()) {
3143 if (unsupported_details) {
3144 *unsupported_details = std::string(status.message());
3145 }
3146 return false;
3147 }
3148
3149 std::vector<TfLiteType> allowed_in_types = {kTfLiteFloat32, kTfLiteFloat16};
3150 std::vector<TfLiteType> allowed_out_types = {kTfLiteFloat32,
3151 kTfLiteFloat16};
3152 if (allow_quant_ops) {
3153 // Since we only check non-constant tensors, type cannot be Int32.
3154 allowed_in_types.push_back(kTfLiteInt8);
3155 allowed_in_types.push_back(kTfLiteUInt8);
3156 allowed_out_types.push_back(kTfLiteInt8);
3157 allowed_out_types.push_back(kTfLiteUInt8);
3158 }
3159 if (IsLogicalCode(registration->builtin_code)) {
3160 allowed_out_types.push_back(kTfLiteBool);
3161 }
3162 if (registration->builtin_code == kTfLiteBuiltinCast) {
3163 allowed_in_types.push_back(kTfLiteBool);
3164 allowed_in_types.push_back(kTfLiteFloat32);
3165 allowed_in_types.push_back(kTfLiteInt32);
3166 allowed_out_types.push_back(kTfLiteFloat32);
3167 allowed_out_types.push_back(kTfLiteInt32);
3168 allowed_out_types.push_back(kTfLiteBool);
3169 }
3170 if (registration->builtin_code == kTfLiteBuiltinOneHot) {
3171 allowed_in_types.push_back(kTfLiteInt32);
3172 }
3173 if (registration->builtin_code == kTfLiteBuiltinSelectV2) {
3174 allowed_in_types.push_back(kTfLiteBool);
3175 }
3176 if (!IsAllAllowedTensors(context, node->inputs, allowed_in_types) ||
3177 !IsAllAllowedTensors(context, node->outputs, allowed_out_types)) {
3178 if (unsupported_details) {
3179 *unsupported_details =
3180 "OP is supported, but tensor type/shape isn't compatible.";
3181 }
3182 return false;
3183 }
3184 return true;
3185 };
3186
3187 delegates::FP16GraphPartitionHelper partition_helper(context,
3188 node_supported_fn);
3189 std::set<std::string> unsupported_nodes_info;
3190 if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) {
3191 return TfLiteIntArrayCreate(0);
3192 }
3193
3194 // By default, we simply get 1st largest partition as 'max_delegate_partions'
3195 // is set to 1 by default.
3196 std::vector<int> ops_to_replace =
3197 partition_helper.GetNodesOfFirstNLargestPartitions(
3198 max_delegated_partitions);
3199
3200 if (!unsupported_nodes_info.empty() &&
3201 partition_helper.num_total_nodes() > ops_to_replace.size()) {
3202 std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n");
3203 std::string error_message = absl::StrCat(
3204 "Following operations are not supported by GPU delegate:\n",
3205 unsupported, "\n");
3206 if (!ops_to_replace.empty()) {
3207 absl::StrAppend(
3208 &error_message, ops_to_replace.size(),
3209 " operations will run on the GPU, and the remaining ",
3210 partition_helper.num_total_nodes() - ops_to_replace.size());
3211 } else {
3212 absl::StrAppend(&error_message,
3213 "No operations will run on the GPU, and all ",
3214 partition_helper.num_total_nodes());
3215 }
3216 absl::StrAppend(&error_message, " operations will run on the CPU.");
3217 TF_LITE_KERNEL_LOG(context, error_message.c_str());
3218 }
3219 return ConvertVectorToTfLiteIntArray(ops_to_replace);
3220 }
3221
3222 // Creates inputs and outputs passed by io_tensors parameters in the resulting
3223 // graph. We force it to make sure that delegated subgraph has same order of
3224 // inputs and outputs with the original one. When delegated model is built from
3225 // the tflite model representation tensors are created lazily, so there is no
3226 // guarantee that the order will match the source model tensors order.
PrecreateIOTensors(TfLiteContext * context,GraphFloat32 * graph,const std::vector<int> & io_ids,absl::flat_hash_map<int,int> * quant_conversion_map,absl::flat_hash_map<int,Value * > * tensor_to_value)3227 absl::Status PrecreateIOTensors(
3228 TfLiteContext* context, GraphFloat32* graph, const std::vector<int>& io_ids,
3229 absl::flat_hash_map<int, int>* quant_conversion_map,
3230 absl::flat_hash_map<int, Value*>* tensor_to_value) {
3231 for (const auto& id : io_ids) {
3232 const TfLiteTensor& tflite_tensor = context->tensors[id];
3233 if (tflite::IsConstantTensor(&tflite_tensor)) continue;
3234 RETURN_IF_ERROR(ObjectReader::ReadNonConstantTensor(
3235 context, tensor_to_value, quant_conversion_map, graph, id));
3236 }
3237 return absl::OkStatus();
3238 }
3239
CopyVariableTensorOutputs(TfLiteNode * tflite_node,TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader & reader,const absl::flat_hash_map<int,ValueId> & new_variable_tensor_values)3240 absl::Status CopyVariableTensorOutputs(
3241 TfLiteNode* tflite_node, TfLiteRegistration* registration,
3242 GraphFloat32* graph, ObjectReader& reader,
3243 const absl::flat_hash_map<int, ValueId>& new_variable_tensor_values) {
3244 absl::flat_hash_map<int, ValueId> new_variable_tensor_values_copy(
3245 new_variable_tensor_values);
3246 // Retrieve the final value id for the variable input tensors.
3247 for (int i = 0; i < tflite_node->inputs->size; i++) {
3248 int tensor_idx = tflite_node->inputs->data[i];
3249 Value* value;
3250 if (!reader.ReadValueByTensorIdx(tensor_idx, &value).ok()) continue;
3251 if (value->tensor.is_variable_input) {
3252 if (new_variable_tensor_values_copy.find(i) ==
3253 new_variable_tensor_values_copy.end()) {
3254 return absl::InvalidArgumentError(
3255 absl::StrCat(GetOpNameByRegistration(*registration),
3256 " did not provide a new value for the variable input "
3257 "tensor with index ",
3258 tensor_idx));
3259 } else {
3260 Node* node = graph->NewNode();
3261 node->operation.type = ToString(OperationType::COPY);
3262 RETURN_IF_ERROR(graph->AddConsumer(
3263 node->id, new_variable_tensor_values_copy.at(i)));
3264 RETURN_IF_ERROR(reader.AddUpdate(node, i));
3265 new_variable_tensor_values_copy.erase(
3266 new_variable_tensor_values_copy.find(i));
3267 }
3268 }
3269 }
3270 if (!new_variable_tensor_values_copy.empty()) {
3271 return absl::InvalidArgumentError(
3272 "More input variable tensors asked to be copied than present on the "
3273 "node");
3274 }
3275 return absl::OkStatus();
3276 }
3277
BuildModel(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)3278 absl::Status BuildModel(TfLiteContext* context,
3279 const TfLiteDelegateParams* delegate_params,
3280 GraphFloat32* graph,
3281 absl::flat_hash_map<int, int>* quant_conversion_map) {
3282 std::vector<int> inputs(delegate_params->input_tensors->size);
3283 std::vector<int> outputs(delegate_params->output_tensors->size);
3284 for (int i = 0; i < delegate_params->input_tensors->size; i++) {
3285 inputs[i] = delegate_params->input_tensors->data[i];
3286 }
3287 for (int i = 0; i < delegate_params->output_tensors->size; i++) {
3288 outputs[i] = delegate_params->output_tensors->data[i];
3289 }
3290 return BuildModelEnforceIO(context, delegate_params, inputs, outputs, graph,
3291 quant_conversion_map);
3292 }
3293
BuildModelEnforceIO(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,const std::vector<int> & input_ids,const std::vector<int> & output_ids,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)3294 absl::Status BuildModelEnforceIO(
3295 TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
3296 const std::vector<int>& input_ids, const std::vector<int>& output_ids,
3297 GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) {
3298 std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
3299 std::vector<int> tflite_nodes;
3300 for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
3301 TfLiteNode* tflite_node = nullptr;
3302 TfLiteRegistration* registration = nullptr;
3303 RETURN_IF_ERROR(GetNodeAndRegistration(
3304 context, delegate_params->nodes_to_replace->data[i], &tflite_node,
3305 ®istration));
3306 if (registration->builtin_code == kTfLiteBuiltinDequantize &&
3307 context->tensors[tflite_node->inputs->data[0]].type ==
3308 TfLiteType::kTfLiteFloat16 &&
3309 context->tensors[tflite_node->inputs->data[0]].allocation_type ==
3310 TfLiteAllocationType::kTfLiteMmapRo) {
3311 // Ignore Fp16 Dequantize nodes only if they are the final nodes before
3312 // weights, i.e. no other nodes preceded them (e.g. DENSIFY).
3313 continue;
3314 }
3315 auto op_parser = NewOperationParser(
3316 registration, /*allow_quant_ops=*/quant_conversion_map != nullptr);
3317 if (!op_parser) {
3318 return absl::UnimplementedError(
3319 absl::StrCat("Operation ", registration->builtin_code, "(",
3320 registration->custom_name,
3321 ") is not supported by TFLite GPU Delegate."));
3322 }
3323 operations.push_back(std::move(op_parser));
3324 tflite_nodes.push_back(i);
3325 }
3326 absl::flat_hash_map<int, Value*> tensor_to_value;
3327 std::vector<ValueId> variable_inputs_to_value_id;
3328
3329 RETURN_IF_ERROR(PrecreateIOTensors(context, graph, input_ids,
3330 quant_conversion_map, &tensor_to_value));
3331 RETURN_IF_ERROR(PrecreateIOTensors(context, graph, output_ids,
3332 quant_conversion_map, &tensor_to_value));
3333 for (int i = 0; i < operations.size(); ++i) {
3334 TfLiteNode* tflite_node;
3335 TfLiteRegistration* registration;
3336 RETURN_IF_ERROR(GetNodeAndRegistration(
3337 context, delegate_params->nodes_to_replace->data[tflite_nodes[i]],
3338 &tflite_node, ®istration));
3339 ObjectReader reader(graph, context, tflite_node, &tensor_to_value,
3340 quant_conversion_map);
3341 const auto status =
3342 operations[i]->Parse(tflite_node, registration, graph, &reader);
3343 if (!status.ok()) {
3344 return absl::InternalError(absl::StrCat(
3345 GetOpNameByRegistration(*registration), ": ", status.message()));
3346 }
3347
3348 absl::flat_hash_map<int, ValueId> new_value_for_variable_input_tensors =
3349 operations[i]->GetNewValueIdsForVariableInputNodes();
3350
3351 RETURN_IF_ERROR(
3352 CopyVariableTensorOutputs(tflite_node, registration, graph, reader,
3353 new_value_for_variable_input_tensors));
3354 }
3355
3356 // Variable input tensors expect to be unchanged throughout model execution.
3357 // They need to be an output of the graph in order to have them unchanged.
3358 for (auto value_id : variable_inputs_to_value_id) {
3359 if (!graph->IsGraphOutput(value_id)) {
3360 return absl::InvalidArgumentError(
3361 absl::StrCat("Variable input tensors must be a graph output. Value ",
3362 value_id, " is not a graph output"));
3363 }
3364 }
3365 return absl::OkStatus();
3366 }
3367
BuildFinalModel(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)3368 absl::Status BuildFinalModel(
3369 TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
3370 GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) {
3371 RETURN_IF_ERROR(
3372 BuildModel(context, delegate_params, graph, quant_conversion_map));
3373
3374 // Apply general transformations on the graph.
3375 ModelTransformer transformer(graph);
3376 if (!ApplyModelTransformations(&transformer)) {
3377 return absl::InternalError("Graph transformations failed");
3378 }
3379 return absl::OkStatus();
3380 }
3381
3382 namespace {
3383
3384 class DelegateContext {
3385 public:
3386 struct DelegateData {
3387 std::vector<int> input_ids;
3388 std::vector<int> output_ids;
3389 GraphFloat32* graph;
3390 std::unique_ptr<absl::flat_hash_map<int, int>> quant_conversion_map;
3391 };
Init(TfLiteContext * context,const TfLiteDelegateParams * delegate_params)3392 bool Init(TfLiteContext* context,
3393 const TfLiteDelegateParams* delegate_params) {
3394 const auto* delegate_data =
3395 reinterpret_cast<DelegateData*>(delegate_params->delegate->data_);
3396 return delegate_data->graph &&
3397 BuildModelEnforceIO(context, delegate_params,
3398 delegate_data->input_ids,
3399 delegate_data->output_ids, delegate_data->graph,
3400 delegate_data->quant_conversion_map.get())
3401 .ok();
3402 }
3403 };
3404
DelegatePrepare(TfLiteContext * context,TfLiteDelegate * delegate)3405 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
3406 TfLiteRegistration registration{};
3407 registration.init = [](TfLiteContext* context, const char* buffer,
3408 size_t) -> void* {
3409 auto* delegate_context = new DelegateContext();
3410 if (!delegate_context->Init(
3411 context, reinterpret_cast<const TfLiteDelegateParams*>(buffer))) {
3412 delete delegate_context;
3413 return nullptr;
3414 }
3415 return delegate_context;
3416 };
3417 registration.free = [](TfLiteContext* context, void* buffer) -> void {
3418 delete reinterpret_cast<DelegateContext*>(buffer);
3419 };
3420 registration.prepare = [](TfLiteContext* context,
3421 TfLiteNode* node) -> TfLiteStatus {
3422 return node->user_data ? kTfLiteOk : kTfLiteError;
3423 };
3424
3425 const auto* delegate_data =
3426 reinterpret_cast<const DelegateContext::DelegateData*>(delegate->data_);
3427 TfLiteIntArray* ops_to_replace = GetOpsToReplace(
3428 context, static_cast<bool>(delegate_data->quant_conversion_map));
3429 const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
3430 context, registration, ops_to_replace, delegate);
3431 TfLiteIntArrayFree(ops_to_replace);
3432 return status;
3433 }
3434
3435 } // namespace
3436
BuildFromFlatBuffer(const tflite::FlatBufferModel & flatbuffer,const tflite::OpResolver & op_resolver,GraphFloat32 * graph,bool allow_quant_ops)3437 absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
3438 const tflite::OpResolver& op_resolver,
3439 GraphFloat32* graph, bool allow_quant_ops) {
3440 std::unique_ptr<tflite::Interpreter> interpreter;
3441 tflite::InterpreterBuilder interpreter_builder(flatbuffer, op_resolver);
3442 if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) {
3443 return absl::InternalError("Unable to prepare TfLite interpreter.");
3444 }
3445 TfLiteDelegate delegate;
3446
3447 DelegateContext::DelegateData delegate_data{interpreter->inputs(),
3448 interpreter->outputs(), graph};
3449 if (allow_quant_ops) {
3450 delegate_data.quant_conversion_map =
3451 std::make_unique<absl::flat_hash_map<int, int>>();
3452 }
3453
3454 delegate.data_ = &delegate_data;
3455 delegate.flags = kTfLiteDelegateFlagsNone;
3456 delegate.Prepare = DelegatePrepare;
3457 delegate.CopyFromBufferHandle = nullptr;
3458 delegate.CopyToBufferHandle = nullptr;
3459 delegate.FreeBufferHandle = nullptr;
3460
3461 if (interpreter->ModifyGraphWithDelegate(&delegate) != kTfLiteOk) {
3462 return absl::InternalError("Conversion from TfLite model failed.");
3463 }
3464
3465 ModelTransformer transformer(graph);
3466 if (!ApplyModelTransformations(&transformer)) {
3467 return absl::InternalError("Graph transformations failed");
3468 }
3469
3470 return absl::OkStatus();
3471 }
3472
3473 } // namespace gpu
3474 } // namespace tflite
3475