xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/ctc/ctc_loss_calculator.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
17 #define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
18 
19 #include <vector>
20 
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/core/framework/device_base.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/util/ctc/ctc_loss_util.h"
28 #include "tensorflow/core/util/work_sharder.h"
29 
30 namespace tensorflow {
31 namespace ctc {
32 
33 template <class T>
34 class CTCLossCalculator {
35   // Connectionist Temporal Classification Loss
36   //
37   // Implementation by kanishkarao@, posenhuang@, and ebrevdo@.
38   //
39   // The CTC Loss layer learns a *transition* probability value for each
40   // input time step.  The transitions are on the class alphabet
41   //   {0, 1, ..., N-2}
42   // where N is the depth of the input layer (the size of the alphabet is N-1).
43   // Note: The token N-1 is reserved for the "no transition" output, so
44   // make sure that your input layer has a depth that's one larger than
45   // the set of classes you're training on.  Also make sure that your
46   // training labels do not have a class value of N-1, as training will skip
47   // these examples.
48   //
49   // Reference materials:
50   //  GravesTh: Alex Graves, "Supervised Sequence Labeling with Recurrent
51   //    Neural Networks" (PhD Thesis), Technische Universit¨at M¨unchen.
52  public:
53   typedef std::vector<std::vector<int>> LabelSequences;
54   using Matrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
55   // typedef Eigen::MatrixXd Matrix;
56   using Array = Eigen::Array<T, Eigen::Dynamic, 1>;
57   // typedef Eigen::ArrayXd Array;
58   using InputMap = Eigen::Map<const Matrix>;
59   // typedef Eigen::Map<const Eigen::MatrixXd> InputMap;
60   using OutputMap = Eigen::Map<Matrix>;
61   // typedef Eigen::Map<Eigen::MatrixXd> OutputMap;
62 
CTCLossCalculator(int blank_index,int output_delay)63   CTCLossCalculator(int blank_index, int output_delay)
64       : blank_index_(blank_index), output_delay_(output_delay) {}
65 
66   template <typename VectorIn, typename VectorOut, typename MatrixIn,
67             typename MatrixOut>
68   Status CalculateLoss(const VectorIn& seq_len, const LabelSequences& labels,
69                        const std::vector<MatrixIn>& inputs,
70                        bool preprocess_collapse_repeated,
71                        bool ctc_merge_repeated,
72                        bool ignore_longer_outputs_than_inputs, VectorOut* loss,
73                        std::vector<MatrixOut>* gradients,
74                        DeviceBase::CpuWorkerThreads* workers = nullptr) const;
75 
76  private:
77   void CalculateForwardVariables(const std::vector<int>& l_prime,
78                                  const Matrix& y, bool ctc_merge_repeated,
79                                  Matrix* log_alpha) const;
80 
81   void CalculateBackwardVariables(const std::vector<int>& l_prime,
82                                   const Matrix& y, bool ctc_merge_repeated,
83                                   Matrix* log_beta) const;
84 
85   void CalculateGradient(const std::vector<int>& l_prime, const Matrix& y,
86                          const Matrix& log_alpha, const Matrix& log_beta,
87                          T log_p_z_x, Matrix* dy) const;
88 
89   void GetLPrimeIndices(const std::vector<int>& l,
90                         std::vector<int>* l_prime) const;
91 
92   // Helper function that calculates the l_prime indices for all
93   // batches at the same time, and identifies errors for any given
94   // batch.  Return value:
95   //    max_{b in batch_size} l_primes[b].size()
96   template <typename Vector>
97   Status PopulateLPrimes(bool preprocess_collapse_repeated,
98                          bool ignore_longer_outputs_than_inputs, int batch_size,
99                          int num_classes, const Vector& seq_len,
100                          const LabelSequences& labels, size_t* max_u_prime,
101                          LabelSequences* l_primes) const;
102 
103   // Utility indices for the CTC algorithm.
104   int blank_index_;
105 
106   // Delay for target labels in time steps.
107   // The delay in time steps before the output sequence.
108   const int output_delay_;
109 };
110 
111 template <class T>
112 template <typename VectorIn, typename VectorOut, typename MatrixIn,
113           typename MatrixOut>
CalculateLoss(const VectorIn & seq_len,const LabelSequences & labels,const std::vector<MatrixIn> & inputs,bool preprocess_collapse_repeated,bool ctc_merge_repeated,bool ignore_longer_outputs_than_inputs,VectorOut * loss,std::vector<MatrixOut> * gradients,DeviceBase::CpuWorkerThreads * workers)114 Status CTCLossCalculator<T>::CalculateLoss(
115     const VectorIn& seq_len, const LabelSequences& labels,
116     const std::vector<MatrixIn>& inputs, bool preprocess_collapse_repeated,
117     bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs,
118     VectorOut* loss, std::vector<MatrixOut>* gradients,
119     DeviceBase::CpuWorkerThreads* workers) const {
120   using Eigen::numext::log;
121 
122   auto num_time_steps = inputs.size();
123 
124   if (loss == nullptr) {
125     return errors::InvalidArgument("loss == nullptr");
126   }
127 
128   bool requires_backprop = (gradients != nullptr);
129 
130   auto batch_size = inputs[0].rows();
131   auto num_classes = inputs[0].cols();
132 
133   if (loss->size() != batch_size) {
134     return errors::InvalidArgument("loss.size() != batch_size");
135   }
136   loss->setZero();
137 
138   for (int t = 1; t < num_time_steps; ++t) {
139     if (inputs[t].rows() != batch_size) {
140       return errors::InvalidArgument("Expected batch size at t: ", t,
141                                      " to be: ", batch_size,
142                                      " but got: ", inputs[t].rows());
143     }
144     if (inputs[t].cols() != num_classes) {
145       return errors::InvalidArgument("Expected class count at t: ", t,
146                                      " to be: ", num_classes,
147                                      " but got: ", inputs[t].cols());
148     }
149   }
150 
151   // Check validity of sequence_length array values.
152   auto max_seq_len = seq_len(0);
153   for (int b = 0; b < batch_size; b++) {
154     if (seq_len(b) < 0) {
155       return errors::InvalidArgument("seq_len(", b, ") < 0");
156     }
157     if (seq_len(b) > num_time_steps) {
158       return errors::InvalidArgument("seq_len(", b, ") > num_time_steps");
159     }
160     max_seq_len = std::max(seq_len(b), max_seq_len);
161   }
162 
163   // Calculate the modified label sequence l' for each batch element,
164   // and calculate the maximum necessary allocation size.
165   LabelSequences l_primes(batch_size);
166   size_t max_u_prime = 0;
167   Status l_p_ret = PopulateLPrimes(
168       preprocess_collapse_repeated, ignore_longer_outputs_than_inputs,
169       batch_size, num_classes, seq_len, labels, &max_u_prime, &l_primes);
170   if (!l_p_ret.ok()) {
171     return l_p_ret;
172   }
173 
174   // Process each item in a batch in parallel, using at most kMaxThreads.
175   auto ComputeLossAndGradients = [this, num_classes, &labels, &l_primes,
176                                   &seq_len, &inputs, requires_backprop,
177                                   ctc_merge_repeated,
178                                   ignore_longer_outputs_than_inputs, &loss,
179                                   &gradients](int64_t start_row,
180                                               int64_t limit_row) {
181     for (int b = start_row; b < limit_row; b++) {
182       // Return zero gradient for empty sequences or sequences with labels
183       // longer than input, which is not supported by CTC.
184       if (seq_len(b) == 0 ||
185           (ignore_longer_outputs_than_inputs &&
186            labels[b].size() > seq_len(b) - this->output_delay_)) {
187         VLOG(1) << "The sequence length is either zero or shorter than the "
188                    "target output (CTC works only with shorter target sequence "
189                    "than input sequence). You can turn this into a warning by "
190                    "using the flag ignore_longer_outputs_than_inputs - "
191                 << b << ": " << str_util::Join(labels[b], " ");
192         continue;
193       }
194 
195       // For each batch element, log(alpha) and log(beta).
196       //   row size is: u_prime == l_prime.size()
197       //   col size is: seq_len[b] - output_delay_
198       const std::vector<int>& l_prime = l_primes[b];
199 
200       Matrix log_alpha_b(l_prime.size(), seq_len(b) - this->output_delay_);
201       Matrix log_beta_b(l_prime.size(), seq_len(b) - this->output_delay_);
202 
203       // Work matrices, pre-allocated to the size required by this batch item.
204       Matrix y(num_classes, seq_len(b));
205       Matrix dy;
206       if (requires_backprop) {
207         dy = Matrix::Zero(y.rows(), y.cols());
208       }
209 
210       // For this batch, we'll only work with this shortened sequence_length.
211       Matrix y_b = y.leftCols(seq_len(b));
212 
213       // Convert label from DistBelief
214       // y, prob are in num_classes x seq_len(b)
215       // Output activations.
216       Array y_b_col;
217       for (int t = 0; t < seq_len(b); t++) {
218         // Calculate the softmax of y_b.  Use original precision
219         // arithmetic for the sum.
220         T max_coeff = inputs[t].row(b).maxCoeff();
221         y_b_col = (inputs[t].row(b).array() - max_coeff).exp();
222         y_b.col(t) = y_b_col / y_b_col.sum();
223       }
224 
225       // Compute forward, backward.
226       // Forward variables.
227       CalculateForwardVariables(l_prime, y_b, ctc_merge_repeated, &log_alpha_b);
228       // Backward variables.
229       CalculateBackwardVariables(l_prime, y_b, ctc_merge_repeated, &log_beta_b);
230 
231       // The loss is computed as the log(p(z|x)) between the target and
232       // prediction. Do lazy evaluation of log_prob here.
233       T log_p_z_x = kLogZero<T>();
234       for (int u = 0; u < l_prime.size(); ++u) {
235         // (GravesTh) Eq 7.26, sum over all paths for t = 0.
236         log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0));
237       }
238 
239       (*loss)(b) = -log_p_z_x;  // Use negative log loss for display.
240 
241       // We compute the derivative if needed.
242       if (requires_backprop) {
243         // Gradients with respect to input activations.
244         // Calculate gradient.
245         dy.setZero();
246         CalculateGradient(l_prime, y_b, log_alpha_b, log_beta_b, log_p_z_x,
247                           &dy);
248 
249         // Convert gradient for current sample to DistBelief.
250         for (int t = 0; t < seq_len(b); t++) {
251           (*gradients)[t].row(b).array() = dy.col(t);
252         }
253       }
254     }  // for (int b = ...
255   };
256   if (workers) {
257     // *Rough* estimate of the cost for one item in the batch.
258     // Forward, Backward: O(T * U (= 2L + 1)), Gradients: O(T * (U + L)).
259     //
260     // softmax: T * L * (Cost(Exp) + Cost(Div))softmax +
261     // fwd,bwd: T * 2 * (2*L + 1) * (Cost(LogSumExp) + Cost(Log)) +
262     // grad: T * ((2L + 1) * Cost(LogSumExp) + L * (Cost(Expf) + Cost(Add)).
263     const int64_t cost_exp = Eigen::internal::functor_traits<
264         Eigen::internal::scalar_exp_op<T>>::Cost;
265     const int64_t cost_log = Eigen::internal::functor_traits<
266         Eigen::internal::scalar_log_op<T>>::Cost;
267     const int64_t cost_log_sum_exp =
268         Eigen::TensorOpCost::AddCost<T>() + cost_exp + cost_log;
269     const int64_t cost =
270         max_seq_len * num_classes *
271             (cost_exp + Eigen::TensorOpCost::DivCost<T>()) +
272         max_seq_len * 2 * (2 * num_classes + 1) *
273             (cost_log_sum_exp + cost_log) +
274         max_seq_len *
275             ((2 * num_classes + 1) * cost_log_sum_exp +
276              num_classes * (cost_exp + Eigen::TensorOpCost::AddCost<T>()));
277     Shard(workers->num_threads, workers->workers, batch_size, cost,
278           ComputeLossAndGradients);
279   } else {
280     ComputeLossAndGradients(0, batch_size);
281   }
282   return OkStatus();
283 }
284 
285 template <class T>
286 template <typename Vector>
PopulateLPrimes(bool preprocess_collapse_repeated,bool ignore_longer_outputs_than_inputs,int batch_size,int num_classes,const Vector & seq_len,const LabelSequences & labels,size_t * max_u_prime,LabelSequences * l_primes)287 Status CTCLossCalculator<T>::PopulateLPrimes(
288     bool preprocess_collapse_repeated, bool ignore_longer_outputs_than_inputs,
289     int batch_size, int num_classes, const Vector& seq_len,
290     const LabelSequences& labels, size_t* max_u_prime,
291     LabelSequences* l_primes) const {
292   // labels is a Label array of size batch_size
293   if (labels.size() != batch_size) {
294     return errors::InvalidArgument(
295         "labels.size() != batch_size: ", labels.size(), " vs. ", batch_size);
296   }
297 
298   *max_u_prime = 0;  // keep track of longest l' modified label sequence.
299   for (int b = 0; b < batch_size; b++) {
300     // Assume label is in Label proto
301     const std::vector<int>& label = labels[b];
302     if (label.size() == 0) {
303       return errors::InvalidArgument("Labels length is zero in batch ", b);
304     }
305 
306     // If debugging: output the labels coming into training.
307     //
308     VLOG(2) << "label for batch: " << b << ": " << str_util::Join(label, " ");
309 
310     // Target indices, length = U.
311     std::vector<int> l;
312 
313     // Convert label from DistBelief
314     bool finished_sequence = false;
315     for (int i = 0; i < label.size(); ++i) {
316       if (i == 0 || !preprocess_collapse_repeated || label[i] != label[i - 1]) {
317         if (label[i] >= num_classes - 1) {
318           finished_sequence = true;
319         } else {
320           if (finished_sequence) {
321             // Saw an invalid sequence with non-null following null
322             // labels.
323             return errors::InvalidArgument(
324                 "Saw a non-null label (index >= num_classes - 1) "
325                 "following a ",
326                 "null label, batch: ", b, " num_classes: ", num_classes,
327                 " labels: ", str_util::Join(label, ","),
328                 " labels seen so far: ", str_util::Join(l, ","));
329           }
330           l.push_back(label[i]);
331         }
332       }
333     }
334 
335     for (int l_i : l) {
336       if (l_i < 0) {
337         return errors::InvalidArgument(
338             "All labels must be nonnegative integers, batch: ", b,
339             " labels: ", str_util::Join(l, ","));
340       } else if (l_i >= num_classes) {
341         return errors::InvalidArgument(
342             "No label may be greater than num_classes. ",
343             "num_classes: ", num_classes, ", batch: ", b,
344             " labels: ", str_util::Join(l, ","));
345       }
346     }
347     if (!ignore_longer_outputs_than_inputs) {
348       // Make sure there is enough time to output the target indices.
349       int time = seq_len(b) - output_delay_;
350       int required_time = label.size();
351       if (required_time > time) {
352         return errors::InvalidArgument(
353             "Not enough time for target transition sequence ("
354             "required: ",
355             required_time, ", available: ", time, ")", b,
356             "You can turn this error into a warning by using the flag "
357             "ignore_longer_outputs_than_inputs");
358       }
359     }
360     // Target indices with blanks before each index and a blank at the end.
361     // Length U' = 2U + 1.
362     // Convert l to l_prime
363     GetLPrimeIndices(l, &l_primes->at(b));
364     *max_u_prime = std::max(*max_u_prime, l_primes->at(b).size());
365   }
366   return OkStatus();
367 }
368 
369 // Calculates the alpha(t, u) as described in (GravesTh) Section 7.3.
370 // Starting with t = 0 instead of t = 1 used in the text.
371 // Based on Kanishka's CTC.
372 template <typename TT>
CalculateForwardVariables(const std::vector<int> & l_prime,const Matrix & y,bool ctc_merge_repeated,Matrix * log_alpha)373 void CTCLossCalculator<TT>::CalculateForwardVariables(
374     const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
375     Matrix* log_alpha) const {
376   using Eigen::numext::log;
377 
378   // Number of cols is the number of time steps = number of cols in target
379   // after the output delay.
380   log_alpha->setConstant(kLogZero<TT>());
381 
382   int U = l_prime.size();
383   int T = log_alpha->cols();
384 
385   CHECK_EQ(U, log_alpha->rows());
386 
387   // Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6.
388   log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_));
389   // Below, l_prime[1] == labels[0]
390   auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_;
391   log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_));
392 
393   for (int t = 1; t < T; ++t) {
394     // If there is not enough time to output the remaining labels or
395     // some labels have been skipped, then let log_alpha(u, t) continue to
396     // be kLogZero.
397     for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
398          ++u) {
399       // Begin (GravesTh) Eq 7.9
400       // Add in the u, t - 1 term.
401       auto sum_log_alpha = kLogZero<TT>();
402       if (ctc_merge_repeated || l_prime[u] == blank_index_) {
403         sum_log_alpha = log_alpha->coeff(u, t - 1);
404       }
405 
406       // Add in the u - 1, t - 1 term.
407       if (u > 0) {
408         sum_log_alpha =
409             LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1));
410       }
411 
412       // Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2).
413       if (u > 1) {
414         const bool matching_labels_merge =
415             ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]);
416         if (l_prime[u] != blank_index_ && !matching_labels_merge) {
417           sum_log_alpha =
418               LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1));
419         }
420       }
421       // Multiply the summed alphas with the activation log probability.
422       log_alpha->coeffRef(u, t) =
423           log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha;
424     }  // End (GravesTh) Eq 7.9.
425   }
426 }
427 
428 // Calculates the beta(t, u) as described in (GravesTh) Section 7.3.
429 template <class TT>
CalculateBackwardVariables(const std::vector<int> & l_prime,const Matrix & y,bool ctc_merge_repeated,Matrix * log_beta)430 void CTCLossCalculator<TT>::CalculateBackwardVariables(
431     const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
432     Matrix* log_beta) const {
433   // Number of cols is the number of time steps =  number of cols in target.
434   // Matrix log_beta =
435   //    Matrix::Constant(l_prime.size(), y.cols() - output_delay_,
436   // kLogZero);
437   using Eigen::numext::log;
438 
439   log_beta->setConstant(kLogZero<TT>());
440   int T = log_beta->cols();
441   int U = l_prime.size();
442   CHECK_EQ(U, log_beta->rows());
443 
444   // Initial beta values in (GravesTh) Eq 7.13: log of probability 1.
445   for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0;
446 
447   for (int t = T - 1 - 1; t >= 0; --t) {
448     // If there is not enough time to output the remaining labels or
449     // some labels have been skipped, then let log_beta(u, t) continue to
450     // be kLogZero.
451     for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
452          ++u) {
453       // Begin (GravesTh) Eq 7.15
454       // Add in the u, t + 1 term.
455       if (ctc_merge_repeated || l_prime[u] == blank_index_) {
456         log_beta->coeffRef(u, t) =
457             LogSumExp(log_beta->coeff(u, t),
458                       log_beta->coeff(u, t + 1) +
459                           log(y(l_prime[u], output_delay_ + t + 1)));
460       }
461 
462       // Add in the u + 1, t + 1 term.
463       if (u + 1 < U) {
464         log_beta->coeffRef(u, t) =
465             LogSumExp(log_beta->coeff(u, t),
466                       log_beta->coeff(u + 1, t + 1) +
467                           log(y(l_prime[u + 1], output_delay_ + t + 1)));
468       }
469 
470       // Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2).
471       if (u + 2 < U) {
472         const bool matching_labels_merge =
473             ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]);
474         if (l_prime[u] != blank_index_ && !matching_labels_merge) {
475           // Add in u + 2 term.
476           log_beta->coeffRef(u, t) =
477               LogSumExp(log_beta->coeff(u, t),
478                         log_beta->coeff(u + 2, t + 1) +
479                             log(y(l_prime[u + 2], output_delay_ + t + 1)));
480         }
481       }  // End (GravesTh) Eq. 7.15
482     }
483   }
484 }
485 
486 // Using (GravesTh) Eq 7.26 & 7.34.
487 template <typename TT>
CalculateGradient(const std::vector<int> & l_prime,const Matrix & y,const Matrix & log_alpha,const Matrix & log_beta,TT log_p_z_x,Matrix * dy)488 void CTCLossCalculator<TT>::CalculateGradient(const std::vector<int>& l_prime,
489                                               const Matrix& y,
490                                               const Matrix& log_alpha,
491                                               const Matrix& log_beta,
492                                               TT log_p_z_x, Matrix* dy) const {
493   // Only working with the leftmost part of dy for this batch element.
494   auto dy_b = dy->leftCols(y.cols());
495 
496   // It is possible that no valid path is found if the activations for the
497   // targets are zero.
498   if (log_p_z_x == kLogZero<TT>()) {
499     LOG(WARNING) << "No valid path found.";
500     dy_b = y;
501     return;
502   }
503 
504   int L = y.rows();
505   int T = y.cols();
506   int U = l_prime.size();
507 
508   for (int t = 0; t < T - output_delay_; ++t) {
509     Array prob_sum(L);
510     prob_sum.setConstant(kLogZero<TT>());
511 
512     for (int u = 0; u < U; ++u) {
513       int l = l_prime[u];
514       prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t));
515     }
516 
517     for (int l = 0; l < L; ++l) {
518       // Negative term in (GravesTh) Eq 7.28.
519       auto negative_term = expf(prob_sum[l] - log_p_z_x);
520 
521       dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term;
522     }
523   }
524 }
525 
526 template <class TT>
GetLPrimeIndices(const std::vector<int> & l,std::vector<int> * l_prime)527 void CTCLossCalculator<TT>::GetLPrimeIndices(const std::vector<int>& l,
528                                              std::vector<int>* l_prime) const {
529   // Assumption is that l_prime is empty.
530   l_prime->reserve(2 * l.size() + 1);
531 
532   for (auto label : l) {
533     l_prime->push_back(blank_index_);
534     l_prime->push_back(label);
535   }
536   // Add final blank to l'.
537   l_prime->push_back(blank_index_);
538 }
539 
540 }  // namespace ctc
541 }  // namespace tensorflow
542 
543 #endif  // TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
544