xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/hashtable_lookup.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Op that looks up items from hashtable.
17 //
18 // Input:
19 //     Tensor[0]: Hash key to lookup, dim.size == 1, int32
20 //     Tensor[1]: Key of hashtable, dim.size == 1, int32
21 //                *MUST* be sorted in ascending order.
22 //     Tensor[2]: Value of hashtable, dim.size >= 1
23 //                Tensor[1].Dim[0] == Tensor[2].Dim[0]
24 //
25 // Output:
26 //   Output[0].dim[0] == Tensor[0].dim[0], num of lookups
27 //   Each item in output is a raw bytes copy of corresponding item in input.
28 //   When key does not exist in hashtable, the returned bytes are all 0s.
29 //
30 //   Output[1].dim = { Tensor[0].dim[0] }, num of lookups
31 //   Each item indicates whether the corresponding lookup has a returned value.
32 //   0 for missing key, 1 for found key.
33 
34 #include <stdint.h>
35 
36 #include <cstdlib>
37 #include <cstring>
38 
39 #include "tensorflow/lite/c/common.h"
40 #include "tensorflow/lite/kernels/internal/compatibility.h"
41 #include "tensorflow/lite/kernels/kernel_util.h"
42 #include "tensorflow/lite/string_util.h"
43 
44 namespace tflite {
45 namespace ops {
46 namespace builtin {
47 
48 namespace {
49 
greater(const void * a,const void * b)50 int greater(const void* a, const void* b) {
51   return *static_cast<const int*>(a) - *static_cast<const int*>(b);
52 }
53 
Prepare(TfLiteContext * context,TfLiteNode * node)54 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
55   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
56   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
57 
58   const TfLiteTensor* lookup;
59   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
60   TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
61   TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
62 
63   const TfLiteTensor* key;
64   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &key));
65   TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1);
66   TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32);
67 
68   const TfLiteTensor* value;
69   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &value));
70   TF_LITE_ENSURE(context, NumDimensions(value) >= 1);
71   TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0),
72                     SizeOfDimension(value, 0));
73   if (value->type == kTfLiteString) {
74     TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1);
75   }
76 
77   TfLiteTensor* hits;
78   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &hits));
79   TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8);
80   TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1);
81   hitSize->data[0] = SizeOfDimension(lookup, 0);
82 
83   TfLiteTensor* output;
84   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
85   TF_LITE_ENSURE_EQ(context, value->type, output->type);
86 
87   TfLiteStatus status = kTfLiteOk;
88   if (output->type != kTfLiteString) {
89     TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
90     outputSize->data[0] = SizeOfDimension(lookup, 0);
91     for (int i = 1; i < NumDimensions(value); i++) {
92       outputSize->data[i] = SizeOfDimension(value, i);
93     }
94     status = context->ResizeTensor(context, output, outputSize);
95   }
96   if (context->ResizeTensor(context, hits, hitSize) != kTfLiteOk) {
97     status = kTfLiteError;
98   }
99   return status;
100 }
101 
Eval(TfLiteContext * context,TfLiteNode * node)102 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
103   TfLiteTensor* output;
104   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
105   TfLiteTensor* hits;
106   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &hits));
107   const TfLiteTensor* lookup;
108   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
109   const TfLiteTensor* key;
110   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &key));
111   const TfLiteTensor* value;
112   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &value));
113 
114   const int num_rows = SizeOfDimension(value, 0);
115   TF_LITE_ENSURE(context, num_rows != 0);
116   const int row_bytes = value->bytes / num_rows;
117   void* pointer = nullptr;
118   DynamicBuffer buf;
119 
120   for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
121     int idx = -1;
122     pointer = bsearch(&(lookup->data.i32[i]), key->data.i32, num_rows,
123                       sizeof(int32_t), greater);
124     if (pointer != nullptr) {
125       idx = (reinterpret_cast<char*>(pointer) - (key->data.raw)) /
126             sizeof(int32_t);
127     }
128 
129     if (idx >= num_rows || idx < 0) {
130       if (output->type == kTfLiteString) {
131         buf.AddString(nullptr, 0);
132       } else {
133         memset(output->data.raw + i * row_bytes, 0, row_bytes);
134       }
135       hits->data.uint8[i] = 0;
136     } else {
137       if (output->type == kTfLiteString) {
138         buf.AddString(GetString(value, idx));
139       } else {
140         memcpy(output->data.raw + i * row_bytes,
141                value->data.raw + idx * row_bytes, row_bytes);
142       }
143       hits->data.uint8[i] = 1;
144     }
145   }
146   if (output->type == kTfLiteString) {
147     buf.WriteToTensorAsVector(output);
148   }
149 
150   return kTfLiteOk;
151 }
152 }  // namespace
153 
Register_HASHTABLE_LOOKUP()154 TfLiteRegistration* Register_HASHTABLE_LOOKUP() {
155   static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
156   return &r;
157 }
158 
159 }  // namespace builtin
160 }  // namespace ops
161 }  // namespace tflite
162