1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <vector>
17
18 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/ctc/ctc_beam_search.h"
21 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22 #include "tensorflow/lite/kernels/internal/tensor.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/kernels/op_macros.h"
25
26 namespace tflite {
27 namespace ops {
28 namespace custom {
29 namespace ctc_beam_search_decoder {
30
31 constexpr int kInputsTensor = 0;
32 constexpr int kSequenceLengthTensor = 1;
33
34 typedef struct {
35 int beam_width;
36 int top_paths;
37 bool merge_repeated;
38 } CTCBeamSearchDecoderParams;
39
Init(TfLiteContext * context,const char * buffer,size_t length)40 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
41 TFLITE_CHECK(buffer != nullptr);
42 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
43 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
44
45 CTCBeamSearchDecoderParams* option = new CTCBeamSearchDecoderParams;
46 option->beam_width = m["beam_width"].AsInt32();
47 option->top_paths = m["top_paths"].AsInt32();
48 option->merge_repeated = m["merge_repeated"].AsBool();
49
50 return option;
51 }
52
Free(TfLiteContext * context,void * buffer)53 void Free(TfLiteContext* context, void* buffer) {
54 delete reinterpret_cast<CTCBeamSearchDecoderParams*>(buffer);
55 }
56
Prepare(TfLiteContext * context,TfLiteNode * node)57 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
58 const CTCBeamSearchDecoderParams* option =
59 reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
60 const int top_paths = option->top_paths;
61 TF_LITE_ENSURE(context, option->beam_width >= top_paths);
62 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
63 // The outputs should be top_paths * 3 + 1.
64 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1);
65
66 const TfLiteTensor* inputs;
67 TF_LITE_ENSURE_OK(context,
68 GetInputSafe(context, node, kInputsTensor, &inputs));
69 TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3);
70 // TensorFlow only supports float.
71 TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32);
72 const int batch_size = SizeOfDimension(inputs, 1);
73
74 const TfLiteTensor* sequence_length;
75 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
76 &sequence_length));
77 TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1);
78 TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size);
79 // TensorFlow only supports int32.
80 TF_LITE_ENSURE_EQ(context, sequence_length->type, kTfLiteInt32);
81
82 // Resize decoded outputs.
83 // Do not resize indices & values cause we don't know the values yet.
84 for (int i = 0; i < top_paths; ++i) {
85 TfLiteTensor* indices;
86 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &indices));
87 SetTensorToDynamic(indices);
88 TfLiteTensor* values;
89 TF_LITE_ENSURE_OK(context,
90 GetOutputSafe(context, node, i + top_paths, &values));
91 SetTensorToDynamic(values);
92 TfLiteTensor* output_shape;
93 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i + 2 * top_paths,
94 &output_shape));
95 SetTensorToDynamic(output_shape);
96 }
97
98 // Resize log probability outputs.
99 TfLiteTensor* log_probability_output;
100 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, top_paths * 3,
101 &log_probability_output));
102 TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2);
103 log_probability_output_shape_array->data[0] = batch_size;
104 log_probability_output_shape_array->data[1] = top_paths;
105 return context->ResizeTensor(context, log_probability_output,
106 log_probability_output_shape_array);
107 }
108
Resize(TfLiteContext * context,std::initializer_list<int32_t> output_shape,TfLiteTensor * output)109 TfLiteStatus Resize(TfLiteContext* context,
110 std::initializer_list<int32_t> output_shape,
111 TfLiteTensor* output) {
112 const int dimensions = output_shape.size();
113 TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(dimensions);
114 int i = 0;
115 for (const int v : output_shape) {
116 output_shape_array->data[i++] = v;
117 }
118 return context->ResizeTensor(context, output, output_shape_array);
119 }
120
StoreAllDecodedSequences(TfLiteContext * context,const std::vector<std::vector<std::vector<int>>> & sequences,TfLiteNode * node,int top_paths)121 TfLiteStatus StoreAllDecodedSequences(
122 TfLiteContext* context,
123 const std::vector<std::vector<std::vector<int>>>& sequences,
124 TfLiteNode* node, int top_paths) {
125 const int32_t batch_size = sequences.size();
126 std::vector<int32_t> num_entries(top_paths, 0);
127
128 // Calculate num_entries per path
129 for (const auto& batch_s : sequences) {
130 TF_LITE_ENSURE_EQ(context, batch_s.size(), top_paths);
131 for (int p = 0; p < top_paths; ++p) {
132 num_entries[p] += batch_s[p].size();
133 }
134 }
135
136 for (int p = 0; p < top_paths; ++p) {
137 const int32_t p_num = num_entries[p];
138
139 // Resize the decoded outputs.
140 TfLiteTensor* indices;
141 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p, &indices));
142 TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices));
143
144 TfLiteTensor* values;
145 TF_LITE_ENSURE_OK(context,
146 GetOutputSafe(context, node, p + top_paths, &values));
147 TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values));
148
149 TfLiteTensor* decoded_shape;
150 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p + 2 * top_paths,
151 &decoded_shape));
152 TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape));
153
154 int32_t max_decoded = 0;
155 int32_t offset = 0;
156
157 int32_t* indices_data = GetTensorData<int32_t>(indices);
158 int32_t* values_data = GetTensorData<int32_t>(values);
159 int32_t* decoded_shape_data = GetTensorData<int32_t>(decoded_shape);
160 for (int b = 0; b < batch_size; ++b) {
161 auto& p_batch = sequences[b][p];
162 int32_t num_decoded = p_batch.size();
163 max_decoded = std::max(max_decoded, num_decoded);
164
165 std::copy_n(p_batch.begin(), num_decoded, values_data + offset);
166 for (int32_t t = 0; t < num_decoded; ++t, ++offset) {
167 indices_data[offset * 2] = b;
168 indices_data[offset * 2 + 1] = t;
169 }
170 }
171
172 decoded_shape_data[0] = batch_size;
173 decoded_shape_data[1] = max_decoded;
174 }
175 return kTfLiteOk;
176 }
177
Eval(TfLiteContext * context,TfLiteNode * node)178 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
179 const TfLiteTensor* inputs;
180 TF_LITE_ENSURE_OK(context,
181 GetInputSafe(context, node, kInputsTensor, &inputs));
182 const TfLiteTensor* sequence_length;
183 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
184 &sequence_length));
185 const CTCBeamSearchDecoderParams* option =
186 reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
187
188 const int max_time = SizeOfDimension(inputs, 0);
189 const int batch_size = SizeOfDimension(inputs, 1);
190 const int num_classes = SizeOfDimension(inputs, 2);
191
192 const int beam_width = option->beam_width;
193 const int top_paths = option->top_paths;
194 const bool merge_repeated = option->merge_repeated;
195
196 // Validate sequence length is less or equal than max time.
197 for (int i = 0; i < batch_size; ++i) {
198 TF_LITE_ENSURE(context,
199 max_time >= GetTensorData<int32_t>(sequence_length)[i]);
200 }
201
202 // The following logic is implemented like
203 // tensorflow/core/kernels/ctc_decoder_ops.cc
204 std::vector<optimized_ops::TTypes<float>::UnalignedConstMatrix> input_list_t;
205
206 for (std::size_t t = 0; t < max_time; ++t) {
207 input_list_t.emplace_back(
208 GetTensorData<float>(inputs) + t * batch_size * num_classes, batch_size,
209 num_classes);
210 }
211
212 ::tflite::custom::ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer;
213 ::tflite::custom::ctc::CTCBeamSearchDecoder<> beam_search(
214 num_classes, beam_width, &beam_scorer, 1 /* batch_size */,
215 merge_repeated);
216
217 // Allocate temporary memory for holding chip operation data.
218 float* input_chip_t_data =
219 static_cast<float*>(malloc(num_classes * sizeof(float)));
220 Eigen::array<Eigen::DenseIndex, 1> dims;
221 dims[0] = num_classes;
222 optimized_ops::TTypes<float>::Flat input_chip_t(input_chip_t_data, dims);
223
224 std::vector<std::vector<std::vector<int>>> best_paths(batch_size);
225 std::vector<float> log_probs;
226
227 TfLiteTensor* log_probabilities;
228 TF_LITE_ENSURE_OK(
229 context, GetOutputSafe(context, node, 3 * top_paths, &log_probabilities));
230 float* log_probabilities_output = GetTensorData<float>(log_probabilities);
231
232 // Assumption: the blank index is num_classes - 1
233 for (int b = 0; b < batch_size; ++b) {
234 auto& best_paths_b = best_paths[b];
235 best_paths_b.resize(top_paths);
236 for (int t = 0; t < GetTensorData<int32_t>(sequence_length)[b]; ++t) {
237 input_chip_t = input_list_t[t].chip(b, 0);
238 auto input_bi =
239 Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
240 beam_search.Step(input_bi);
241 }
242 TF_LITE_ENSURE(context, beam_search.TopPaths(top_paths, &best_paths_b,
243 &log_probs, merge_repeated));
244 beam_search.Reset();
245
246 // Fill in log_probabilities output.
247 for (int bp = 0; bp < top_paths; ++bp) {
248 log_probabilities_output[b * top_paths + bp] = log_probs[bp];
249 }
250 }
251
252 free(input_chip_t_data);
253 return StoreAllDecodedSequences(context, best_paths, node, top_paths);
254 }
255
256 } // namespace ctc_beam_search_decoder
257
Register_CTC_BEAM_SEARCH_DECODER()258 TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER() {
259 static TfLiteRegistration r = {
260 ctc_beam_search_decoder::Init, ctc_beam_search_decoder::Free,
261 ctc_beam_search_decoder::Prepare, ctc_beam_search_decoder::Eval};
262 return &r;
263 }
264
265 } // namespace custom
266 } // namespace ops
267 } // namespace tflite
268