xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/delegate_test_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/delegate_test_util.h"
17 
18 #include <stdint.h>
19 #include <stdlib.h>
20 #include <string.h>
21 
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include <gtest/gtest.h>
27 #include "third_party/eigen3/Eigen/Core"
28 #include "tensorflow/lite/builtin_ops.h"
29 #include "tensorflow/lite/c/builtin_op_data.h"
30 #include "tensorflow/lite/delegates/utils.h"
31 #include "tensorflow/lite/interpreter.h"
32 #include "tensorflow/lite/kernels/builtin_op_kernels.h"
33 #include "tensorflow/lite/kernels/internal/compatibility.h"
34 #include "tensorflow/lite/kernels/kernel_util.h"
35 #include "tensorflow/lite/schema/schema_generated.h"
36 #include "tensorflow/lite/string_type.h"
37 #include "tensorflow/lite/util.h"
38 
39 namespace tflite {
40 namespace delegates {
41 namespace test_utils {
42 
AddOpRegistration()43 TfLiteRegistration AddOpRegistration() {
44   TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
45 
46   reg.custom_name = "my_add";
47   reg.builtin_code = tflite::BuiltinOperator_CUSTOM;
48 
49   reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
50     const TfLiteTensor* input1;
51     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1));
52     const TfLiteTensor* input2;
53     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input2));
54     TfLiteTensor* output;
55     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
56 
57     // Verify that the two inputs have the same shape.
58     TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size);
59     for (int i = 0; i < input1->dims->size; ++i) {
60       TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]);
61     }
62 
63     // Set output shape to match input shape.
64     TF_LITE_ENSURE_STATUS(context->ResizeTensor(
65         context, output, TfLiteIntArrayCopy(input1->dims)));
66     return kTfLiteOk;
67   };
68 
69   reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
70     const TfLiteTensor* a0;
71     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
72     TF_LITE_ENSURE(context, a0);
73     TF_LITE_ENSURE(context, a0->data.f);
74     const TfLiteTensor* a1;
75     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &a1));
76     TF_LITE_ENSURE(context, a1);
77     TF_LITE_ENSURE(context, a1->data.f);
78     TfLiteTensor* out;
79     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out));
80     TF_LITE_ENSURE(context, out);
81     TF_LITE_ENSURE(context, out->data.f);
82     // Set output data to element-wise sum of input data.
83     int num = a0->dims->data[0];
84     for (int i = 0; i < num; i++) {
85       out->data.f[i] = a0->data.f[i] + a1->data.f[i];
86     }
87     return kTfLiteOk;
88   };
89   return reg;
90 }
91 
SetUpSubgraph(Subgraph * subgraph)92 void TestDelegation::SetUpSubgraph(Subgraph* subgraph) {
93   subgraph->AddTensors(5);
94   subgraph->SetInputs({0, 1});
95   subgraph->SetOutputs({3, 4});
96   std::vector<int> dims({3});
97   TfLiteQuantization quant{kTfLiteNoQuantization, nullptr};
98   subgraph->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", dims.size(),
99                                          dims.data(), quant, false);
100   subgraph->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", dims.size(),
101                                          dims.data(), quant, false);
102   subgraph->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", dims.size(),
103                                          dims.data(), quant, false);
104   subgraph->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", dims.size(),
105                                          dims.data(), quant, false);
106   subgraph->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", dims.size(),
107                                          dims.data(), quant, false);
108   TfLiteRegistration reg = AddOpRegistration();
109   int node_index_ignored;
110   subgraph->AddNodeWithParameters({0, 0}, {2}, {}, nullptr, 0, nullptr, &reg,
111                                   &node_index_ignored);
112   subgraph->AddNodeWithParameters({1, 1}, {3}, {}, nullptr, 0, nullptr, &reg,
113                                   &node_index_ignored);
114   subgraph->AddNodeWithParameters({2, 1}, {4}, {}, nullptr, 0, nullptr, &reg,
115                                   &node_index_ignored);
116 }
117 
AddSubgraphs(int subgraphs_to_add,int * first_new_subgraph_index)118 void TestDelegation::AddSubgraphs(int subgraphs_to_add,
119                                   int* first_new_subgraph_index) {
120   interpreter_->AddSubgraphs(subgraphs_to_add, first_new_subgraph_index);
121 }
122 
SetUp()123 void TestDelegate::SetUp() {
124   interpreter_ = TestDelegation::NewInterpreterWithDefaultDelegates();
125   SetUpSubgraph(&interpreter_->primary_subgraph());
126 }
127 
TearDown()128 void TestDelegate::TearDown() {
129   // Interpreter relies on delegate to free the resources properly. Thus
130   // the life cycle of delegate must be longer than interpreter.
131   interpreter_.reset();
132   delegate_.reset();
133   delegate2_.reset();
134 }
135 
SetUp()136 void TestTwoDelegates::SetUp() {
137   interpreter_ = TestDelegation::NewInterpreterWithDefaultDelegates();
138   SetUpSubgraph(&interpreter_->primary_subgraph());
139 }
140 
TearDown()141 void TestTwoDelegates::TearDown() {
142   // Interpreter relies on delegate to free the resources properly. Thus
143   // the life cycle of delegate must be longer than interpreter.
144   interpreter_.reset();
145   delegate_.reset();
146   delegate2_.reset();
147 }
148 
SimpleDelegate(const std::vector<int> & nodes,int64_t delegate_flags,bool fail_node_prepare,int min_ops_per_subset,bool fail_node_invoke,bool automatic_shape_propagation,bool custom_op,bool set_output_tensor_dynamic)149 SimpleDelegate::SimpleDelegate(const std::vector<int>& nodes,
150                                int64_t delegate_flags, bool fail_node_prepare,
151                                int min_ops_per_subset, bool fail_node_invoke,
152                                bool automatic_shape_propagation, bool custom_op,
153                                bool set_output_tensor_dynamic)
154     : nodes_(nodes),
155       fail_delegate_node_prepare_(fail_node_prepare),
156       min_ops_per_subset_(min_ops_per_subset),
157       fail_delegate_node_invoke_(fail_node_invoke),
158       automatic_shape_propagation_(automatic_shape_propagation),
159       custom_op_(custom_op),
160       set_output_tensor_dynamic_(set_output_tensor_dynamic) {
161   delegate_.Prepare = [](TfLiteContext* context,
162                          TfLiteDelegate* delegate) -> TfLiteStatus {
163     auto* simple = static_cast<SimpleDelegate*>(delegate->data_);
164     TfLiteIntArray* nodes_to_separate =
165         TfLiteIntArrayCreate(simple->nodes_.size());
166     // Mark nodes that we want in TfLiteIntArray* structure.
167     int index = 0;
168     for (auto node_index : simple->nodes_) {
169       nodes_to_separate->data[index++] = node_index;
170       // make sure node is added
171       TfLiteNode* node;
172       TfLiteRegistration* reg;
173       context->GetNodeAndRegistration(context, node_index, &node, &reg);
174       if (simple->custom_op_) {
175         TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
176         TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
177       } else {
178         TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_ADD);
179       }
180     }
181     // Check that all nodes are available
182     TfLiteIntArray* execution_plan;
183     TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
184     for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
185       int node_index = execution_plan->data[exec_index];
186       TfLiteNode* node;
187       TfLiteRegistration* reg;
188       context->GetNodeAndRegistration(context, node_index, &node, &reg);
189       if (exec_index == node_index) {
190         // Check op details only if it wasn't delegated already.
191         if (simple->custom_op_) {
192           TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
193           TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
194         } else {
195           TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_ADD);
196         }
197       }
198     }
199 
200     // Get preview of delegate partitioning from the context.
201     TfLiteDelegateParams* params_array;
202     int num_partitions;
203     TFLITE_CHECK_EQ(
204         context->PreviewDelegatePartitioning(context, nodes_to_separate,
205                                              &params_array, &num_partitions),
206         kTfLiteOk);
207 
208     if (simple->min_ops_per_subset() > 0) {
209       // Build a new vector of ops from subsets with at least the minimum
210       // size.
211       std::vector<int> allowed_ops;
212       for (int idx = 0; idx < num_partitions; ++idx) {
213         const auto* nodes_in_subset = params_array[idx].nodes_to_replace;
214         if (nodes_in_subset->size < simple->min_ops_per_subset()) continue;
215         allowed_ops.insert(allowed_ops.end(), nodes_in_subset->data,
216                            nodes_in_subset->data + nodes_in_subset->size);
217       }
218 
219       // Free existing nodes_to_separate & initialize a new array with
220       // allowed_ops.
221       TfLiteIntArrayFree(nodes_to_separate);
222       nodes_to_separate = TfLiteIntArrayCreate(allowed_ops.size());
223       memcpy(nodes_to_separate->data, allowed_ops.data(),
224              sizeof(int) * nodes_to_separate->size);
225     }
226 
227     // Another call to PreviewDelegatePartitioning should be okay, since
228     // partitioning memory is managed by context.
229     TFLITE_CHECK_EQ(
230         context->PreviewDelegatePartitioning(context, nodes_to_separate,
231                                              &params_array, &num_partitions),
232         kTfLiteOk);
233 
234     context->ReplaceNodeSubsetsWithDelegateKernels(
235         context, simple->FakeFusedRegistration(), nodes_to_separate, delegate);
236     TfLiteIntArrayFree(nodes_to_separate);
237     return kTfLiteOk;
238   };
239   delegate_.CopyToBufferHandle = [](TfLiteContext* context,
240                                     TfLiteDelegate* delegate,
241                                     TfLiteBufferHandle buffer_handle,
242                                     TfLiteTensor* tensor) -> TfLiteStatus {
243     // TODO(b/156586986): Implement tests to test buffer copying logic.
244     return kTfLiteOk;
245   };
246   delegate_.CopyFromBufferHandle = [](TfLiteContext* context,
247                                       TfLiteDelegate* delegate,
248                                       TfLiteBufferHandle buffer_handle,
249                                       TfLiteTensor* output) -> TfLiteStatus {
250     TFLITE_CHECK_GE(buffer_handle, -1);
251     TFLITE_CHECK_EQ(output->buffer_handle, buffer_handle);
252     const float floats[] = {6., 6., 6.};
253     int num = output->dims->data[0];
254     for (int i = 0; i < num; i++) {
255       output->data.f[i] = floats[i];
256     }
257     return kTfLiteOk;
258   };
259 
260   delegate_.FreeBufferHandle =
261       [](TfLiteContext* context, TfLiteDelegate* delegate,
262          TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; };
263   // Store type-punned data SimpleDelegate structure.
264   delegate_.data_ = static_cast<void*>(this);
265   delegate_.flags = delegate_flags;
266 }
267 
FakeFusedRegistration()268 TfLiteRegistration SimpleDelegate::FakeFusedRegistration() {
269   TfLiteRegistration reg = {nullptr};
270   reg.custom_name = "fake_fused_op";
271 
272   // Different flavors of the delegate kernel's Invoke(), dependent on
273   // testing parameters.
274   if (fail_delegate_node_invoke_) {
275     reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
276       return kTfLiteError;
277     };
278   } else {
279     reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
280       // Compute output data as elementwise sum of the two input arguments:
281       //   func(x, y) = x + y
282       // or for a single argument compute 2 * x:
283       //   func(x) = x + x
284       const TfLiteTensor* a0;
285       const TfLiteTensor* a1;
286       if (node->inputs->size == 2) {
287         a0 = GetInput(context, node, 0);
288         a1 = GetInput(context, node, 1);
289       } else {
290         a0 = GetInput(context, node, 0);
291         a1 = a0;
292       }
293       TfLiteTensor* out;
294       TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out));
295       int num = 1;
296       for (int i = 0; i < a0->dims->size; ++i) {
297         num *= a0->dims->data[i];
298       }
299       for (int i = 0; i < num; i++) {
300         out->data.f[i] = a0->data.f[i] + a1->data.f[i];
301       }
302       if (out->buffer_handle != kTfLiteNullBufferHandle) {
303         // Make the data stale so that CopyFromBufferHandle can be invoked
304         out->data_is_stale = true;
305       }
306       return kTfLiteOk;
307     };
308   }
309 
310   // Different flavors of the delegate kernel's Prepare(), dependent on
311   // testing parameters.
312   if (automatic_shape_propagation_) {
313     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
314       // Shapes should already by propagated by the runtime, just need to
315       // check.
316       const TfLiteTensor* input1;
317       TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1));
318       TfLiteTensor* output;
319       TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
320       const int input_dims_size = input1->dims->size;
321       TF_LITE_ENSURE(context, output->dims->size == input_dims_size);
322       for (int i = 0; i < input_dims_size; ++i) {
323         TF_LITE_ENSURE(context, output->dims->data[i] == input1->dims->data[i]);
324       }
325       return kTfLiteOk;
326     };
327   } else if (fail_delegate_node_prepare_) {
328     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
329       return kTfLiteError;
330     };
331   } else if (set_output_tensor_dynamic_) {
332     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
333       TfLiteTensor* output;
334       TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
335       SetTensorToDynamic(output);
336       return kTfLiteOk;
337     };
338   } else {
339     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
340       // Set output size to input size
341       const TfLiteTensor* input1;
342       const TfLiteTensor* input2;
343       if (node->inputs->size == 2) {
344         input1 = GetInput(context, node, 0);
345         input2 = GetInput(context, node, 1);
346       } else {
347         input1 = GetInput(context, node, 0);
348         input2 = input1;
349       }
350       TfLiteTensor* output;
351       TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
352 
353       TF_LITE_ENSURE_STATUS(context->ResizeTensor(
354           context, output, TfLiteIntArrayCopy(input1->dims)));
355       return kTfLiteOk;
356     };
357   }
358 
359   return reg;
360 }
361 
362 std::unique_ptr<SimpleDelegate>
DelegateWithRuntimeShapePropagation(const std::vector<int> & nodes,int64_t delegate_flags,int min_ops_per_subset)363 SimpleDelegate::DelegateWithRuntimeShapePropagation(
364     const std::vector<int>& nodes, int64_t delegate_flags,
365     int min_ops_per_subset) {
366   return std::make_unique<SimpleDelegate>(
367       nodes, delegate_flags, false /**fail_node_prepare**/,
368       min_ops_per_subset /**min_ops_per_subset**/, false /**fail_node_invoke**/,
369       true /**automatic_shape_propagation**/);
370 }
371 
DelegateWithDynamicOutput(const std::vector<int> & nodes)372 std::unique_ptr<SimpleDelegate> SimpleDelegate::DelegateWithDynamicOutput(
373     const std::vector<int>& nodes) {
374   // All params default except nodes & set_output_tensor_dynamic.
375   return std::make_unique<SimpleDelegate>(
376       nodes, kTfLiteDelegateFlagsAllowDynamicTensors,
377       false /**fail_node_prepare**/, 0 /**min_ops_per_subset**/,
378       false /**fail_node_invoke**/, false /**automatic_shape_propagation**/,
379       true /**custom_op**/, true /**set_output_tensor_dynamic**/);
380 }
381 
SetUp()382 void TestFP16Delegation::SetUp() {
383   interpreter_ = TestDelegation::NewInterpreterWithDefaultDelegates();
384   interpreter_->AddTensors(13);
385   interpreter_->SetInputs({0});
386   interpreter_->SetOutputs({12});
387 
388   float16_const_ = Eigen::half(2.f);
389 
390   // TENSORS.
391   TfLiteQuantizationParams quant;
392   // Input.
393   interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {1}, quant);
394   // fp16 constant, dequantize output, Add0 output.
395   interpreter_->SetTensorParametersReadOnly(
396       1, kTfLiteFloat16, "", {1}, quant,
397       reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
398   interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {1}, quant);
399   interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {1}, quant);
400   // fp16 constant, dequantize output, Add1 output.
401   interpreter_->SetTensorParametersReadOnly(
402       4, kTfLiteFloat16, "", {1}, quant,
403       reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
404   interpreter_->SetTensorParametersReadWrite(5, kTfLiteFloat32, "", {1}, quant);
405   interpreter_->SetTensorParametersReadWrite(6, kTfLiteFloat32, "", {1}, quant);
406   // fp16 constant, dequantize output, Mul0 output.
407   interpreter_->SetTensorParametersReadOnly(
408       7, kTfLiteFloat16, "", {1}, quant,
409       reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
410   interpreter_->SetTensorParametersReadWrite(8, kTfLiteFloat32, "", {1}, quant);
411   interpreter_->SetTensorParametersReadWrite(9, kTfLiteFloat32, "", {1}, quant);
412   // fp16 constant, dequantize output, Add2 output.
413   interpreter_->SetTensorParametersReadOnly(
414       10, kTfLiteFloat16, "", {1}, quant,
415       reinterpret_cast<const char*>(&float16_const_), sizeof(TfLiteFloat16));
416   interpreter_->SetTensorParametersReadWrite(11, kTfLiteFloat32, "", {1},
417                                              quant);
418   interpreter_->SetTensorParametersReadWrite(12, kTfLiteFloat32, "", {1},
419                                              quant);
420 
421   // NODES.
422   auto* add_reg = ops::builtin::Register_ADD();
423   auto* mul_reg = ops::builtin::Register_MUL();
424   auto* deq_reg = ops::builtin::Register_DEQUANTIZE();
425   add_reg->builtin_code = kTfLiteBuiltinAdd;
426   deq_reg->builtin_code = kTfLiteBuiltinDequantize;
427   mul_reg->builtin_code = kTfLiteBuiltinMul;
428   TfLiteAddParams* builtin_data0 =
429       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
430   TfLiteAddParams* builtin_data1 =
431       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
432   TfLiteMulParams* builtin_data2 =
433       reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
434   TfLiteAddParams* builtin_data3 =
435       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
436   builtin_data0->activation = kTfLiteActNone;
437   builtin_data1->activation = kTfLiteActNone;
438   builtin_data2->activation = kTfLiteActNone;
439   builtin_data3->activation = kTfLiteActNone;
440   interpreter_->AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr, deq_reg);
441   interpreter_->AddNodeWithParameters({0, 2}, {3}, nullptr, 0, builtin_data0,
442                                       add_reg);
443   interpreter_->AddNodeWithParameters({4}, {5}, nullptr, 0, nullptr, deq_reg);
444   interpreter_->AddNodeWithParameters({3, 5}, {6}, nullptr, 0, builtin_data1,
445                                       add_reg);
446   interpreter_->AddNodeWithParameters({7}, {8}, nullptr, 0, nullptr, deq_reg);
447   interpreter_->AddNodeWithParameters({6, 8}, {9}, nullptr, 0, builtin_data2,
448                                       mul_reg);
449   interpreter_->AddNodeWithParameters({10}, {11}, nullptr, 0, nullptr, deq_reg);
450   interpreter_->AddNodeWithParameters({9, 11}, {12}, nullptr, 0, builtin_data3,
451                                       add_reg);
452 }
453 
VerifyInvoke()454 void TestFP16Delegation::VerifyInvoke() {
455   std::vector<float> input = {3.0f};
456   std::vector<float> expected_output = {16.0f};
457 
458   const int input_tensor_idx = interpreter_->inputs()[0];
459   const int output_tensor_idx = interpreter_->outputs()[0];
460 
461   memcpy(interpreter_->typed_tensor<float>(input_tensor_idx), input.data(),
462          sizeof(float));
463   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
464   TfLiteTensor* output_tensor = interpreter_->tensor(output_tensor_idx);
465   for (int i = 0; i < 1; ++i) {
466     EXPECT_EQ(output_tensor->data.f[i], expected_output[i]) << i;
467   }
468 }
469 
FP16Delegate(int num_delegated_subsets,bool fail_node_prepare,bool fail_node_invoke)470 TestFP16Delegation::FP16Delegate::FP16Delegate(int num_delegated_subsets,
471                                                bool fail_node_prepare,
472                                                bool fail_node_invoke)
473     : num_delegated_subsets_(num_delegated_subsets),
474       fail_delegate_node_prepare_(fail_node_prepare),
475       fail_delegate_node_invoke_(fail_node_invoke) {
476   delegate_.Prepare = [](TfLiteContext* context,
477                          TfLiteDelegate* delegate) -> TfLiteStatus {
478     auto* fp16_delegate = static_cast<FP16Delegate*>(delegate->data_);
479     // FP16 graph partitioning.
480     delegates::IsNodeSupportedFn node_supported_fn =
481         [=](TfLiteContext* context, TfLiteNode* node,
482             TfLiteRegistration* registration,
483             std::string* unsupported_details) -> bool {
484       return registration->builtin_code == kTfLiteBuiltinAdd;
485     };
486     delegates::FP16GraphPartitionHelper partition_helper(context,
487                                                          node_supported_fn);
488     TfLiteIntArray* nodes_to_separate = nullptr;
489     if (partition_helper.Partition(nullptr) != kTfLiteOk) {
490       nodes_to_separate = TfLiteIntArrayCreate(0);
491     } else {
492       std::vector<int> ops_to_replace =
493           partition_helper.GetNodesOfFirstNLargestPartitions(
494               fp16_delegate->num_delegated_subsets());
495       nodes_to_separate = ConvertVectorToTfLiteIntArray(ops_to_replace);
496     }
497 
498     context->ReplaceNodeSubsetsWithDelegateKernels(
499         context, fp16_delegate->FakeFusedRegistration(), nodes_to_separate,
500         delegate);
501     TfLiteIntArrayFree(nodes_to_separate);
502     return kTfLiteOk;
503   };
504   delegate_.CopyFromBufferHandle =
505       [](TfLiteContext* context, TfLiteDelegate* delegate,
506          TfLiteBufferHandle buffer_handle,
507          TfLiteTensor* output) -> TfLiteStatus { return kTfLiteOk; };
508   delegate_.FreeBufferHandle = nullptr;
509   delegate_.CopyToBufferHandle = nullptr;
510   // Store type-punned data SimpleDelegate structure.
511   delegate_.data_ = static_cast<void*>(this);
512   delegate_.flags = kTfLiteDelegateFlagsNone;
513 }
514 
FakeFusedRegistration()515 TfLiteRegistration TestFP16Delegation::FP16Delegate::FakeFusedRegistration() {
516   TfLiteRegistration reg = {nullptr};
517   reg.custom_name = "fake_fp16_add_op";
518 
519   // Different flavors of the delegate kernel's Invoke(), dependent on
520   // testing parameters.
521   if (fail_delegate_node_invoke_) {
522     reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
523       return kTfLiteError;
524     };
525   } else {
526     reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
527       float output = 0;
528       for (int i = 0; i < node->inputs->size; ++i) {
529         const TfLiteTensor* input_tensor = GetInput(context, node, i);
530         if (input_tensor->type == kTfLiteFloat32) {
531           output += input_tensor->data.f[0];
532         } else {
533           // All constants are 2.
534           output += 2;
535         }
536       }
537       TfLiteTensor* out = GetOutput(context, node, 0);
538       out->data.f[0] = output;
539       return kTfLiteOk;
540     };
541   }
542 
543   // Different flavors of the delegate kernel's Prepare(), dependent on
544   // testing parameters.
545   if (fail_delegate_node_prepare_) {
546     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
547       return kTfLiteError;
548     };
549   } else {
550     reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
551       // Set output size to input size
552       const TfLiteTensor* input = GetInput(context, node, 0);
553       TfLiteTensor* output = GetOutput(context, node, 0);
554       TF_LITE_ENSURE_STATUS(context->ResizeTensor(
555           context, output, TfLiteIntArrayCopy(input->dims)));
556       return kTfLiteOk;
557     };
558   }
559 
560   return reg;
561 }
562 
563 }  // namespace test_utils
564 }  // namespace delegates
565 }  // namespace tflite
566