1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // LINT.IfChange 16 17 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ 18 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ 19 20 #include <algorithm> 21 #include <memory> 22 #include <vector> 23 24 #include "third_party/eigen3/Eigen/Core" 25 #include "tensorflow/core/lib/gtl/flatmap.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/macros.h" 28 #include "tensorflow/core/platform/types.h" 29 #include "tensorflow/core/util/ctc/ctc_loss_util.h" 30 31 namespace tensorflow { 32 namespace ctc { 33 34 // The ctc_beam_search namespace holds several classes meant to be accessed only 35 // in case of extending the CTCBeamSearch decoder to allow custom scoring 36 // functions. 37 // 38 // BeamEntry is exposed through template arguments BeamScorer and BeamComparer 39 // of CTCBeamSearch (ctc_beam_search.h). 40 namespace ctc_beam_search { 41 42 struct EmptyBeamState {}; 43 44 template <typename T> 45 struct BeamProbability { BeamProbabilityBeamProbability46 BeamProbability() 47 : total(kLogZero<T>()), blank(kLogZero<T>()), label(kLogZero<T>()) {} ResetBeamProbability48 void Reset() { 49 total = kLogZero<T>(); 50 blank = kLogZero<T>(); 51 label = kLogZero<T>(); 52 } 53 T total; 54 T blank; 55 T label; 56 }; 57 58 template <class T, class CTCBeamState> 59 class BeamRoot; 60 61 template <class T, class CTCBeamState = EmptyBeamState> 62 struct BeamEntry { 63 // BeamRoot<CTCBeamState>::AddEntry() serves as the factory method. 64 friend BeamEntry<T, CTCBeamState>* BeamRoot<T, CTCBeamState>::AddEntry( 65 BeamEntry<T, CTCBeamState>* p, int l); ActiveBeamEntry66 inline bool Active() const { return newp.total != kLogZero<T>(); } 67 // Return the child at the given index, or construct a new one in-place if 68 // none was found. GetChildBeamEntry69 BeamEntry<T, CTCBeamState>& GetChild(int ind) { 70 auto entry = children.emplace(ind, nullptr); 71 auto& child_entry = entry.first->second; 72 // If this is a new child, populate the BeamEntry<CTCBeamState>*. 73 if (entry.second) { 74 child_entry = beam_root->AddEntry(this, ind); 75 } 76 return *child_entry; 77 } LabelSeqBeamEntry78 std::vector<int> LabelSeq(bool merge_repeated) const { 79 std::vector<int> labels; 80 int prev_label = -1; 81 const BeamEntry<T, CTCBeamState>* c = this; 82 while (c->parent != nullptr) { // Checking c->parent to skip root leaf. 83 if (!merge_repeated || c->label != prev_label) { 84 labels.push_back(c->label); 85 } 86 prev_label = c->label; 87 c = c->parent; 88 } 89 std::reverse(labels.begin(), labels.end()); 90 return labels; 91 } 92 93 BeamEntry<T, CTCBeamState>* parent; 94 int label; 95 // All instances of child BeamEntry are owned by *beam_root. 96 gtl::FlatMap<int, BeamEntry<T, CTCBeamState>*> children; 97 BeamProbability<T> oldp; 98 BeamProbability<T> newp; 99 CTCBeamState state; 100 101 private: 102 // Constructor giving parent, label, and the beam_root. 103 // The object pointed to by p cannot be copied and should not be moved, 104 // otherwise parent will become invalid. 105 // This private constructor is only called through the factory method 106 // BeamRoot<CTCBeamState>::AddEntry(). BeamEntryBeamEntry107 BeamEntry(BeamEntry* p, int l, BeamRoot<T, CTCBeamState>* beam_root) 108 : parent(p), label(l), beam_root(beam_root) {} 109 BeamRoot<T, CTCBeamState>* beam_root; 110 TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry); 111 }; 112 113 // This class owns all instances of BeamEntry. This is used to avoid recursive 114 // destructor call during destruction. 115 template <class T, class CTCBeamState = EmptyBeamState> 116 class BeamRoot { 117 public: BeamRoot(BeamEntry<T,CTCBeamState> * p,int l)118 BeamRoot(BeamEntry<T, CTCBeamState>* p, int l) { 119 root_entry_ = AddEntry(p, l); 120 } 121 BeamRoot(const BeamRoot&) = delete; 122 BeamRoot& operator=(const BeamRoot&) = delete; 123 AddEntry(BeamEntry<T,CTCBeamState> * p,int l)124 BeamEntry<T, CTCBeamState>* AddEntry(BeamEntry<T, CTCBeamState>* p, int l) { 125 auto* new_entry = new BeamEntry<T, CTCBeamState>(p, l, this); 126 beam_entries_.emplace_back(new_entry); 127 return new_entry; 128 } RootEntry()129 BeamEntry<T, CTCBeamState>* RootEntry() const { return root_entry_; } 130 131 private: 132 BeamEntry<T, CTCBeamState>* root_entry_ = nullptr; 133 std::vector<std::unique_ptr<BeamEntry<T, CTCBeamState>>> beam_entries_; 134 }; 135 136 // BeamComparer is the default beam comparer provided in CTCBeamSearch. 137 template <class T, class CTCBeamState = EmptyBeamState> 138 class BeamComparer { 139 public: ~BeamComparer()140 virtual ~BeamComparer() {} operator()141 virtual bool inline operator()(const BeamEntry<T, CTCBeamState>* a, 142 const BeamEntry<T, CTCBeamState>* b) const { 143 return a->newp.total > b->newp.total; 144 } 145 }; 146 147 } // namespace ctc_beam_search 148 149 } // namespace ctc 150 } // namespace tensorflow 151 152 #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ 153 // LINT.ThenChange(//tensorflow/lite/kernels/ctc/ctc_beam_entry.h) 154