xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/ctc_decoder_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/ctc_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include <limits>
21 
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/util/ctc/ctc_beam_search.h"
30 #include "tensorflow/core/util/sparse/sparse_tensor.h"
31 #include "tensorflow/core/util/work_sharder.h"
32 
33 namespace tensorflow {
34 
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 
37 template <typename T>
RowMax(const typename TTypes<T>::UnalignedConstMatrix & m,int r,int * c)38 inline T RowMax(const typename TTypes<T>::UnalignedConstMatrix& m, int r,
39                 int* c) {
40   *c = 0;
41   CHECK_LT(0, m.dimension(1));
42   auto p = m(r, 0);
43   for (int i = 1; i < m.dimension(1); ++i) {
44     if (m(r, i) > p) {
45       p = m(r, i);
46       *c = i;
47     }
48   }
49   return p;
50 }
51 
52 class CTCDecodeHelper {
53  public:
CTCDecodeHelper()54   CTCDecodeHelper() : top_paths_(1) {}
55 
GetTopPaths() const56   inline int GetTopPaths() const { return top_paths_; }
SetTopPaths(int tp)57   void SetTopPaths(int tp) { top_paths_ = tp; }
58 
ValidateInputsGenerateOutputs(OpKernelContext * ctx,const Tensor ** inputs,const Tensor ** seq_len,Tensor ** log_prob,OpOutputList * decoded_indices,OpOutputList * decoded_values,OpOutputList * decoded_shape) const59   Status ValidateInputsGenerateOutputs(
60       OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
61       Tensor** log_prob, OpOutputList* decoded_indices,
62       OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
63     Status status = ctx->input("inputs", inputs);
64     if (!status.ok()) return status;
65     status = ctx->input("sequence_length", seq_len);
66     if (!status.ok()) return status;
67 
68     const TensorShape& inputs_shape = (*inputs)->shape();
69 
70     if (inputs_shape.dims() != 3) {
71       return errors::InvalidArgument("inputs is not a 3-Tensor");
72     }
73     if (inputs_shape.num_elements() == 0) {
74       return errors::InvalidArgument("inputs must not be empty");
75     }
76 
77     const int64_t max_time = inputs_shape.dim_size(0);
78     const int64_t batch_size = inputs_shape.dim_size(1);
79 
80     if (max_time == 0) {
81       return errors::InvalidArgument("max_time is 0");
82     }
83     if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
84       return errors::InvalidArgument("sequence_length is not a vector");
85     }
86 
87     if (!(batch_size == (*seq_len)->dim_size(0))) {
88       return errors::FailedPrecondition(
89           "len(sequence_length) != batch_size.  ",
90           "len(sequence_length):  ", (*seq_len)->dim_size(0),
91           " batch_size: ", batch_size);
92     }
93 
94     auto seq_len_t = (*seq_len)->vec<int32>();
95 
96     for (int b = 0; b < batch_size; ++b) {
97       if (!(seq_len_t(b) <= max_time)) {
98         return errors::FailedPrecondition("sequence_length(", b,
99                                           ") <= ", max_time);
100       }
101     }
102 
103     Status s = ctx->allocate_output(
104         "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
105     if (!s.ok()) return s;
106 
107     s = ctx->output_list("decoded_indices", decoded_indices);
108     if (!s.ok()) return s;
109     s = ctx->output_list("decoded_values", decoded_values);
110     if (!s.ok()) return s;
111     s = ctx->output_list("decoded_shape", decoded_shape);
112     if (!s.ok()) return s;
113 
114     return OkStatus();
115   }
116 
117   // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
StoreAllDecodedSequences(const std::vector<std::vector<std::vector<int>>> & sequences,OpOutputList * decoded_indices,OpOutputList * decoded_values,OpOutputList * decoded_shape) const118   Status StoreAllDecodedSequences(
119       const std::vector<std::vector<std::vector<int> > >& sequences,
120       OpOutputList* decoded_indices, OpOutputList* decoded_values,
121       OpOutputList* decoded_shape) const {
122     // Calculate the total number of entries for each path
123     const int64_t batch_size = sequences.size();
124     std::vector<int64_t> num_entries(top_paths_, 0);
125 
126     // Calculate num_entries per path
127     for (const auto& batch_s : sequences) {
128       CHECK_EQ(batch_s.size(), top_paths_);
129       for (int p = 0; p < top_paths_; ++p) {
130         num_entries[p] += batch_s[p].size();
131       }
132     }
133 
134     for (int p = 0; p < top_paths_; ++p) {
135       Tensor* p_indices = nullptr;
136       Tensor* p_values = nullptr;
137       Tensor* p_shape = nullptr;
138 
139       const int64_t p_num = num_entries[p];
140 
141       Status s =
142           decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
143       if (!s.ok()) return s;
144       s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
145       if (!s.ok()) return s;
146       s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
147       if (!s.ok()) return s;
148 
149       auto indices_t = p_indices->matrix<int64_t>();
150       auto values_t = p_values->vec<int64_t>();
151       auto shape_t = p_shape->vec<int64_t>();
152 
153       int64_t max_decoded = 0;
154       int64_t offset = 0;
155 
156       for (int64_t b = 0; b < batch_size; ++b) {
157         auto& p_batch = sequences[b][p];
158         int64_t num_decoded = p_batch.size();
159         max_decoded = std::max(max_decoded, num_decoded);
160         if (num_decoded > 0) {
161           DCHECK_NE(values_t.data(), nullptr)
162               << "values_t should not be nullptr: p_num=" << p_num
163               << " num_decoded=" << num_decoded;
164           DCHECK_LT(offset, values_t.size())
165               << "offset should be smaller than values_t.size()";
166           std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
167         }
168         for (int64_t t = 0; t < num_decoded; ++t, ++offset) {
169           indices_t(offset, 0) = b;
170           indices_t(offset, 1) = t;
171         }
172       }
173 
174       shape_t(0) = batch_size;
175       shape_t(1) = max_decoded;
176     }
177     return OkStatus();
178   }
179 
180  private:
181   int top_paths_;
182   TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
183 };
184 
185 template <typename T>
186 class CTCGreedyDecoderOp : public OpKernel {
187  public:
CTCGreedyDecoderOp(OpKernelConstruction * ctx)188   explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
189     OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
190     OP_REQUIRES_OK(ctx, ctx->GetAttr("blank_index", &blank_index_));
191   }
192 
Compute(OpKernelContext * ctx)193   void Compute(OpKernelContext* ctx) override {
194     const Tensor* inputs;
195     const Tensor* seq_len;
196     Tensor* log_prob = nullptr;
197     OpOutputList decoded_indices;
198     OpOutputList decoded_values;
199     OpOutputList decoded_shape;
200     OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
201                             ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
202                             &decoded_values, &decoded_shape));
203 
204     const TensorShape& inputs_shape = inputs->shape();
205 
206     std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
207     const int64_t max_time = inputs_shape.dim_size(0);
208     const int64_t batch_size = inputs_shape.dim_size(1);
209     const int64_t num_classes_raw = inputs_shape.dim_size(2);
210     OP_REQUIRES(
211         ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
212         errors::InvalidArgument("num_classes cannot exceed max int"));
213     const int num_classes = static_cast<const int>(num_classes_raw);
214 
215     auto inputs_t = inputs->tensor<T, 3>();
216 
217     input_list_t.reserve(max_time);
218     for (std::size_t t = 0; t < max_time; ++t) {
219       input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
220                                 batch_size, num_classes);
221     }
222     auto seq_len_t = seq_len->vec<int32>();
223     auto log_prob_t = log_prob->matrix<T>();
224 
225     log_prob_t.setZero();
226 
227     int blank_index =
228         (blank_index_ < 0) ? num_classes + blank_index_ : blank_index_;
229     OP_REQUIRES(ctx, FastBoundsCheck(blank_index, num_classes),
230                 errors::InvalidArgument("blank_index expected to be between ",
231                                         -num_classes, " and ", num_classes - 1,
232                                         " but was ", blank_index_));
233 
234     // Perform best path decoding
235     std::vector<std::vector<std::vector<int> > > sequences(batch_size);
236     auto decode = [&](const int64_t begin, const int64_t end) {
237       for (int b = begin; b < end; ++b) {
238         sequences[b].resize(1);
239         auto &sequence = sequences[b][0];
240         int prev_indices = -1;
241         for (int t = 0; t < seq_len_t(b); ++t) {
242           int max_class_indices;
243           OP_REQUIRES(ctx, input_list_t[t].dimension(1) > 0,
244                       errors::InvalidArgument("Invalid input dimensions."));
245           log_prob_t(b, 0) +=
246               -RowMax<T>(input_list_t[t], b, &max_class_indices);
247           if (max_class_indices != blank_index &&
248               !(merge_repeated_ && max_class_indices == prev_indices)) {
249             sequence.push_back(max_class_indices);
250           }
251           prev_indices = max_class_indices;
252         }
253       }
254     };
255 
256     const int64_t kCostPerUnit = 50 * max_time * num_classes;
257     const int64_t total = batch_size;
258     const DeviceBase::CpuWorkerThreads& worker_threads =
259         *ctx->device()->tensorflow_cpu_worker_threads();
260     Shard(worker_threads.num_threads, worker_threads.workers, total,
261           kCostPerUnit, decode);
262 
263     OP_REQUIRES_OK(
264         ctx, decode_helper_.StoreAllDecodedSequences(
265                  sequences, &decoded_indices, &decoded_values, &decoded_shape));
266   }
267 
268  private:
269   CTCDecodeHelper decode_helper_;
270   bool merge_repeated_;
271   int blank_index_;
272 
273   TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
274 };
275 
276 #define REGISTER_CPU(T)                                                   \
277   REGISTER_KERNEL_BUILDER(                                                \
278       Name("CTCGreedyDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
279       CTCGreedyDecoderOp<T>);
280 
281 REGISTER_CPU(float);
282 REGISTER_CPU(double);
283 
284 #undef REGISTER_CPU
285 
286 // CTC beam search
287 template <typename T>
288 class CTCBeamSearchDecoderOp : public OpKernel {
289  public:
CTCBeamSearchDecoderOp(OpKernelConstruction * ctx)290   explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
291     OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
292     OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
293     int top_paths;
294     OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
295     decode_helper_.SetTopPaths(top_paths);
296   }
297 
Compute(OpKernelContext * ctx)298   void Compute(OpKernelContext* ctx) override {
299     const Tensor* inputs;
300     const Tensor* seq_len;
301     Tensor* log_prob = nullptr;
302     OpOutputList decoded_indices;
303     OpOutputList decoded_values;
304     OpOutputList decoded_shape;
305     OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
306                             ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
307                             &decoded_values, &decoded_shape));
308 
309     auto inputs_t = inputs->tensor<T, 3>();
310     auto seq_len_t = seq_len->vec<int32>();
311     auto log_prob_t = log_prob->matrix<T>();
312 
313     const TensorShape& inputs_shape = inputs->shape();
314 
315     const int64_t max_time = inputs_shape.dim_size(0);
316     const int64_t batch_size = inputs_shape.dim_size(1);
317     const int64_t num_classes_raw = inputs_shape.dim_size(2);
318     OP_REQUIRES(
319         ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
320         errors::InvalidArgument("num_classes cannot exceed max int"));
321     const int num_classes = static_cast<const int>(num_classes_raw);
322 
323     log_prob_t.setZero();
324 
325     std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
326 
327     input_list_t.reserve(max_time);
328     for (std::size_t t = 0; t < max_time; ++t) {
329       input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
330                                 batch_size, num_classes);
331     }
332 
333     ctc::CTCBeamSearchDecoder<T> beam_search(num_classes, beam_width_,
334                                              &beam_scorer_, 1 /* batch_size */,
335                                              merge_repeated_);
336     Tensor input_chip(DataTypeToEnum<T>::v(), TensorShape({num_classes}));
337     auto input_chip_t = input_chip.flat<T>();
338 
339     std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
340     std::vector<T> log_probs;
341 
342     // Assumption: the blank index is num_classes - 1
343     for (int b = 0; b < batch_size; ++b) {
344       auto& best_paths_b = best_paths[b];
345       best_paths_b.resize(decode_helper_.GetTopPaths());
346       for (int t = 0; t < seq_len_t(b); ++t) {
347         input_chip_t = input_list_t[t].chip(b, 0);
348         auto input_bi = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>(
349             input_chip_t.data(), num_classes);
350         beam_search.Step(input_bi);
351       }
352       OP_REQUIRES_OK(
353           ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
354                                     &log_probs, merge_repeated_));
355 
356       beam_search.Reset();
357 
358       for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
359         log_prob_t(b, bp) = log_probs[bp];
360       }
361     }
362 
363     OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
364                             best_paths, &decoded_indices, &decoded_values,
365                             &decoded_shape));
366   }
367 
368  private:
369   CTCDecodeHelper decode_helper_;
370   typename ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer beam_scorer_;
371   bool merge_repeated_;
372   int beam_width_;
373   TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp<T>);
374 };
375 
376 #define REGISTER_CPU(T)                                                       \
377   REGISTER_KERNEL_BUILDER(                                                    \
378       Name("CTCBeamSearchDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
379       CTCBeamSearchDecoderOp<T>);
380 
381 REGISTER_CPU(float);
382 REGISTER_CPU(double);
383 
384 #undef REGISTER_CPU
385 
386 }  // end namespace tensorflow
387