1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
17 #define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
18
19 #include <vector>
20
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/core/framework/device_base.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/util/ctc/ctc_loss_util.h"
28 #include "tensorflow/core/util/work_sharder.h"
29
30 namespace tensorflow {
31 namespace ctc {
32
33 template <class T>
34 class CTCLossCalculator {
35 // Connectionist Temporal Classification Loss
36 //
37 // Implementation by kanishkarao@, posenhuang@, and ebrevdo@.
38 //
39 // The CTC Loss layer learns a *transition* probability value for each
40 // input time step. The transitions are on the class alphabet
41 // {0, 1, ..., N-2}
42 // where N is the depth of the input layer (the size of the alphabet is N-1).
43 // Note: The token N-1 is reserved for the "no transition" output, so
44 // make sure that your input layer has a depth that's one larger than
45 // the set of classes you're training on. Also make sure that your
46 // training labels do not have a class value of N-1, as training will skip
47 // these examples.
48 //
49 // Reference materials:
50 // GravesTh: Alex Graves, "Supervised Sequence Labeling with Recurrent
51 // Neural Networks" (PhD Thesis), Technische Universit¨at M¨unchen.
52 public:
53 typedef std::vector<std::vector<int>> LabelSequences;
54 using Matrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
55 // typedef Eigen::MatrixXd Matrix;
56 using Array = Eigen::Array<T, Eigen::Dynamic, 1>;
57 // typedef Eigen::ArrayXd Array;
58 using InputMap = Eigen::Map<const Matrix>;
59 // typedef Eigen::Map<const Eigen::MatrixXd> InputMap;
60 using OutputMap = Eigen::Map<Matrix>;
61 // typedef Eigen::Map<Eigen::MatrixXd> OutputMap;
62
CTCLossCalculator(int blank_index,int output_delay)63 CTCLossCalculator(int blank_index, int output_delay)
64 : blank_index_(blank_index), output_delay_(output_delay) {}
65
66 template <typename VectorIn, typename VectorOut, typename MatrixIn,
67 typename MatrixOut>
68 Status CalculateLoss(const VectorIn& seq_len, const LabelSequences& labels,
69 const std::vector<MatrixIn>& inputs,
70 bool preprocess_collapse_repeated,
71 bool ctc_merge_repeated,
72 bool ignore_longer_outputs_than_inputs, VectorOut* loss,
73 std::vector<MatrixOut>* gradients,
74 DeviceBase::CpuWorkerThreads* workers = nullptr) const;
75
76 private:
77 void CalculateForwardVariables(const std::vector<int>& l_prime,
78 const Matrix& y, bool ctc_merge_repeated,
79 Matrix* log_alpha) const;
80
81 void CalculateBackwardVariables(const std::vector<int>& l_prime,
82 const Matrix& y, bool ctc_merge_repeated,
83 Matrix* log_beta) const;
84
85 void CalculateGradient(const std::vector<int>& l_prime, const Matrix& y,
86 const Matrix& log_alpha, const Matrix& log_beta,
87 T log_p_z_x, Matrix* dy) const;
88
89 void GetLPrimeIndices(const std::vector<int>& l,
90 std::vector<int>* l_prime) const;
91
92 // Helper function that calculates the l_prime indices for all
93 // batches at the same time, and identifies errors for any given
94 // batch. Return value:
95 // max_{b in batch_size} l_primes[b].size()
96 template <typename Vector>
97 Status PopulateLPrimes(bool preprocess_collapse_repeated,
98 bool ignore_longer_outputs_than_inputs, int batch_size,
99 int num_classes, const Vector& seq_len,
100 const LabelSequences& labels, size_t* max_u_prime,
101 LabelSequences* l_primes) const;
102
103 // Utility indices for the CTC algorithm.
104 int blank_index_;
105
106 // Delay for target labels in time steps.
107 // The delay in time steps before the output sequence.
108 const int output_delay_;
109 };
110
111 template <class T>
112 template <typename VectorIn, typename VectorOut, typename MatrixIn,
113 typename MatrixOut>
CalculateLoss(const VectorIn & seq_len,const LabelSequences & labels,const std::vector<MatrixIn> & inputs,bool preprocess_collapse_repeated,bool ctc_merge_repeated,bool ignore_longer_outputs_than_inputs,VectorOut * loss,std::vector<MatrixOut> * gradients,DeviceBase::CpuWorkerThreads * workers)114 Status CTCLossCalculator<T>::CalculateLoss(
115 const VectorIn& seq_len, const LabelSequences& labels,
116 const std::vector<MatrixIn>& inputs, bool preprocess_collapse_repeated,
117 bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs,
118 VectorOut* loss, std::vector<MatrixOut>* gradients,
119 DeviceBase::CpuWorkerThreads* workers) const {
120 using Eigen::numext::log;
121
122 auto num_time_steps = inputs.size();
123
124 if (loss == nullptr) {
125 return errors::InvalidArgument("loss == nullptr");
126 }
127
128 bool requires_backprop = (gradients != nullptr);
129
130 auto batch_size = inputs[0].rows();
131 auto num_classes = inputs[0].cols();
132
133 if (loss->size() != batch_size) {
134 return errors::InvalidArgument("loss.size() != batch_size");
135 }
136 loss->setZero();
137
138 for (int t = 1; t < num_time_steps; ++t) {
139 if (inputs[t].rows() != batch_size) {
140 return errors::InvalidArgument("Expected batch size at t: ", t,
141 " to be: ", batch_size,
142 " but got: ", inputs[t].rows());
143 }
144 if (inputs[t].cols() != num_classes) {
145 return errors::InvalidArgument("Expected class count at t: ", t,
146 " to be: ", num_classes,
147 " but got: ", inputs[t].cols());
148 }
149 }
150
151 // Check validity of sequence_length array values.
152 auto max_seq_len = seq_len(0);
153 for (int b = 0; b < batch_size; b++) {
154 if (seq_len(b) < 0) {
155 return errors::InvalidArgument("seq_len(", b, ") < 0");
156 }
157 if (seq_len(b) > num_time_steps) {
158 return errors::InvalidArgument("seq_len(", b, ") > num_time_steps");
159 }
160 max_seq_len = std::max(seq_len(b), max_seq_len);
161 }
162
163 // Calculate the modified label sequence l' for each batch element,
164 // and calculate the maximum necessary allocation size.
165 LabelSequences l_primes(batch_size);
166 size_t max_u_prime = 0;
167 Status l_p_ret = PopulateLPrimes(
168 preprocess_collapse_repeated, ignore_longer_outputs_than_inputs,
169 batch_size, num_classes, seq_len, labels, &max_u_prime, &l_primes);
170 if (!l_p_ret.ok()) {
171 return l_p_ret;
172 }
173
174 // Process each item in a batch in parallel, using at most kMaxThreads.
175 auto ComputeLossAndGradients = [this, num_classes, &labels, &l_primes,
176 &seq_len, &inputs, requires_backprop,
177 ctc_merge_repeated,
178 ignore_longer_outputs_than_inputs, &loss,
179 &gradients](int64_t start_row,
180 int64_t limit_row) {
181 for (int b = start_row; b < limit_row; b++) {
182 // Return zero gradient for empty sequences or sequences with labels
183 // longer than input, which is not supported by CTC.
184 if (seq_len(b) == 0 ||
185 (ignore_longer_outputs_than_inputs &&
186 labels[b].size() > seq_len(b) - this->output_delay_)) {
187 VLOG(1) << "The sequence length is either zero or shorter than the "
188 "target output (CTC works only with shorter target sequence "
189 "than input sequence). You can turn this into a warning by "
190 "using the flag ignore_longer_outputs_than_inputs - "
191 << b << ": " << str_util::Join(labels[b], " ");
192 continue;
193 }
194
195 // For each batch element, log(alpha) and log(beta).
196 // row size is: u_prime == l_prime.size()
197 // col size is: seq_len[b] - output_delay_
198 const std::vector<int>& l_prime = l_primes[b];
199
200 Matrix log_alpha_b(l_prime.size(), seq_len(b) - this->output_delay_);
201 Matrix log_beta_b(l_prime.size(), seq_len(b) - this->output_delay_);
202
203 // Work matrices, pre-allocated to the size required by this batch item.
204 Matrix y(num_classes, seq_len(b));
205 Matrix dy;
206 if (requires_backprop) {
207 dy = Matrix::Zero(y.rows(), y.cols());
208 }
209
210 // For this batch, we'll only work with this shortened sequence_length.
211 Matrix y_b = y.leftCols(seq_len(b));
212
213 // Convert label from DistBelief
214 // y, prob are in num_classes x seq_len(b)
215 // Output activations.
216 Array y_b_col;
217 for (int t = 0; t < seq_len(b); t++) {
218 // Calculate the softmax of y_b. Use original precision
219 // arithmetic for the sum.
220 T max_coeff = inputs[t].row(b).maxCoeff();
221 y_b_col = (inputs[t].row(b).array() - max_coeff).exp();
222 y_b.col(t) = y_b_col / y_b_col.sum();
223 }
224
225 // Compute forward, backward.
226 // Forward variables.
227 CalculateForwardVariables(l_prime, y_b, ctc_merge_repeated, &log_alpha_b);
228 // Backward variables.
229 CalculateBackwardVariables(l_prime, y_b, ctc_merge_repeated, &log_beta_b);
230
231 // The loss is computed as the log(p(z|x)) between the target and
232 // prediction. Do lazy evaluation of log_prob here.
233 T log_p_z_x = kLogZero<T>();
234 for (int u = 0; u < l_prime.size(); ++u) {
235 // (GravesTh) Eq 7.26, sum over all paths for t = 0.
236 log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0));
237 }
238
239 (*loss)(b) = -log_p_z_x; // Use negative log loss for display.
240
241 // We compute the derivative if needed.
242 if (requires_backprop) {
243 // Gradients with respect to input activations.
244 // Calculate gradient.
245 dy.setZero();
246 CalculateGradient(l_prime, y_b, log_alpha_b, log_beta_b, log_p_z_x,
247 &dy);
248
249 // Convert gradient for current sample to DistBelief.
250 for (int t = 0; t < seq_len(b); t++) {
251 (*gradients)[t].row(b).array() = dy.col(t);
252 }
253 }
254 } // for (int b = ...
255 };
256 if (workers) {
257 // *Rough* estimate of the cost for one item in the batch.
258 // Forward, Backward: O(T * U (= 2L + 1)), Gradients: O(T * (U + L)).
259 //
260 // softmax: T * L * (Cost(Exp) + Cost(Div))softmax +
261 // fwd,bwd: T * 2 * (2*L + 1) * (Cost(LogSumExp) + Cost(Log)) +
262 // grad: T * ((2L + 1) * Cost(LogSumExp) + L * (Cost(Expf) + Cost(Add)).
263 const int64_t cost_exp = Eigen::internal::functor_traits<
264 Eigen::internal::scalar_exp_op<T>>::Cost;
265 const int64_t cost_log = Eigen::internal::functor_traits<
266 Eigen::internal::scalar_log_op<T>>::Cost;
267 const int64_t cost_log_sum_exp =
268 Eigen::TensorOpCost::AddCost<T>() + cost_exp + cost_log;
269 const int64_t cost =
270 max_seq_len * num_classes *
271 (cost_exp + Eigen::TensorOpCost::DivCost<T>()) +
272 max_seq_len * 2 * (2 * num_classes + 1) *
273 (cost_log_sum_exp + cost_log) +
274 max_seq_len *
275 ((2 * num_classes + 1) * cost_log_sum_exp +
276 num_classes * (cost_exp + Eigen::TensorOpCost::AddCost<T>()));
277 Shard(workers->num_threads, workers->workers, batch_size, cost,
278 ComputeLossAndGradients);
279 } else {
280 ComputeLossAndGradients(0, batch_size);
281 }
282 return OkStatus();
283 }
284
285 template <class T>
286 template <typename Vector>
PopulateLPrimes(bool preprocess_collapse_repeated,bool ignore_longer_outputs_than_inputs,int batch_size,int num_classes,const Vector & seq_len,const LabelSequences & labels,size_t * max_u_prime,LabelSequences * l_primes)287 Status CTCLossCalculator<T>::PopulateLPrimes(
288 bool preprocess_collapse_repeated, bool ignore_longer_outputs_than_inputs,
289 int batch_size, int num_classes, const Vector& seq_len,
290 const LabelSequences& labels, size_t* max_u_prime,
291 LabelSequences* l_primes) const {
292 // labels is a Label array of size batch_size
293 if (labels.size() != batch_size) {
294 return errors::InvalidArgument(
295 "labels.size() != batch_size: ", labels.size(), " vs. ", batch_size);
296 }
297
298 *max_u_prime = 0; // keep track of longest l' modified label sequence.
299 for (int b = 0; b < batch_size; b++) {
300 // Assume label is in Label proto
301 const std::vector<int>& label = labels[b];
302 if (label.size() == 0) {
303 return errors::InvalidArgument("Labels length is zero in batch ", b);
304 }
305
306 // If debugging: output the labels coming into training.
307 //
308 VLOG(2) << "label for batch: " << b << ": " << str_util::Join(label, " ");
309
310 // Target indices, length = U.
311 std::vector<int> l;
312
313 // Convert label from DistBelief
314 bool finished_sequence = false;
315 for (int i = 0; i < label.size(); ++i) {
316 if (i == 0 || !preprocess_collapse_repeated || label[i] != label[i - 1]) {
317 if (label[i] >= num_classes - 1) {
318 finished_sequence = true;
319 } else {
320 if (finished_sequence) {
321 // Saw an invalid sequence with non-null following null
322 // labels.
323 return errors::InvalidArgument(
324 "Saw a non-null label (index >= num_classes - 1) "
325 "following a ",
326 "null label, batch: ", b, " num_classes: ", num_classes,
327 " labels: ", str_util::Join(label, ","),
328 " labels seen so far: ", str_util::Join(l, ","));
329 }
330 l.push_back(label[i]);
331 }
332 }
333 }
334
335 for (int l_i : l) {
336 if (l_i < 0) {
337 return errors::InvalidArgument(
338 "All labels must be nonnegative integers, batch: ", b,
339 " labels: ", str_util::Join(l, ","));
340 } else if (l_i >= num_classes) {
341 return errors::InvalidArgument(
342 "No label may be greater than num_classes. ",
343 "num_classes: ", num_classes, ", batch: ", b,
344 " labels: ", str_util::Join(l, ","));
345 }
346 }
347 if (!ignore_longer_outputs_than_inputs) {
348 // Make sure there is enough time to output the target indices.
349 int time = seq_len(b) - output_delay_;
350 int required_time = label.size();
351 if (required_time > time) {
352 return errors::InvalidArgument(
353 "Not enough time for target transition sequence ("
354 "required: ",
355 required_time, ", available: ", time, ")", b,
356 "You can turn this error into a warning by using the flag "
357 "ignore_longer_outputs_than_inputs");
358 }
359 }
360 // Target indices with blanks before each index and a blank at the end.
361 // Length U' = 2U + 1.
362 // Convert l to l_prime
363 GetLPrimeIndices(l, &l_primes->at(b));
364 *max_u_prime = std::max(*max_u_prime, l_primes->at(b).size());
365 }
366 return OkStatus();
367 }
368
369 // Calculates the alpha(t, u) as described in (GravesTh) Section 7.3.
370 // Starting with t = 0 instead of t = 1 used in the text.
371 // Based on Kanishka's CTC.
372 template <typename TT>
CalculateForwardVariables(const std::vector<int> & l_prime,const Matrix & y,bool ctc_merge_repeated,Matrix * log_alpha)373 void CTCLossCalculator<TT>::CalculateForwardVariables(
374 const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
375 Matrix* log_alpha) const {
376 using Eigen::numext::log;
377
378 // Number of cols is the number of time steps = number of cols in target
379 // after the output delay.
380 log_alpha->setConstant(kLogZero<TT>());
381
382 int U = l_prime.size();
383 int T = log_alpha->cols();
384
385 CHECK_EQ(U, log_alpha->rows());
386
387 // Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6.
388 log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_));
389 // Below, l_prime[1] == labels[0]
390 auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_;
391 log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_));
392
393 for (int t = 1; t < T; ++t) {
394 // If there is not enough time to output the remaining labels or
395 // some labels have been skipped, then let log_alpha(u, t) continue to
396 // be kLogZero.
397 for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
398 ++u) {
399 // Begin (GravesTh) Eq 7.9
400 // Add in the u, t - 1 term.
401 auto sum_log_alpha = kLogZero<TT>();
402 if (ctc_merge_repeated || l_prime[u] == blank_index_) {
403 sum_log_alpha = log_alpha->coeff(u, t - 1);
404 }
405
406 // Add in the u - 1, t - 1 term.
407 if (u > 0) {
408 sum_log_alpha =
409 LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1));
410 }
411
412 // Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2).
413 if (u > 1) {
414 const bool matching_labels_merge =
415 ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]);
416 if (l_prime[u] != blank_index_ && !matching_labels_merge) {
417 sum_log_alpha =
418 LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1));
419 }
420 }
421 // Multiply the summed alphas with the activation log probability.
422 log_alpha->coeffRef(u, t) =
423 log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha;
424 } // End (GravesTh) Eq 7.9.
425 }
426 }
427
428 // Calculates the beta(t, u) as described in (GravesTh) Section 7.3.
429 template <class TT>
CalculateBackwardVariables(const std::vector<int> & l_prime,const Matrix & y,bool ctc_merge_repeated,Matrix * log_beta)430 void CTCLossCalculator<TT>::CalculateBackwardVariables(
431 const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
432 Matrix* log_beta) const {
433 // Number of cols is the number of time steps = number of cols in target.
434 // Matrix log_beta =
435 // Matrix::Constant(l_prime.size(), y.cols() - output_delay_,
436 // kLogZero);
437 using Eigen::numext::log;
438
439 log_beta->setConstant(kLogZero<TT>());
440 int T = log_beta->cols();
441 int U = l_prime.size();
442 CHECK_EQ(U, log_beta->rows());
443
444 // Initial beta values in (GravesTh) Eq 7.13: log of probability 1.
445 for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0;
446
447 for (int t = T - 1 - 1; t >= 0; --t) {
448 // If there is not enough time to output the remaining labels or
449 // some labels have been skipped, then let log_beta(u, t) continue to
450 // be kLogZero.
451 for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
452 ++u) {
453 // Begin (GravesTh) Eq 7.15
454 // Add in the u, t + 1 term.
455 if (ctc_merge_repeated || l_prime[u] == blank_index_) {
456 log_beta->coeffRef(u, t) =
457 LogSumExp(log_beta->coeff(u, t),
458 log_beta->coeff(u, t + 1) +
459 log(y(l_prime[u], output_delay_ + t + 1)));
460 }
461
462 // Add in the u + 1, t + 1 term.
463 if (u + 1 < U) {
464 log_beta->coeffRef(u, t) =
465 LogSumExp(log_beta->coeff(u, t),
466 log_beta->coeff(u + 1, t + 1) +
467 log(y(l_prime[u + 1], output_delay_ + t + 1)));
468 }
469
470 // Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2).
471 if (u + 2 < U) {
472 const bool matching_labels_merge =
473 ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]);
474 if (l_prime[u] != blank_index_ && !matching_labels_merge) {
475 // Add in u + 2 term.
476 log_beta->coeffRef(u, t) =
477 LogSumExp(log_beta->coeff(u, t),
478 log_beta->coeff(u + 2, t + 1) +
479 log(y(l_prime[u + 2], output_delay_ + t + 1)));
480 }
481 } // End (GravesTh) Eq. 7.15
482 }
483 }
484 }
485
486 // Using (GravesTh) Eq 7.26 & 7.34.
487 template <typename TT>
CalculateGradient(const std::vector<int> & l_prime,const Matrix & y,const Matrix & log_alpha,const Matrix & log_beta,TT log_p_z_x,Matrix * dy)488 void CTCLossCalculator<TT>::CalculateGradient(const std::vector<int>& l_prime,
489 const Matrix& y,
490 const Matrix& log_alpha,
491 const Matrix& log_beta,
492 TT log_p_z_x, Matrix* dy) const {
493 // Only working with the leftmost part of dy for this batch element.
494 auto dy_b = dy->leftCols(y.cols());
495
496 // It is possible that no valid path is found if the activations for the
497 // targets are zero.
498 if (log_p_z_x == kLogZero<TT>()) {
499 LOG(WARNING) << "No valid path found.";
500 dy_b = y;
501 return;
502 }
503
504 int L = y.rows();
505 int T = y.cols();
506 int U = l_prime.size();
507
508 for (int t = 0; t < T - output_delay_; ++t) {
509 Array prob_sum(L);
510 prob_sum.setConstant(kLogZero<TT>());
511
512 for (int u = 0; u < U; ++u) {
513 int l = l_prime[u];
514 prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t));
515 }
516
517 for (int l = 0; l < L; ++l) {
518 // Negative term in (GravesTh) Eq 7.28.
519 auto negative_term = expf(prob_sum[l] - log_p_z_x);
520
521 dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term;
522 }
523 }
524 }
525
526 template <class TT>
GetLPrimeIndices(const std::vector<int> & l,std::vector<int> * l_prime)527 void CTCLossCalculator<TT>::GetLPrimeIndices(const std::vector<int>& l,
528 std::vector<int>* l_prime) const {
529 // Assumption is that l_prime is empty.
530 l_prime->reserve(2 * l.size() + 1);
531
532 for (auto label : l) {
533 l_prime->push_back(blank_index_);
534 l_prime->push_back(label);
535 }
536 // Add final blank to l'.
537 l_prime->push_back(blank_index_);
538 }
539
540 } // namespace ctc
541 } // namespace tensorflow
542
543 #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
544