xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/lsh_projection.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // LSH Projection projects an input to a bit vector via locality sensitive
17 // hashing.
18 //
19 // Options:
20 //   Sparse:
21 //     Computed bit vector is considered to be sparse.
22 //     Each output element is an int32 made up by multiple bits computed from
23 // hash functions.
24 //
25 //   Dense:
26 //     Computed bit vector is considered to be dense. Each output element is
27 // either 0 or 1 that represents a bit.
28 //
29 // Input:
30 //   Tensor[0]: Hash functions. Dim.size == 2, DataType: Float.
31 //              Tensor[0].Dim[0]: Num of hash functions. Must be at least 1.
32 //              Tensor[0].Dim[1]: Num of projected output bits generated by
33 //                                each hash function.
34 //   In sparse case, Tensor[0].Dim[1] + ceil( log2(Tensor[0].Dim[0] )) <= 32.
35 //
36 //   Tensor[1]: Input. Dim.size >= 1, No restriction on DataType.
37 //   Tensor[2]: Optional, Weight. Dim.size == 1, DataType: Float.
38 //              If not set, each element of input is considered to have same
39 // weight of 1.0 Tensor[1].Dim[0] == Tensor[2].Dim[0]
40 //
41 // Output:
42 //   Sparse:
43 //     Output.Dim == { Tensor[0].Dim[0] }
44 //     A tensor of int32 that represents hash signatures,
45 //
46 //     NOTE: To avoid collisions across hash functions, an offset value of
47 //     k * (1 << Tensor[0].Dim[1]) will be added to each signature,
48 //     k is the index of the hash function.
49 //   Dense:
50 //     Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] }
51 //     A flattened tensor represents projected bit vectors.
52 
53 #include <stddef.h>
54 #include <stdint.h>
55 
56 #include <cstring>
57 #include <memory>
58 
59 #include "tensorflow/lite/c/builtin_op_data.h"
60 #include "tensorflow/lite/c/common.h"
61 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
62 #include "tensorflow/lite/kernels/kernel_util.h"
63 #include "utils/hash/farmhash.h"
64 
65 namespace tflite {
66 namespace ops {
67 namespace builtin {
68 namespace lsh_projection {
69 
Resize(TfLiteContext * context,TfLiteNode * node)70 TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
71   auto* params =
72       reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data);
73   TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
74   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
75 
76   const TfLiteTensor* hash;
77   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &hash));
78   TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2);
79   // Support up to 32 bits.
80   TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32);
81 
82   const TfLiteTensor* input;
83   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input));
84   TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
85   TF_LITE_ENSURE(context, SizeOfDimension(input, 0) >= 1);
86 
87   if (NumInputs(node) == 3) {
88     const TfLiteTensor* weight;
89     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &weight));
90     TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1);
91     TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0),
92                       SizeOfDimension(input, 0));
93   }
94 
95   TfLiteTensor* output;
96   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
97   TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
98   switch (params->type) {
99     case kTfLiteLshProjectionSparse:
100       outputSize->data[0] = SizeOfDimension(hash, 0);
101       break;
102     case kTfLiteLshProjectionDense:
103       outputSize->data[0] = SizeOfDimension(hash, 0) * SizeOfDimension(hash, 1);
104       break;
105     default:
106       return kTfLiteError;
107   }
108   return context->ResizeTensor(context, output, outputSize);
109 }
110 
111 // Compute sign bit of dot product of hash(seed, input) and weight.
112 // NOTE: use float as seed, and convert it to double as a temporary solution
113 //       to match the trained model. This is going to be changed once the new
114 //       model is trained in an optimized method.
115 //
RunningSignBit(const TfLiteTensor * input,const TfLiteTensor * weight,float seed)116 int RunningSignBit(const TfLiteTensor* input, const TfLiteTensor* weight,
117                    float seed) {
118   double score = 0.0;
119   int input_item_bytes = input->bytes / SizeOfDimension(input, 0);
120   char* input_ptr = input->data.raw;
121 
122   const size_t seed_size = sizeof(float);
123   const size_t key_bytes = sizeof(float) + input_item_bytes;
124   std::unique_ptr<char[]> key(new char[key_bytes]);
125 
126   const float* weight_ptr = GetTensorData<float>(weight);
127 
128   for (int i = 0; i < SizeOfDimension(input, 0); ++i) {
129     // Create running hash id and value for current dimension.
130     memcpy(key.get(), &seed, seed_size);
131     memcpy(key.get() + seed_size, input_ptr, input_item_bytes);
132 
133     int64_t hash_signature = farmhash::Fingerprint64(key.get(), key_bytes);
134     double running_value = static_cast<double>(hash_signature);
135     input_ptr += input_item_bytes;
136     if (weight_ptr == nullptr) {
137       score += running_value;
138     } else {
139       score += weight_ptr[i] * running_value;
140     }
141   }
142 
143   return (score > 0) ? 1 : 0;
144 }
145 
SparseLshProjection(const TfLiteTensor * hash,const TfLiteTensor * input,const TfLiteTensor * weight,int32_t * out_buf)146 void SparseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input,
147                          const TfLiteTensor* weight, int32_t* out_buf) {
148   int num_hash = SizeOfDimension(hash, 0);
149   int num_bits = SizeOfDimension(hash, 1);
150   for (int i = 0; i < num_hash; i++) {
151     int32_t hash_signature = 0;
152     for (int j = 0; j < num_bits; j++) {
153       float seed = GetTensorData<float>(hash)[i * num_bits + j];
154       int bit = RunningSignBit(input, weight, seed);
155       hash_signature = (hash_signature << 1) | bit;
156     }
157     *out_buf++ = hash_signature + i * (1 << num_bits);
158   }
159 }
160 
DenseLshProjection(const TfLiteTensor * hash,const TfLiteTensor * input,const TfLiteTensor * weight,int32_t * out_buf)161 void DenseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input,
162                         const TfLiteTensor* weight, int32_t* out_buf) {
163   int num_hash = SizeOfDimension(hash, 0);
164   int num_bits = SizeOfDimension(hash, 1);
165   for (int i = 0; i < num_hash; i++) {
166     for (int j = 0; j < num_bits; j++) {
167       float seed = GetTensorData<float>(hash)[i * num_bits + j];
168       int bit = RunningSignBit(input, weight, seed);
169       *out_buf++ = bit;
170     }
171   }
172 }
173 
Eval(TfLiteContext * context,TfLiteNode * node)174 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
175   auto* params =
176       reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data);
177 
178   TfLiteTensor* out_tensor;
179   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out_tensor));
180   int32_t* out_buf = out_tensor->data.i32;
181   const TfLiteTensor* hash;
182   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &hash));
183   const TfLiteTensor* input;
184   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input));
185   const TfLiteTensor* weight =
186       NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2);
187 
188   switch (params->type) {
189     case kTfLiteLshProjectionDense:
190       DenseLshProjection(hash, input, weight, out_buf);
191       break;
192     case kTfLiteLshProjectionSparse:
193       SparseLshProjection(hash, input, weight, out_buf);
194       break;
195     default:
196       return kTfLiteError;
197   }
198 
199   return kTfLiteOk;
200 }
201 }  // namespace lsh_projection
202 
Register_LSH_PROJECTION()203 TfLiteRegistration* Register_LSH_PROJECTION() {
204   static TfLiteRegistration r = {nullptr, nullptr, lsh_projection::Resize,
205                                  lsh_projection::Eval};
206   return &r;
207 }
208 
209 }  // namespace builtin
210 }  // namespace ops
211 }  // namespace tflite
212