xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/embedding_lookup_sparse.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Op that looks up items from a sparse tensor in an embedding matrix.
17 // The sparse lookup tensor is represented by three individual tensors: lookup,
18 // indices, and dense_shape. The representation assume that the corresponding
19 // dense tensor would satisfy:
20 //   * dense.shape = dense_shape
21 //   * dense[tuple(indices[i])] = lookup[i]
22 //
23 // By convention, indices should be sorted.
24 //
25 // Options:
26 //   combiner: The reduction op (SUM, MEAN, SQRTN).
27 //     * SUM computes the weighted sum of the embedding results.
28 //     * MEAN is the weighted sum divided by the total weight.
29 //     * SQRTN is the weighted sum divided by the square root of the sum of the
30 //       squares of the weights.
31 //
32 // Input:
33 //     Tensor[0]: Ids to lookup, dim.size == 1, int32.
34 //     Tensor[1]: Indices, int32.
35 //     Tensor[2]: Dense shape, int32.
36 //     Tensor[3]: Weights to use for aggregation, float.
37 //     Tensor[4]: Params, a matrix of multi-dimensional items,
38 //                dim.size >= 2, float.
39 //
40 // Output:
41 //   A (dense) tensor representing the combined embeddings for the sparse ids.
42 //   For each row in the sparse tensor represented by (lookup, indices, shape)
43 //   the op looks up the embeddings for all ids in that row, multiplies them by
44 //   the corresponding weight, and combines these embeddings as specified in the
45 //   last dimension.
46 //
47 //   Output.dim = [l0, ... , ln-1, e1, ..., em]
48 //   Where dense_shape == [l0, ..., ln] and Tensor[4].dim == [e0, e1, ..., em]
49 //
50 //   For instance, if params is a 10x20 matrix and ids, weights are:
51 //
52 //   [0, 0]: id 1, weight 2.0
53 //   [0, 1]: id 3, weight 0.5
54 //   [1, 0]: id 0, weight 1.0
55 //   [2, 3]: id 1, weight 3.0
56 //
57 //   with combiner=MEAN, then the output will be a (3, 20) tensor where:
58 //
59 //   output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
60 //   output[1, :] = (params[0, :] * 1.0) / 1.0
61 //   output[2, :] = (params[1, :] * 3.0) / 3.0
62 //
63 //   When indices are out of bound, the op will not succeed.
64 
65 #include <stdint.h>
66 
67 #include <algorithm>
68 #include <cmath>
69 
70 #include "tensorflow/lite/c/builtin_op_data.h"
71 #include "tensorflow/lite/c/common.h"
72 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
73 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
74 #include "tensorflow/lite/kernels/kernel_util.h"
75 #include "tensorflow/lite/util.h"
76 
77 namespace tflite {
78 namespace ops {
79 namespace builtin {
80 
81 namespace {
82 
Prepare(TfLiteContext * context,TfLiteNode * node)83 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
84   TF_LITE_ENSURE_EQ(context, NumInputs(node), 5);
85   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
86 
87   const TfLiteTensor* ids;
88   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids));
89   TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1);
90   TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32);
91 
92   const TfLiteTensor* indices;
93   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices));
94   TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2);
95   TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32);
96 
97   const TfLiteTensor* shape;
98   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &shape));
99   TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
100   TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32);
101 
102   const TfLiteTensor* weights;
103   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
104   TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1);
105   TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32);
106 
107   TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
108                     SizeOfDimension(ids, 0));
109   TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
110                     SizeOfDimension(weights, 0));
111 
112   const TfLiteTensor* value;
113   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
114   TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
115 
116   // Mark the output as a dynamic tensor.
117   TfLiteTensor* output;
118   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
119   TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
120   output->allocation_type = kTfLiteDynamic;
121 
122   return kTfLiteOk;
123 }
124 
FinalizeAggregation(TfLiteCombinerType combiner,int num_elements,float current_total_weight,float current_squares_weight,int embedding_size,float * output)125 void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements,
126                          float current_total_weight,
127                          float current_squares_weight, int embedding_size,
128                          float* output) {
129   if (combiner != kTfLiteCombinerTypeSum && num_elements > 0) {
130     float multiplier = 1.0;
131     switch (combiner) {
132       case kTfLiteCombinerTypeMean:
133         multiplier = current_total_weight;
134         break;
135       case kTfLiteCombinerTypeSqrtn:
136         multiplier = std::sqrt(current_squares_weight);
137         break;
138       default:
139         break;
140     }
141     for (int k = 0; k < embedding_size; k++) {
142       output[k] /= multiplier;
143     }
144   }
145 }
146 
Eval(TfLiteContext * context,TfLiteNode * node)147 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
148   auto* params =
149       reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data);
150   TfLiteTensor* output;
151   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
152   const TfLiteTensor* ids;
153   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids));
154   const TfLiteTensor* indices;
155   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices));
156   const TfLiteTensor* dense_shape;
157   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &dense_shape));
158   const TfLiteTensor* weights;
159   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
160   const TfLiteTensor* value;
161   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
162   const size_t values_size = NumElements(value);
163 
164   const int lookup_rank = SizeOfDimension(indices, 1);
165   const int embedding_rank = NumDimensions(value);
166   const int num_lookups = SizeOfDimension(ids, 0);
167   const int num_rows = SizeOfDimension(value, 0);
168 
169   // The last dimension gets replaced by the embedding.
170   const int output_rank = (lookup_rank - 1) + (embedding_rank - 1);
171 
172   // Make sure that the actual dense shape of the sparse tensor represented by
173   // (loopkup, indices, dense_shape) is consistent.
174   TF_LITE_ENSURE_EQ(context, SizeOfDimension(dense_shape, 0), lookup_rank);
175 
176   // Resize output tensor.
177   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
178   TF_LITE_ENSURE(context, output_shape != nullptr);
179   int k = 0;
180   size_t embedding_size = 1;
181   size_t lookup_size = 1;
182   for (int i = 0; i < lookup_rank - 1; i++, k++) {
183     const size_t dim = dense_shape->data.i32[i];
184     TF_LITE_ENSURE_MSG(
185         context,
186         MultiplyAndCheckOverflow(lookup_size, dim, &lookup_size) == kTfLiteOk,
187         "Lookup size overflowed.");
188     output_shape->data[k] = dim;
189   }
190   for (int i = 1; i < embedding_rank; i++, k++) {
191     const size_t dim = SizeOfDimension(value, i);
192     TF_LITE_ENSURE_MSG(context,
193                        MultiplyAndCheckOverflow(embedding_size, dim,
194                                                 &embedding_size) == kTfLiteOk,
195                        "Embedding size overflowed.");
196     output_shape->data[k] = dim;
197   }
198   TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
199   const size_t output_size = lookup_size * embedding_size;
200   TfLiteTensorRealloc(output_size * sizeof(float), output);
201 
202   float* output_ptr = GetTensorData<float>(output);
203   const float* weights_ptr = GetTensorData<float>(weights);
204   const float* value_ptr = GetTensorData<float>(value);
205   // Makes sure reallocation was successful.
206   TF_LITE_ENSURE(context, output_ptr != nullptr);
207 
208   std::fill_n(output_ptr, output_size, 0.0f);
209 
210   // Keep track of the current bucket for aggregation/combination.
211   int current_output_offset = 0;
212   float current_total_weight = 0.0;
213   float current_squares_weight = 0.0;
214   int num_elements = 0;
215 
216   for (int i = 0; i < num_lookups; i++) {
217     int idx = ids->data.i32[i];
218     if (idx >= num_rows || idx < 0) {
219       TF_LITE_KERNEL_LOG(context,
220                          "Embedding Lookup Sparse: index out of bounds. "
221                          "Got %d, and bounds are [0, %d]",
222                          idx, num_rows - 1);
223       return kTfLiteError;
224     }
225 
226     // Check where we need to aggregate.
227     const int example_indices_offset = i * lookup_rank;
228     int output_bucket = 0;
229     int stride = 1;
230     for (int k = (lookup_rank - 1) - 1; k >= 0; k--) {
231       output_bucket += indices->data.i32[example_indices_offset + k] * stride;
232       stride *= dense_shape->data.i32[k];
233     }
234     const int output_offset = output_bucket * embedding_size;
235 
236     // If we are in a new aggregation bucket and the combiner is not the sum,
237     // go back and finalize the result of the previous bucket.
238     if (output_offset != current_output_offset) {
239       FinalizeAggregation(params->combiner, num_elements, current_total_weight,
240                           current_squares_weight, embedding_size,
241                           &output_ptr[current_output_offset]);
242 
243       // Track next bucket.
244       num_elements = 0;
245       current_total_weight = 0.0;
246       current_squares_weight = 0.0;
247       current_output_offset = output_offset;
248     }
249 
250     // Add element to aggregation.
251     ++num_elements;
252     const int example_embedding_offset = idx * embedding_size;
253     const float w = weights_ptr[i];
254     current_squares_weight += w * w;
255     current_total_weight += w;
256     for (int k = 0; k < embedding_size; k++) {
257       // only index if indices are valid
258       if (current_output_offset + k < 0) continue;
259       if (current_output_offset + k >= output_size) continue;
260       if (example_embedding_offset + k < 0) continue;
261       if (example_embedding_offset + k >= values_size) continue;
262       output_ptr[current_output_offset + k] +=
263           value_ptr[example_embedding_offset + k] * w;
264     }
265   }
266 
267   // Finalize last bucket.
268   FinalizeAggregation(params->combiner, num_elements, current_total_weight,
269                       current_squares_weight, embedding_size,
270                       &GetTensorData<float>(output)[current_output_offset]);
271 
272   return kTfLiteOk;
273 }
274 
275 }  // namespace
276 
Register_EMBEDDING_LOOKUP_SPARSE()277 TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE() {
278   static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
279   return &r;
280 }
281 
282 }  // namespace builtin
283 }  // namespace ops
284 }  // namespace tflite
285