1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/delegate_test_util.h"
17
18 #include <stdint.h>
19 #include <stdlib.h>
20 #include <string.h>
21
22 #include <memory>
23 #include <string>
24 #include <vector>
25
26 #include <gtest/gtest.h>
27 #include "third_party/eigen3/Eigen/Core"
28 #include "tensorflow/lite/builtin_ops.h"
29 #include "tensorflow/lite/c/builtin_op_data.h"
30 #include "tensorflow/lite/delegates/utils.h"
31 #include "tensorflow/lite/interpreter.h"
32 #include "tensorflow/lite/kernels/builtin_op_kernels.h"
33 #include "tensorflow/lite/kernels/internal/compatibility.h"
34 #include "tensorflow/lite/kernels/kernel_util.h"
35 #include "tensorflow/lite/schema/schema_generated.h"
36 #include "tensorflow/lite/string_type.h"
37 #include "tensorflow/lite/util.h"
38
39 namespace tflite {
40 namespace delegates {
41 namespace test_utils {
42
AddOpRegistration()43 TfLiteRegistration AddOpRegistration() {
44 TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
45
46 reg.custom_name = "my_add";
47 reg.builtin_code = tflite::BuiltinOperator_CUSTOM;
48
49 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
50 const TfLiteTensor* input1;
51 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1));
52 const TfLiteTensor* input2;
53 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input2));
54 TfLiteTensor* output;
55 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
56
57 // Verify that the two inputs have the same shape.
58 TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size);
59 for (int i = 0; i < input1->dims->size; ++i) {
60 TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]);
61 }
62
63 // Set output shape to match input shape.
64 TF_LITE_ENSURE_STATUS(context->ResizeTensor(
65 context, output, TfLiteIntArrayCopy(input1->dims)));
66 return kTfLiteOk;
67 };
68
69 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
70 const TfLiteTensor* a0;
71 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
72 TF_LITE_ENSURE(context, a0);
73 TF_LITE_ENSURE(context, a0->data.f);
74 const TfLiteTensor* a1;
75 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &a1));
76 TF_LITE_ENSURE(context, a1);
77 TF_LITE_ENSURE(context, a1->data.f);
78 TfLiteTensor* out;
79 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out));
80 TF_LITE_ENSURE(context, out);
81 TF_LITE_ENSURE(context, out->data.f);
82 // Set output data to element-wise sum of input data.
83 int num = a0->dims->data[0];
84 for (int i = 0; i < num; i++) {
85 out->data.f[i] = a0->data.f[i] + a1->data.f[i];
86 }
87 return kTfLiteOk;
88 };
89 return reg;
90 }
91
SetUpSubgraph(Subgraph * subgraph)92 void TestDelegation::SetUpSubgraph(Subgraph* subgraph) {
93 subgraph->AddTensors(5);
94 subgraph->SetInputs({0, 1});
95 subgraph->SetOutputs({3, 4});
96 std::vector<int> dims({3});
97 TfLiteQuantization quant{kTfLiteNoQuantization, nullptr};
98 subgraph->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", dims.size(),
99 dims.data(), quant, false);
100 subgraph->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", dims.size(),
101 dims.data(), quant, false);
102 subgraph->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", dims.size(),
103 dims.data(), quant, false);
104 subgraph->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", dims.size(),
105 dims.data(), quant, false);
106 subgraph->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", dims.size(),
107 dims.data(), quant, false);
108 TfLiteRegistration reg = AddOpRegistration();
109 int node_index_ignored;
110 subgraph->AddNodeWithParameters({0, 0}, {2}, {}, nullptr, 0, nullptr, ®,
111 &node_index_ignored);
112 subgraph->AddNodeWithParameters({1, 1}, {3}, {}, nullptr, 0, nullptr, ®,
113 &node_index_ignored);
114 subgraph->AddNodeWithParameters({2, 1}, {4}, {}, nullptr, 0, nullptr, ®,
115 &node_index_ignored);
116 }
117
AddSubgraphs(int subgraphs_to_add,int * first_new_subgraph_index)118 void TestDelegation::AddSubgraphs(int subgraphs_to_add,
119 int* first_new_subgraph_index) {
120 interpreter_->AddSubgraphs(subgraphs_to_add, first_new_subgraph_index);
121 }
122
SetUp()123 void TestDelegate::SetUp() {
124 interpreter_ = TestDelegation::NewInterpreterWithDefaultDelegates();
125 SetUpSubgraph(&interpreter_->primary_subgraph());
126 }
127
TearDown()128 void TestDelegate::TearDown() {
129 // Interpreter relies on delegate to free the resources properly. Thus
130 // the life cycle of delegate must be longer than interpreter.
131 interpreter_.reset();
132 delegate_.reset();
133 delegate2_.reset();
134 }
135
SetUp()136 void TestTwoDelegates::SetUp() {
137 interpreter_ = TestDelegation::NewInterpreterWithDefaultDelegates();
138 SetUpSubgraph(&interpreter_->primary_subgraph());
139 }
140
TearDown()141 void TestTwoDelegates::TearDown() {
142 // Interpreter relies on delegate to free the resources properly. Thus
143 // the life cycle of delegate must be longer than interpreter.
144 interpreter_.reset();
145 delegate_.reset();
146 delegate2_.reset();
147 }
148
SimpleDelegate(const std::vector<int> & nodes,int64_t delegate_flags,bool fail_node_prepare,int min_ops_per_subset,bool fail_node_invoke,bool automatic_shape_propagation,bool custom_op,bool set_output_tensor_dynamic)149 SimpleDelegate::SimpleDelegate(const std::vector<int>& nodes,
150 int64_t delegate_flags, bool fail_node_prepare,
151 int min_ops_per_subset, bool fail_node_invoke,
152 bool automatic_shape_propagation, bool custom_op,
153 bool set_output_tensor_dynamic)
154 : nodes_(nodes),
155 fail_delegate_node_prepare_(fail_node_prepare),
156 min_ops_per_subset_(min_ops_per_subset),
157 fail_delegate_node_invoke_(fail_node_invoke),
158 automatic_shape_propagation_(automatic_shape_propagation),
159 custom_op_(custom_op),
160 set_output_tensor_dynamic_(set_output_tensor_dynamic) {
161 delegate_.Prepare = [](TfLiteContext* context,
162 TfLiteDelegate* delegate) -> TfLiteStatus {
163 auto* simple = static_cast<SimpleDelegate*>(delegate->data_);
164 TfLiteIntArray* nodes_to_separate =
165 TfLiteIntArrayCreate(simple->nodes_.size());
166 // Mark nodes that we want in TfLiteIntArray* structure.
167 int index = 0;
168 for (auto node_index : simple->nodes_) {
169 nodes_to_separate->data[index++] = node_index;
170 // make sure node is added
171 TfLiteNode* node;
172 TfLiteRegistration* reg;
173 context->GetNodeAndRegistration(context, node_index, &node, ®);
174 if (simple->custom_op_) {
175 TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
176 TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
177 } else {
178 TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_ADD);
179 }
180 }
181 // Check that all nodes are available
182 TfLiteIntArray* execution_plan;
183 TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
184 for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
185 int node_index = execution_plan->data[exec_index];
186 TfLiteNode* node;
187 TfLiteRegistration* reg;
188 context->GetNodeAndRegistration(context, node_index, &node, ®);
189 if (exec_index == node_index) {
190 // Check op details only if it wasn't delegated already.
191 if (simple->custom_op_) {
192 TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
193 TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
194 } else {
195 TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_ADD);
196 }
197 }
198 }
199
200 // Get preview of delegate partitioning from the context.
201 TfLiteDelegateParams* params_array;
202 int num_partitions;
203 TFLITE_CHECK_EQ(
204 context->PreviewDelegatePartitioning(context, nodes_to_separate,
205 ¶ms_array, &num_partitions),
206 kTfLiteOk);
207
208 if (simple->min_ops_per_subset() > 0) {
209 // Build a new vector of ops from subsets with at least the minimum
210 // size.
211 std::vector<int> allowed_ops;
212 for (int idx = 0; idx < num_partitions; ++idx) {
213 const auto* nodes_in_subset = params_array[idx].nodes_to_replace;
214 if (nodes_in_subset->size < simple->min_ops_per_subset()) continue;
215 allowed_ops.insert(allowed_ops.end(), nodes_in_subset->data,
216 nodes_in_subset->data + nodes_in_subset->size);
217 }
218
219 // Free existing nodes_to_separate & initialize a new array with
220 // allowed_ops.
221 TfLiteIntArrayFree(nodes_to_separate);
222 nodes_to_separate = TfLiteIntArrayCreate(allowed_ops.size());
223 memcpy(nodes_to_separate->data, allowed_ops.data(),
224 sizeof(int) * nodes_to_separate->size);
225 }
226
227 // Another call to PreviewDelegatePartitioning should be okay, since
228 // partitioning memory is managed by context.
229 TFLITE_CHECK_EQ(
230 context->PreviewDelegatePartitioning(context, nodes_to_separate,
231 ¶ms_array, &num_partitions),
232 kTfLiteOk);
233
234 context->ReplaceNodeSubsetsWithDelegateKernels(
235 context, simple->FakeFusedRegistration(), nodes_to_separate, delegate);
236 TfLiteIntArrayFree(nodes_to_separate);
237 return kTfLiteOk;
238 };
239 delegate_.CopyToBufferHandle = [](TfLiteContext* context,
240 TfLiteDelegate* delegate,
241 TfLiteBufferHandle buffer_handle,
242 TfLiteTensor* tensor) -> TfLiteStatus {
243 // TODO(b/156586986): Implement tests to test buffer copying logic.
244 return kTfLiteOk;
245 };
246 delegate_.CopyFromBufferHandle = [](TfLiteContext* context,
247 TfLiteDelegate* delegate,
248 TfLiteBufferHandle buffer_handle,
249 TfLiteTensor* output) -> TfLiteStatus {
250 TFLITE_CHECK_GE(buffer_handle, -1);
251 TFLITE_CHECK_EQ(output->buffer_handle, buffer_handle);
252 const float floats[] = {6., 6., 6.};
253 int num = output->dims->data[0];
254 for (int i = 0; i < num; i++) {
255 output->data.f[i] = floats[i];
256 }
257 return kTfLiteOk;
258 };
259
260 delegate_.FreeBufferHandle =
261 [](TfLiteContext* context, TfLiteDelegate* delegate,
262 TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; };
263 // Store type-punned data SimpleDelegate structure.
264 delegate_.data_ = static_cast<void*>(this);
265 delegate_.flags = delegate_flags;
266 }
267
FakeFusedRegistration()268 TfLiteRegistration SimpleDelegate::FakeFusedRegistration() {
269 TfLiteRegistration reg = {nullptr};
270 reg.custom_name = "fake_fused_op";
271
272 // Different flavors of the delegate kernel's Invoke(), dependent on
273 // testing parameters.
274 if (fail_delegate_node_invoke_) {
275 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
276 return kTfLiteError;
277 };
278 } else {
279 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
280 // Compute output data as elementwise sum of the two input arguments:
281 // func(x, y) = x + y
282 // or for a single argument compute 2 * x:
283 // func(x) = x + x
284 const TfLiteTensor* a0;
285 const TfLiteTensor* a1;
286 if (node->inputs->size == 2) {
287 a0 = GetInput(context, node, 0);
288 a1 = GetInput(context, node, 1);
289 } else {
290 a0 = GetInput(context, node, 0);
291 a1 = a0;
292 }
293 TfLiteTensor* out;
294 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out));
295 int num = 1;
296 for (int i = 0; i < a0->dims->size; ++i) {
297 num *= a0->dims->data[i];
298 }
299 for (int i = 0; i < num; i++) {
300 out->data.f[i] = a0->data.f[i] + a1->data.f[i];
301 }
302 if (out->buffer_handle != kTfLiteNullBufferHandle) {
303 // Make the data stale so that CopyFromBufferHandle can be invoked
304 out->data_is_stale = true;
305 }
306 return kTfLiteOk;
307 };
308 }
309
310 // Different flavors of the delegate kernel's Prepare(), dependent on
311 // testing parameters.
312 if (automatic_shape_propagation_) {
313 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
314 // Shapes should already by propagated by the runtime, just need to
315 // check.
316 const TfLiteTensor* input1;
317 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1));
318 TfLiteTensor* output;
319 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
320 const int input_dims_size = input1->dims->size;
321 TF_LITE_ENSURE(context, output->dims->size == input_dims_size);
322 for (int i = 0; i < input_dims_size; ++i) {
323 TF_LITE_ENSURE(context, output->dims->data[i] == input1->dims->data[i]);
324 }
325 return kTfLiteOk;
326 };
327 } else if (fail_delegate_node_prepare_) {
328 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
329 return kTfLiteError;
330 };
331 } else if (set_output_tensor_dynamic_) {
332 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
333 TfLiteTensor* output;
334 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
335 SetTensorToDynamic(output);
336 return kTfLiteOk;
337 };
338 } else {
339 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
340 // Set output size to input size
341 const TfLiteTensor* input1;
342 const TfLiteTensor* input2;
343 if (node->inputs->size == 2) {
344 input1 = GetInput(context, node, 0);
345 input2 = GetInput(context, node, 1);
346 } else {
347 input1 = GetInput(context, node, 0);
348 input2 = input1;
349 }
350 TfLiteTensor* output;
351 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
352
353 TF_LITE_ENSURE_STATUS(context->ResizeTensor(
354 context, output, TfLiteIntArrayCopy(input1->dims)));
355 return kTfLiteOk;
356 };
357 }
358
359 return reg;
360 }
361
362 std::unique_ptr<SimpleDelegate>
DelegateWithRuntimeShapePropagation(const std::vector<int> & nodes,int64_t delegate_flags,int min_ops_per_subset)363 SimpleDelegate::DelegateWithRuntimeShapePropagation(
364 const std::vector<int>& nodes, int64_t delegate_flags,
365 int min_ops_per_subset) {
366 return std::make_unique<SimpleDelegate>(
367 nodes, delegate_flags, false /**fail_node_prepare**/,
368 min_ops_per_subset /**min_ops_per_subset**/, false /**fail_node_invoke**/,
369 true /**automatic_shape_propagation**/);
370 }
371
DelegateWithDynamicOutput(const std::vector<int> & nodes)372 std::unique_ptr<SimpleDelegate> SimpleDelegate::DelegateWithDynamicOutput(
373 const std::vector<int>& nodes) {
374 // All params default except nodes & set_output_tensor_dynamic.
375 return std::make_unique<SimpleDelegate>(
376 nodes, kTfLiteDelegateFlagsAllowDynamicTensors,
377 false /**fail_node_prepare**/, 0 /**min_ops_per_subset**/,
378 false /**fail_node_invoke**/, false /**automatic_shape_propagation**/,
379 true /**custom_op**/, true /**set_output_tensor_dynamic**/);
380 }
381
SetUp()382 void TestFP16Delegation::SetUp() {
383 interpreter_ = TestDelegation::NewInterpreterWithDefaultDelegates();
384 interpreter_->AddTensors(13);
385 interpreter_->SetInputs({0});
386 interpreter_->SetOutputs({12});
387
388 float16_const_ = Eigen::half(2.f);
389
390 // TENSORS.
391 TfLiteQuantizationParams quant;
392 // Input.
393 interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {1}, quant);
394 // fp16 constant, dequantize output, Add0 output.
395 interpreter_->SetTensorParametersReadOnly(
396 1, kTfLiteFloat16, "", {1}, quant,
397 reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
398 interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {1}, quant);
399 interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {1}, quant);
400 // fp16 constant, dequantize output, Add1 output.
401 interpreter_->SetTensorParametersReadOnly(
402 4, kTfLiteFloat16, "", {1}, quant,
403 reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
404 interpreter_->SetTensorParametersReadWrite(5, kTfLiteFloat32, "", {1}, quant);
405 interpreter_->SetTensorParametersReadWrite(6, kTfLiteFloat32, "", {1}, quant);
406 // fp16 constant, dequantize output, Mul0 output.
407 interpreter_->SetTensorParametersReadOnly(
408 7, kTfLiteFloat16, "", {1}, quant,
409 reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
410 interpreter_->SetTensorParametersReadWrite(8, kTfLiteFloat32, "", {1}, quant);
411 interpreter_->SetTensorParametersReadWrite(9, kTfLiteFloat32, "", {1}, quant);
412 // fp16 constant, dequantize output, Add2 output.
413 interpreter_->SetTensorParametersReadOnly(
414 10, kTfLiteFloat16, "", {1}, quant,
415 reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
416 interpreter_->SetTensorParametersReadWrite(11, kTfLiteFloat32, "", {1},
417 quant);
418 interpreter_->SetTensorParametersReadWrite(12, kTfLiteFloat32, "", {1},
419 quant);
420
421 // NODES.
422 auto* add_reg = ops::builtin::Register_ADD();
423 auto* mul_reg = ops::builtin::Register_MUL();
424 auto* deq_reg = ops::builtin::Register_DEQUANTIZE();
425 add_reg->builtin_code = kTfLiteBuiltinAdd;
426 deq_reg->builtin_code = kTfLiteBuiltinDequantize;
427 mul_reg->builtin_code = kTfLiteBuiltinMul;
428 TfLiteAddParams* builtin_data0 =
429 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
430 TfLiteAddParams* builtin_data1 =
431 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
432 TfLiteMulParams* builtin_data2 =
433 reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
434 TfLiteAddParams* builtin_data3 =
435 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
436 builtin_data0->activation = kTfLiteActNone;
437 builtin_data1->activation = kTfLiteActNone;
438 builtin_data2->activation = kTfLiteActNone;
439 builtin_data3->activation = kTfLiteActNone;
440 interpreter_->AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr, deq_reg);
441 interpreter_->AddNodeWithParameters({0, 2}, {3}, nullptr, 0, builtin_data0,
442 add_reg);
443 interpreter_->AddNodeWithParameters({4}, {5}, nullptr, 0, nullptr, deq_reg);
444 interpreter_->AddNodeWithParameters({3, 5}, {6}, nullptr, 0, builtin_data1,
445 add_reg);
446 interpreter_->AddNodeWithParameters({7}, {8}, nullptr, 0, nullptr, deq_reg);
447 interpreter_->AddNodeWithParameters({6, 8}, {9}, nullptr, 0, builtin_data2,
448 mul_reg);
449 interpreter_->AddNodeWithParameters({10}, {11}, nullptr, 0, nullptr, deq_reg);
450 interpreter_->AddNodeWithParameters({9, 11}, {12}, nullptr, 0, builtin_data3,
451 add_reg);
452 }
453
VerifyInvoke()454 void TestFP16Delegation::VerifyInvoke() {
455 std::vector<float> input = {3.0f};
456 std::vector<float> expected_output = {16.0f};
457
458 const int input_tensor_idx = interpreter_->inputs()[0];
459 const int output_tensor_idx = interpreter_->outputs()[0];
460
461 memcpy(interpreter_->typed_tensor<float>(input_tensor_idx), input.data(),
462 sizeof(float));
463 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
464 TfLiteTensor* output_tensor = interpreter_->tensor(output_tensor_idx);
465 for (int i = 0; i < 1; ++i) {
466 EXPECT_EQ(output_tensor->data.f[i], expected_output[i]) << i;
467 }
468 }
469
FP16Delegate(int num_delegated_subsets,bool fail_node_prepare,bool fail_node_invoke)470 TestFP16Delegation::FP16Delegate::FP16Delegate(int num_delegated_subsets,
471 bool fail_node_prepare,
472 bool fail_node_invoke)
473 : num_delegated_subsets_(num_delegated_subsets),
474 fail_delegate_node_prepare_(fail_node_prepare),
475 fail_delegate_node_invoke_(fail_node_invoke) {
476 delegate_.Prepare = [](TfLiteContext* context,
477 TfLiteDelegate* delegate) -> TfLiteStatus {
478 auto* fp16_delegate = static_cast<FP16Delegate*>(delegate->data_);
479 // FP16 graph partitioning.
480 delegates::IsNodeSupportedFn node_supported_fn =
481 [=](TfLiteContext* context, TfLiteNode* node,
482 TfLiteRegistration* registration,
483 std::string* unsupported_details) -> bool {
484 return registration->builtin_code == kTfLiteBuiltinAdd;
485 };
486 delegates::FP16GraphPartitionHelper partition_helper(context,
487 node_supported_fn);
488 TfLiteIntArray* nodes_to_separate = nullptr;
489 if (partition_helper.Partition(nullptr) != kTfLiteOk) {
490 nodes_to_separate = TfLiteIntArrayCreate(0);
491 } else {
492 std::vector<int> ops_to_replace =
493 partition_helper.GetNodesOfFirstNLargestPartitions(
494 fp16_delegate->num_delegated_subsets());
495 nodes_to_separate = ConvertVectorToTfLiteIntArray(ops_to_replace);
496 }
497
498 context->ReplaceNodeSubsetsWithDelegateKernels(
499 context, fp16_delegate->FakeFusedRegistration(), nodes_to_separate,
500 delegate);
501 TfLiteIntArrayFree(nodes_to_separate);
502 return kTfLiteOk;
503 };
504 delegate_.CopyFromBufferHandle =
505 [](TfLiteContext* context, TfLiteDelegate* delegate,
506 TfLiteBufferHandle buffer_handle,
507 TfLiteTensor* output) -> TfLiteStatus { return kTfLiteOk; };
508 delegate_.FreeBufferHandle = nullptr;
509 delegate_.CopyToBufferHandle = nullptr;
510 // Store type-punned data SimpleDelegate structure.
511 delegate_.data_ = static_cast<void*>(this);
512 delegate_.flags = kTfLiteDelegateFlagsNone;
513 }
514
FakeFusedRegistration()515 TfLiteRegistration TestFP16Delegation::FP16Delegate::FakeFusedRegistration() {
516 TfLiteRegistration reg = {nullptr};
517 reg.custom_name = "fake_fp16_add_op";
518
519 // Different flavors of the delegate kernel's Invoke(), dependent on
520 // testing parameters.
521 if (fail_delegate_node_invoke_) {
522 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
523 return kTfLiteError;
524 };
525 } else {
526 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
527 float output = 0;
528 for (int i = 0; i < node->inputs->size; ++i) {
529 const TfLiteTensor* input_tensor = GetInput(context, node, i);
530 if (input_tensor->type == kTfLiteFloat32) {
531 output += input_tensor->data.f[0];
532 } else {
533 // All constants are 2.
534 output += 2;
535 }
536 }
537 TfLiteTensor* out = GetOutput(context, node, 0);
538 out->data.f[0] = output;
539 return kTfLiteOk;
540 };
541 }
542
543 // Different flavors of the delegate kernel's Prepare(), dependent on
544 // testing parameters.
545 if (fail_delegate_node_prepare_) {
546 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
547 return kTfLiteError;
548 };
549 } else {
550 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
551 // Set output size to input size
552 const TfLiteTensor* input = GetInput(context, node, 0);
553 TfLiteTensor* output = GetOutput(context, node, 0);
554 TF_LITE_ENSURE_STATUS(context->ResizeTensor(
555 context, output, TfLiteIntArrayCopy(input->dims)));
556 return kTfLiteOk;
557 };
558 }
559
560 return reg;
561 }
562
563 } // namespace test_utils
564 } // namespace delegates
565 } // namespace tflite
566