xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <vector>
17 
18 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/ctc/ctc_beam_search.h"
21 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22 #include "tensorflow/lite/kernels/internal/tensor.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/kernels/op_macros.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace custom {
29 namespace ctc_beam_search_decoder {
30 
31 constexpr int kInputsTensor = 0;
32 constexpr int kSequenceLengthTensor = 1;
33 
34 typedef struct {
35   int beam_width;
36   int top_paths;
37   bool merge_repeated;
38 } CTCBeamSearchDecoderParams;
39 
Init(TfLiteContext * context,const char * buffer,size_t length)40 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
41   TFLITE_CHECK(buffer != nullptr);
42   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
43   const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
44 
45   CTCBeamSearchDecoderParams* option = new CTCBeamSearchDecoderParams;
46   option->beam_width = m["beam_width"].AsInt32();
47   option->top_paths = m["top_paths"].AsInt32();
48   option->merge_repeated = m["merge_repeated"].AsBool();
49 
50   return option;
51 }
52 
Free(TfLiteContext * context,void * buffer)53 void Free(TfLiteContext* context, void* buffer) {
54   delete reinterpret_cast<CTCBeamSearchDecoderParams*>(buffer);
55 }
56 
Prepare(TfLiteContext * context,TfLiteNode * node)57 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
58   const CTCBeamSearchDecoderParams* option =
59       reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
60   const int top_paths = option->top_paths;
61   TF_LITE_ENSURE(context, option->beam_width >= top_paths);
62   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
63   // The outputs should be top_paths * 3 + 1.
64   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1);
65 
66   const TfLiteTensor* inputs;
67   TF_LITE_ENSURE_OK(context,
68                     GetInputSafe(context, node, kInputsTensor, &inputs));
69   TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3);
70   // TensorFlow only supports float.
71   TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32);
72   const int batch_size = SizeOfDimension(inputs, 1);
73 
74   const TfLiteTensor* sequence_length;
75   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
76                                           &sequence_length));
77   TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1);
78   TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size);
79   // TensorFlow only supports int32.
80   TF_LITE_ENSURE_EQ(context, sequence_length->type, kTfLiteInt32);
81 
82   // Resize decoded outputs.
83   // Do not resize indices & values cause we don't know the values yet.
84   for (int i = 0; i < top_paths; ++i) {
85     TfLiteTensor* indices;
86     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &indices));
87     SetTensorToDynamic(indices);
88     TfLiteTensor* values;
89     TF_LITE_ENSURE_OK(context,
90                       GetOutputSafe(context, node, i + top_paths, &values));
91     SetTensorToDynamic(values);
92     TfLiteTensor* output_shape;
93     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i + 2 * top_paths,
94                                              &output_shape));
95     SetTensorToDynamic(output_shape);
96   }
97 
98   // Resize log probability outputs.
99   TfLiteTensor* log_probability_output;
100   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, top_paths * 3,
101                                            &log_probability_output));
102   TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2);
103   log_probability_output_shape_array->data[0] = batch_size;
104   log_probability_output_shape_array->data[1] = top_paths;
105   return context->ResizeTensor(context, log_probability_output,
106                                log_probability_output_shape_array);
107 }
108 
Resize(TfLiteContext * context,std::initializer_list<int32_t> output_shape,TfLiteTensor * output)109 TfLiteStatus Resize(TfLiteContext* context,
110                     std::initializer_list<int32_t> output_shape,
111                     TfLiteTensor* output) {
112   const int dimensions = output_shape.size();
113   TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(dimensions);
114   int i = 0;
115   for (const int v : output_shape) {
116     output_shape_array->data[i++] = v;
117   }
118   return context->ResizeTensor(context, output, output_shape_array);
119 }
120 
StoreAllDecodedSequences(TfLiteContext * context,const std::vector<std::vector<std::vector<int>>> & sequences,TfLiteNode * node,int top_paths)121 TfLiteStatus StoreAllDecodedSequences(
122     TfLiteContext* context,
123     const std::vector<std::vector<std::vector<int>>>& sequences,
124     TfLiteNode* node, int top_paths) {
125   const int32_t batch_size = sequences.size();
126   std::vector<int32_t> num_entries(top_paths, 0);
127 
128   // Calculate num_entries per path
129   for (const auto& batch_s : sequences) {
130     TF_LITE_ENSURE_EQ(context, batch_s.size(), top_paths);
131     for (int p = 0; p < top_paths; ++p) {
132       num_entries[p] += batch_s[p].size();
133     }
134   }
135 
136   for (int p = 0; p < top_paths; ++p) {
137     const int32_t p_num = num_entries[p];
138 
139     // Resize the decoded outputs.
140     TfLiteTensor* indices;
141     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p, &indices));
142     TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices));
143 
144     TfLiteTensor* values;
145     TF_LITE_ENSURE_OK(context,
146                       GetOutputSafe(context, node, p + top_paths, &values));
147     TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values));
148 
149     TfLiteTensor* decoded_shape;
150     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p + 2 * top_paths,
151                                              &decoded_shape));
152     TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape));
153 
154     int32_t max_decoded = 0;
155     int32_t offset = 0;
156 
157     int32_t* indices_data = GetTensorData<int32_t>(indices);
158     int32_t* values_data = GetTensorData<int32_t>(values);
159     int32_t* decoded_shape_data = GetTensorData<int32_t>(decoded_shape);
160     for (int b = 0; b < batch_size; ++b) {
161       auto& p_batch = sequences[b][p];
162       int32_t num_decoded = p_batch.size();
163       max_decoded = std::max(max_decoded, num_decoded);
164 
165       std::copy_n(p_batch.begin(), num_decoded, values_data + offset);
166       for (int32_t t = 0; t < num_decoded; ++t, ++offset) {
167         indices_data[offset * 2] = b;
168         indices_data[offset * 2 + 1] = t;
169       }
170     }
171 
172     decoded_shape_data[0] = batch_size;
173     decoded_shape_data[1] = max_decoded;
174   }
175   return kTfLiteOk;
176 }
177 
Eval(TfLiteContext * context,TfLiteNode * node)178 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
179   const TfLiteTensor* inputs;
180   TF_LITE_ENSURE_OK(context,
181                     GetInputSafe(context, node, kInputsTensor, &inputs));
182   const TfLiteTensor* sequence_length;
183   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
184                                           &sequence_length));
185   const CTCBeamSearchDecoderParams* option =
186       reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
187 
188   const int max_time = SizeOfDimension(inputs, 0);
189   const int batch_size = SizeOfDimension(inputs, 1);
190   const int num_classes = SizeOfDimension(inputs, 2);
191 
192   const int beam_width = option->beam_width;
193   const int top_paths = option->top_paths;
194   const bool merge_repeated = option->merge_repeated;
195 
196   // Validate sequence length is less or equal than max time.
197   for (int i = 0; i < batch_size; ++i) {
198     TF_LITE_ENSURE(context,
199                    max_time >= GetTensorData<int32_t>(sequence_length)[i]);
200   }
201 
202   // The following logic is implemented like
203   // tensorflow/core/kernels/ctc_decoder_ops.cc
204   std::vector<optimized_ops::TTypes<float>::UnalignedConstMatrix> input_list_t;
205 
206   for (std::size_t t = 0; t < max_time; ++t) {
207     input_list_t.emplace_back(
208         GetTensorData<float>(inputs) + t * batch_size * num_classes, batch_size,
209         num_classes);
210   }
211 
212   ::tflite::custom::ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer;
213   ::tflite::custom::ctc::CTCBeamSearchDecoder<> beam_search(
214       num_classes, beam_width, &beam_scorer, 1 /* batch_size */,
215       merge_repeated);
216 
217   // Allocate temporary memory for holding chip operation data.
218   float* input_chip_t_data =
219       static_cast<float*>(malloc(num_classes * sizeof(float)));
220   Eigen::array<Eigen::DenseIndex, 1> dims;
221   dims[0] = num_classes;
222   optimized_ops::TTypes<float>::Flat input_chip_t(input_chip_t_data, dims);
223 
224   std::vector<std::vector<std::vector<int>>> best_paths(batch_size);
225   std::vector<float> log_probs;
226 
227   TfLiteTensor* log_probabilities;
228   TF_LITE_ENSURE_OK(
229       context, GetOutputSafe(context, node, 3 * top_paths, &log_probabilities));
230   float* log_probabilities_output = GetTensorData<float>(log_probabilities);
231 
232   // Assumption: the blank index is num_classes - 1
233   for (int b = 0; b < batch_size; ++b) {
234     auto& best_paths_b = best_paths[b];
235     best_paths_b.resize(top_paths);
236     for (int t = 0; t < GetTensorData<int32_t>(sequence_length)[b]; ++t) {
237       input_chip_t = input_list_t[t].chip(b, 0);
238       auto input_bi =
239           Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
240       beam_search.Step(input_bi);
241     }
242     TF_LITE_ENSURE(context, beam_search.TopPaths(top_paths, &best_paths_b,
243                                                  &log_probs, merge_repeated));
244     beam_search.Reset();
245 
246     // Fill in log_probabilities output.
247     for (int bp = 0; bp < top_paths; ++bp) {
248       log_probabilities_output[b * top_paths + bp] = log_probs[bp];
249     }
250   }
251 
252   free(input_chip_t_data);
253   return StoreAllDecodedSequences(context, best_paths, node, top_paths);
254 }
255 
256 }  // namespace ctc_beam_search_decoder
257 
Register_CTC_BEAM_SEARCH_DECODER()258 TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER() {
259   static TfLiteRegistration r = {
260       ctc_beam_search_decoder::Init, ctc_beam_search_decoder::Free,
261       ctc_beam_search_decoder::Prepare, ctc_beam_search_decoder::Eval};
262   return &r;
263 }
264 
265 }  // namespace custom
266 }  // namespace ops
267 }  // namespace tflite
268