xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/ctc/ctc_beam_entry.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // LINT.IfChange
16 
17 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
18 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
19 
20 #include <algorithm>
21 #include <memory>
22 #include <vector>
23 
24 #include "third_party/eigen3/Eigen/Core"
25 #include "tensorflow/core/lib/gtl/flatmap.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/ctc/ctc_loss_util.h"
30 
31 namespace tensorflow {
32 namespace ctc {
33 
34 // The ctc_beam_search namespace holds several classes meant to be accessed only
35 // in case of extending the CTCBeamSearch decoder to allow custom scoring
36 // functions.
37 //
38 // BeamEntry is exposed through template arguments BeamScorer and BeamComparer
39 // of CTCBeamSearch (ctc_beam_search.h).
40 namespace ctc_beam_search {
41 
42 struct EmptyBeamState {};
43 
44 template <typename T>
45 struct BeamProbability {
BeamProbabilityBeamProbability46   BeamProbability()
47       : total(kLogZero<T>()), blank(kLogZero<T>()), label(kLogZero<T>()) {}
ResetBeamProbability48   void Reset() {
49     total = kLogZero<T>();
50     blank = kLogZero<T>();
51     label = kLogZero<T>();
52   }
53   T total;
54   T blank;
55   T label;
56 };
57 
58 template <class T, class CTCBeamState>
59 class BeamRoot;
60 
61 template <class T, class CTCBeamState = EmptyBeamState>
62 struct BeamEntry {
63   // BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
64   friend BeamEntry<T, CTCBeamState>* BeamRoot<T, CTCBeamState>::AddEntry(
65       BeamEntry<T, CTCBeamState>* p, int l);
ActiveBeamEntry66   inline bool Active() const { return newp.total != kLogZero<T>(); }
67   // Return the child at the given index, or construct a new one in-place if
68   // none was found.
GetChildBeamEntry69   BeamEntry<T, CTCBeamState>& GetChild(int ind) {
70     auto entry = children.emplace(ind, nullptr);
71     auto& child_entry = entry.first->second;
72     // If this is a new child, populate the BeamEntry<CTCBeamState>*.
73     if (entry.second) {
74       child_entry = beam_root->AddEntry(this, ind);
75     }
76     return *child_entry;
77   }
LabelSeqBeamEntry78   std::vector<int> LabelSeq(bool merge_repeated) const {
79     std::vector<int> labels;
80     int prev_label = -1;
81     const BeamEntry<T, CTCBeamState>* c = this;
82     while (c->parent != nullptr) {  // Checking c->parent to skip root leaf.
83       if (!merge_repeated || c->label != prev_label) {
84         labels.push_back(c->label);
85       }
86       prev_label = c->label;
87       c = c->parent;
88     }
89     std::reverse(labels.begin(), labels.end());
90     return labels;
91   }
92 
93   BeamEntry<T, CTCBeamState>* parent;
94   int label;
95   // All instances of child BeamEntry are owned by *beam_root.
96   gtl::FlatMap<int, BeamEntry<T, CTCBeamState>*> children;
97   BeamProbability<T> oldp;
98   BeamProbability<T> newp;
99   CTCBeamState state;
100 
101  private:
102   // Constructor giving parent, label, and the beam_root.
103   // The object pointed to by p cannot be copied and should not be moved,
104   // otherwise parent will become invalid.
105   // This private constructor is only called through the factory method
106   // BeamRoot<CTCBeamState>::AddEntry().
BeamEntryBeamEntry107   BeamEntry(BeamEntry* p, int l, BeamRoot<T, CTCBeamState>* beam_root)
108       : parent(p), label(l), beam_root(beam_root) {}
109   BeamRoot<T, CTCBeamState>* beam_root;
110   TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry);
111 };
112 
113 // This class owns all instances of BeamEntry.  This is used to avoid recursive
114 // destructor call during destruction.
115 template <class T, class CTCBeamState = EmptyBeamState>
116 class BeamRoot {
117  public:
BeamRoot(BeamEntry<T,CTCBeamState> * p,int l)118   BeamRoot(BeamEntry<T, CTCBeamState>* p, int l) {
119     root_entry_ = AddEntry(p, l);
120   }
121   BeamRoot(const BeamRoot&) = delete;
122   BeamRoot& operator=(const BeamRoot&) = delete;
123 
AddEntry(BeamEntry<T,CTCBeamState> * p,int l)124   BeamEntry<T, CTCBeamState>* AddEntry(BeamEntry<T, CTCBeamState>* p, int l) {
125     auto* new_entry = new BeamEntry<T, CTCBeamState>(p, l, this);
126     beam_entries_.emplace_back(new_entry);
127     return new_entry;
128   }
RootEntry()129   BeamEntry<T, CTCBeamState>* RootEntry() const { return root_entry_; }
130 
131  private:
132   BeamEntry<T, CTCBeamState>* root_entry_ = nullptr;
133   std::vector<std::unique_ptr<BeamEntry<T, CTCBeamState>>> beam_entries_;
134 };
135 
136 // BeamComparer is the default beam comparer provided in CTCBeamSearch.
137 template <class T, class CTCBeamState = EmptyBeamState>
138 class BeamComparer {
139  public:
~BeamComparer()140   virtual ~BeamComparer() {}
operator()141   virtual bool inline operator()(const BeamEntry<T, CTCBeamState>* a,
142                                  const BeamEntry<T, CTCBeamState>* b) const {
143     return a->newp.total > b->newp.total;
144   }
145 };
146 
147 }  // namespace ctc_beam_search
148 
149 }  // namespace ctc
150 }  // namespace tensorflow
151 
152 #endif  // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
153 // LINT.ThenChange(//tensorflow/lite/kernels/ctc/ctc_beam_entry.h)
154