xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/embedding_lookup.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Ops that looks up items from matrix.
17 //
18 // Input:
19 //     Tensor[0]: Row number to lookup, dim.size == 1, int32
20 //     Tensor[1]: 2-dimensional matrix of multi-dimensional items
21 //                dim.size >= 2, any data type.
22 //                first dimension is row, second dimension is column.
23 //
24 // Output:
25 //   Output.dim[0] == Tensor[0].dim[0], num of lookups
26 //   Output.dim[1] == Tensor[1].dim[1],  num of items per row
27 //   Each item in output is a raw bytes copy of the corresponding item in input,
28 //   or a dequantized value in the case of a uint8 input.
29 //   When indices are out of bound, the ops will not succeed.
30 //
31 
32 #include <stdint.h>
33 
34 #include <cstring>
35 
36 #include "tensorflow/lite/c/common.h"
37 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
38 #include "tensorflow/lite/kernels/kernel_util.h"
39 
40 namespace tflite {
41 namespace ops {
42 namespace builtin {
43 namespace embedding_lookup {
44 
Prepare(TfLiteContext * context,TfLiteNode * node)45 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
46   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
47   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
48 
49   const TfLiteTensor* lookup;
50   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
51   TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
52   TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
53 
54   const TfLiteTensor* value;
55   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value));
56   TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
57 
58   TfLiteTensor* output;
59   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
60   TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
61 
62   outputSize->data[0] = SizeOfDimension(lookup, 0);
63   outputSize->data[1] = SizeOfDimension(value, 1);
64   for (int i = 2; i < NumDimensions(value); i++) {
65     outputSize->data[i] = SizeOfDimension(value, i);
66   }
67   return context->ResizeTensor(context, output, outputSize);
68 }
69 
EvalSimple(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * lookup,const TfLiteTensor * value,TfLiteTensor * output)70 TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node,
71                         const TfLiteTensor* lookup, const TfLiteTensor* value,
72                         TfLiteTensor* output) {
73   const int row_size = SizeOfDimension(value, 0);
74   if (row_size == 0) {
75     // Propagate empty tensor if input is empty
76     return kTfLiteOk;
77   }
78   const int row_bytes = value->bytes / row_size;
79 
80   char* output_raw = GetTensorData<char>(output);
81   const char* value_raw = GetTensorData<char>(value);
82   const int32_t* lookup_data = GetTensorData<int32_t>(lookup);
83   for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
84     int idx = lookup_data[i];
85     if (idx >= row_size || idx < 0) {
86       TF_LITE_KERNEL_LOG(context,
87                          "Embedding Lookup: index out of bounds. "
88                          "Got %d, and bounds are [0, %d]",
89                          idx, row_size - 1);
90       return kTfLiteError;
91     } else {
92       std::memcpy(output_raw + i * row_bytes, value_raw + idx * row_bytes,
93                   row_bytes);
94     }
95   }
96 
97   return kTfLiteOk;
98 }
99 
EvalHybrid(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * lookup,const TfLiteTensor * value,TfLiteTensor * output)100 TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
101                         const TfLiteTensor* lookup, const TfLiteTensor* value,
102                         TfLiteTensor* output) {
103   const int row_size = SizeOfDimension(value, 0);
104   const double scaling_factor = value->params.scale;
105 
106   // col_size after we flatten tensor into 2D.
107   int col_size = 1;
108   for (int i = 1; i < NumDimensions(value); i++) {
109     col_size *= SizeOfDimension(value, i);
110   }
111 
112   float* output_ptr = GetTensorData<float>(output);
113   const int8_t* value_ptr = GetTensorData<int8_t>(value);
114   const int32_t* lookup_data = GetTensorData<int32_t>(lookup);
115 
116   for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
117     int idx = lookup_data[i];
118     if (idx >= row_size || idx < 0) {
119       TF_LITE_KERNEL_LOG(context,
120                          "Embedding Lookup: index out of bounds. "
121                          "Got %d, and bounds are [0, %d]",
122                          idx, row_size - 1);
123       return kTfLiteError;
124     } else {
125       // Dequantize embedding values.
126       // TODO(alanchiao): refactor scalar multiply into separate function
127       // for ease of adding a neon equivalent if ever necessary.
128       for (int j = 0; j < col_size; j++) {
129         output_ptr[j + i * col_size] =
130             value_ptr[j + idx * col_size] * scaling_factor;
131       }
132     }
133   }
134 
135   return kTfLiteOk;
136 }
137 
Eval(TfLiteContext * context,TfLiteNode * node)138 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
139   const TfLiteTensor* lookup;
140   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
141   const TfLiteTensor* value;
142   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value));
143   TfLiteTensor* output;
144   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
145   switch (value->type) {
146     case kTfLiteFloat32:
147       return EvalSimple(context, node, lookup, value, output);
148     case kTfLiteUInt8:
149     case kTfLiteInt8:
150       if (output->type == kTfLiteFloat32) {
151         return EvalHybrid(context, node, lookup, value, output);
152       } else {
153         return EvalSimple(context, node, lookup, value, output);
154       }
155     default:
156       TF_LITE_KERNEL_LOG(context, "Type not currently supported.");
157       return kTfLiteError;
158   }
159 }
160 
161 }  // namespace embedding_lookup
162 
Register_EMBEDDING_LOOKUP()163 TfLiteRegistration* Register_EMBEDDING_LOOKUP() {
164   static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare,
165                                  embedding_lookup::Eval};
166   return &r;
167 }
168 
169 }  // namespace builtin
170 }  // namespace ops
171 }  // namespace tflite
172