1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/tflite/blacklist_base.h"
18
19 #include <cstdint>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/lite/context.h"
23 #include "tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
24 namespace tflite {
25 namespace ops {
26 namespace custom {
27 namespace libtextclassifier3 {
28 namespace blacklist {
29
30 static const int kOutputCategories = 0;
31
Free(TfLiteContext * context,void * buffer)32 void Free(TfLiteContext* context, void* buffer) {
33 delete reinterpret_cast<BlacklistOpBase*>(buffer);
34 }
35
Resize(TfLiteContext * context,TfLiteNode * node)36 TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
37 auto* op = reinterpret_cast<BlacklistOpBase*>(node->user_data);
38
39 TfLiteIntArray* input_dims = op->GetInputShape(context, node);
40 TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_dims->size + 1);
41 for (int i = 0; i < input_dims->size; i++) {
42 output_dims->data[i] = input_dims->data[i];
43 }
44 output_dims->data[input_dims->size] = op->categories();
45 return context->ResizeTensor(
46 context, &context->tensors[node->outputs->data[kOutputCategories]],
47 output_dims);
48 }
49
Eval(TfLiteContext * context,TfLiteNode * node)50 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
51 auto* op = reinterpret_cast<BlacklistOpBase*>(node->user_data);
52
53 TfLiteTensor* output_categories =
54 &context->tensors[node->outputs->data[kOutputCategories]];
55
56 TfLiteIntArray* input_dims = op->GetInputShape(context, node);
57 int input_size = 1;
58 for (int i = 0; i < input_dims->size; i++) {
59 input_size *= input_dims->data[i];
60 }
61 const int n_categories = op->categories();
62
63 TF_LITE_ENSURE_STATUS(op->InitializeInput(context, node));
64 if (output_categories->type == kTfLiteFloat32) {
65 for (int i = 0; i < input_size; i++) {
66 absl::flat_hash_set<int> categories = op->GetCategories(i);
67 if (categories.empty()) {
68 for (int j = 0; j < n_categories; j++) {
69 output_categories->data.f[i * n_categories + j] =
70 (j < op->negative_categories()) ? 1.0 : 0.0;
71 }
72 } else {
73 for (int j = 0; j < n_categories; j++) {
74 output_categories->data.f[i * n_categories + j] =
75 (categories.find(j) != categories.end()) ? 1.0 : 0.0;
76 }
77 }
78 }
79 } else if (output_categories->type == kTfLiteUInt8) {
80 const uint8_t one =
81 ::seq_flow_lite::PodQuantize(1.0, output_categories->params.zero_point,
82 1.0 / output_categories->params.scale);
83 const uint8_t zero =
84 ::seq_flow_lite::PodQuantize(0.0, output_categories->params.zero_point,
85 1.0 / output_categories->params.scale);
86 for (int i = 0; i < input_size; i++) {
87 absl::flat_hash_set<int> categories = op->GetCategories(i);
88 if (categories.empty()) {
89 for (int j = 0; j < n_categories; j++) {
90 output_categories->data.uint8[i * n_categories + j] =
91 (j < op->negative_categories()) ? one : zero;
92 }
93 } else {
94 for (int j = 0; j < n_categories; j++) {
95 output_categories->data.uint8[i * n_categories + j] =
96 (categories.find(j) != categories.end()) ? one : zero;
97 }
98 }
99 }
100 }
101 op->FinalizeInput();
102 return kTfLiteOk;
103 }
104
105 } // namespace blacklist
106 } // namespace libtextclassifier3
107 } // namespace custom
108 } // namespace ops
109 } // namespace tflite
110