xref: /aosp_15_r20/external/libtextclassifier/native/utils/tflite/blacklist_base.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/tflite/blacklist_base.h"
18 
19 #include <cstdint>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/lite/context.h"
23 #include "tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
24 namespace tflite {
25 namespace ops {
26 namespace custom {
27 namespace libtextclassifier3 {
28 namespace blacklist {
29 
30 static const int kOutputCategories = 0;
31 
Free(TfLiteContext * context,void * buffer)32 void Free(TfLiteContext* context, void* buffer) {
33   delete reinterpret_cast<BlacklistOpBase*>(buffer);
34 }
35 
Resize(TfLiteContext * context,TfLiteNode * node)36 TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
37   auto* op = reinterpret_cast<BlacklistOpBase*>(node->user_data);
38 
39   TfLiteIntArray* input_dims = op->GetInputShape(context, node);
40   TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_dims->size + 1);
41   for (int i = 0; i < input_dims->size; i++) {
42     output_dims->data[i] = input_dims->data[i];
43   }
44   output_dims->data[input_dims->size] = op->categories();
45   return context->ResizeTensor(
46       context, &context->tensors[node->outputs->data[kOutputCategories]],
47       output_dims);
48 }
49 
Eval(TfLiteContext * context,TfLiteNode * node)50 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
51   auto* op = reinterpret_cast<BlacklistOpBase*>(node->user_data);
52 
53   TfLiteTensor* output_categories =
54       &context->tensors[node->outputs->data[kOutputCategories]];
55 
56   TfLiteIntArray* input_dims = op->GetInputShape(context, node);
57   int input_size = 1;
58   for (int i = 0; i < input_dims->size; i++) {
59     input_size *= input_dims->data[i];
60   }
61   const int n_categories = op->categories();
62 
63   TF_LITE_ENSURE_STATUS(op->InitializeInput(context, node));
64   if (output_categories->type == kTfLiteFloat32) {
65     for (int i = 0; i < input_size; i++) {
66       absl::flat_hash_set<int> categories = op->GetCategories(i);
67       if (categories.empty()) {
68         for (int j = 0; j < n_categories; j++) {
69           output_categories->data.f[i * n_categories + j] =
70               (j < op->negative_categories()) ? 1.0 : 0.0;
71         }
72       } else {
73         for (int j = 0; j < n_categories; j++) {
74           output_categories->data.f[i * n_categories + j] =
75               (categories.find(j) != categories.end()) ? 1.0 : 0.0;
76         }
77       }
78     }
79   } else if (output_categories->type == kTfLiteUInt8) {
80     const uint8_t one =
81         ::seq_flow_lite::PodQuantize(1.0, output_categories->params.zero_point,
82                                      1.0 / output_categories->params.scale);
83     const uint8_t zero =
84         ::seq_flow_lite::PodQuantize(0.0, output_categories->params.zero_point,
85                                      1.0 / output_categories->params.scale);
86     for (int i = 0; i < input_size; i++) {
87       absl::flat_hash_set<int> categories = op->GetCategories(i);
88       if (categories.empty()) {
89         for (int j = 0; j < n_categories; j++) {
90           output_categories->data.uint8[i * n_categories + j] =
91               (j < op->negative_categories()) ? one : zero;
92         }
93       } else {
94         for (int j = 0; j < n_categories; j++) {
95           output_categories->data.uint8[i * n_categories + j] =
96               (categories.find(j) != categories.end()) ? one : zero;
97         }
98       }
99     }
100   }
101   op->FinalizeInput();
102   return kTfLiteOk;
103 }
104 
105 }  // namespace blacklist
106 }  // namespace libtextclassifier3
107 }  // namespace custom
108 }  // namespace ops
109 }  // namespace tflite
110