1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Ops that looks up items from matrix.
17 //
18 // Input:
19 // Tensor[0]: Row number to lookup, dim.size == 1, int32
20 // Tensor[1]: 2-dimensional matrix of multi-dimensional items
21 // dim.size >= 2, any data type.
22 // first dimension is row, second dimension is column.
23 //
24 // Output:
25 // Output.dim[0] == Tensor[0].dim[0], num of lookups
26 // Output.dim[1] == Tensor[1].dim[1], num of items per row
27 // Each item in output is a raw bytes copy of the corresponding item in input,
28 // or a dequantized value in the case of a uint8 input.
29 // When indices are out of bound, the ops will not succeed.
30 //
31
32 #include <stdint.h>
33
34 #include <cstring>
35
36 #include "tensorflow/lite/c/common.h"
37 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
38 #include "tensorflow/lite/kernels/kernel_util.h"
39
40 namespace tflite {
41 namespace ops {
42 namespace builtin {
43 namespace embedding_lookup {
44
Prepare(TfLiteContext * context,TfLiteNode * node)45 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
46 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
47 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
48
49 const TfLiteTensor* lookup;
50 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
51 TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
52 TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
53
54 const TfLiteTensor* value;
55 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value));
56 TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
57
58 TfLiteTensor* output;
59 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
60 TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
61
62 outputSize->data[0] = SizeOfDimension(lookup, 0);
63 outputSize->data[1] = SizeOfDimension(value, 1);
64 for (int i = 2; i < NumDimensions(value); i++) {
65 outputSize->data[i] = SizeOfDimension(value, i);
66 }
67 return context->ResizeTensor(context, output, outputSize);
68 }
69
EvalSimple(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * lookup,const TfLiteTensor * value,TfLiteTensor * output)70 TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node,
71 const TfLiteTensor* lookup, const TfLiteTensor* value,
72 TfLiteTensor* output) {
73 const int row_size = SizeOfDimension(value, 0);
74 if (row_size == 0) {
75 // Propagate empty tensor if input is empty
76 return kTfLiteOk;
77 }
78 const int row_bytes = value->bytes / row_size;
79
80 char* output_raw = GetTensorData<char>(output);
81 const char* value_raw = GetTensorData<char>(value);
82 const int32_t* lookup_data = GetTensorData<int32_t>(lookup);
83 for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
84 int idx = lookup_data[i];
85 if (idx >= row_size || idx < 0) {
86 TF_LITE_KERNEL_LOG(context,
87 "Embedding Lookup: index out of bounds. "
88 "Got %d, and bounds are [0, %d]",
89 idx, row_size - 1);
90 return kTfLiteError;
91 } else {
92 std::memcpy(output_raw + i * row_bytes, value_raw + idx * row_bytes,
93 row_bytes);
94 }
95 }
96
97 return kTfLiteOk;
98 }
99
EvalHybrid(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * lookup,const TfLiteTensor * value,TfLiteTensor * output)100 TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
101 const TfLiteTensor* lookup, const TfLiteTensor* value,
102 TfLiteTensor* output) {
103 const int row_size = SizeOfDimension(value, 0);
104 const double scaling_factor = value->params.scale;
105
106 // col_size after we flatten tensor into 2D.
107 int col_size = 1;
108 for (int i = 1; i < NumDimensions(value); i++) {
109 col_size *= SizeOfDimension(value, i);
110 }
111
112 float* output_ptr = GetTensorData<float>(output);
113 const int8_t* value_ptr = GetTensorData<int8_t>(value);
114 const int32_t* lookup_data = GetTensorData<int32_t>(lookup);
115
116 for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
117 int idx = lookup_data[i];
118 if (idx >= row_size || idx < 0) {
119 TF_LITE_KERNEL_LOG(context,
120 "Embedding Lookup: index out of bounds. "
121 "Got %d, and bounds are [0, %d]",
122 idx, row_size - 1);
123 return kTfLiteError;
124 } else {
125 // Dequantize embedding values.
126 // TODO(alanchiao): refactor scalar multiply into separate function
127 // for ease of adding a neon equivalent if ever necessary.
128 for (int j = 0; j < col_size; j++) {
129 output_ptr[j + i * col_size] =
130 value_ptr[j + idx * col_size] * scaling_factor;
131 }
132 }
133 }
134
135 return kTfLiteOk;
136 }
137
Eval(TfLiteContext * context,TfLiteNode * node)138 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
139 const TfLiteTensor* lookup;
140 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
141 const TfLiteTensor* value;
142 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value));
143 TfLiteTensor* output;
144 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
145 switch (value->type) {
146 case kTfLiteFloat32:
147 return EvalSimple(context, node, lookup, value, output);
148 case kTfLiteUInt8:
149 case kTfLiteInt8:
150 if (output->type == kTfLiteFloat32) {
151 return EvalHybrid(context, node, lookup, value, output);
152 } else {
153 return EvalSimple(context, node, lookup, value, output);
154 }
155 default:
156 TF_LITE_KERNEL_LOG(context, "Type not currently supported.");
157 return kTfLiteError;
158 }
159 }
160
161 } // namespace embedding_lookup
162
Register_EMBEDDING_LOOKUP()163 TfLiteRegistration* Register_EMBEDDING_LOOKUP() {
164 static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare,
165 embedding_lookup::Eval};
166 return &r;
167 }
168
169 } // namespace builtin
170 } // namespace ops
171 } // namespace tflite
172