xref: /aosp_15_r20/external/tensorflow/tensorflow/core/lib/monitoring/sampler.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
18 
19 // clang-format off
20 // Required for IS_MOBILE_PLATFORM
21 #include "tensorflow/core/platform/platform.h"
22 // clang-format on
23 
24 // We replace this implementation with a null implementation for mobile
25 // platforms.
26 #ifdef IS_MOBILE_PLATFORM
27 
28 #include <memory>
29 
30 #include "tensorflow/core/framework/summary.pb.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/monitoring/metric_def.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace tensorflow {
37 namespace monitoring {
38 
39 // SamplerCell which has a null implementation.
40 class SamplerCell {
41  public:
SamplerCell()42   SamplerCell() {}
~SamplerCell()43   ~SamplerCell() {}
44 
Add(double value)45   void Add(double value) {}
value()46   HistogramProto value() const { return HistogramProto(); }
47 
48  private:
49   TF_DISALLOW_COPY_AND_ASSIGN(SamplerCell);
50 };
51 
52 // Buckets which has a null implementation.
53 class Buckets {
54  public:
55   Buckets() = default;
56   ~Buckets() = default;
57 
Explicit(std::initializer_list<double> bucket_limits)58   static std::unique_ptr<Buckets> Explicit(
59       std::initializer_list<double> bucket_limits) {
60     return std::unique_ptr<Buckets>(new Buckets());
61   }
62 
Exponential(double scale,double growth_factor,int bucket_count)63   static std::unique_ptr<Buckets> Exponential(double scale,
64                                               double growth_factor,
65                                               int bucket_count) {
66     return std::unique_ptr<Buckets>(new Buckets());
67   }
68 
explicit_bounds()69   const std::vector<double>& explicit_bounds() const {
70     return explicit_bounds_;
71   }
72 
73  private:
74   std::vector<double> explicit_bounds_;
75 
76   TF_DISALLOW_COPY_AND_ASSIGN(Buckets);
77 };
78 
79 // Sampler which has a null implementation.
80 template <int NumLabels>
81 class Sampler {
82  public:
~Sampler()83   ~Sampler() {}
84 
85   template <typename... MetricDefArgs>
New(const MetricDef<MetricKind::kCumulative,HistogramProto,NumLabels> & metric_def,std::unique_ptr<Buckets> buckets)86   static Sampler* New(const MetricDef<MetricKind::kCumulative, HistogramProto,
87                                       NumLabels>& metric_def,
88                       std::unique_ptr<Buckets> buckets) {
89     return new Sampler<NumLabels>(std::move(buckets));
90   }
91 
92   template <typename... Labels>
GetCell(const Labels &...labels)93   SamplerCell* GetCell(const Labels&... labels) {
94     return &default_sampler_cell_;
95   }
96 
GetStatus()97   Status GetStatus() { return Status::OK(); }
98 
99  private:
Sampler(std::unique_ptr<Buckets> buckets)100   Sampler(std::unique_ptr<Buckets> buckets) : buckets_(std::move(buckets)) {}
101 
102   SamplerCell default_sampler_cell_;
103   std::unique_ptr<Buckets> buckets_;
104 
105   TF_DISALLOW_COPY_AND_ASSIGN(Sampler);
106 };
107 
108 }  // namespace monitoring
109 }  // namespace tensorflow
110 
111 #else  // IS_MOBILE_PLATFORM
112 
113 #include <float.h>
114 
115 #include <map>
116 
117 #include "tensorflow/core/framework/summary.pb.h"
118 #include "tensorflow/core/lib/core/status.h"
119 #include "tensorflow/core/lib/histogram/histogram.h"
120 #include "tensorflow/core/lib/monitoring/collection_registry.h"
121 #include "tensorflow/core/lib/monitoring/metric_def.h"
122 #include "tensorflow/core/platform/macros.h"
123 #include "tensorflow/core/platform/mutex.h"
124 #include "tensorflow/core/platform/thread_annotations.h"
125 
126 namespace tensorflow {
127 namespace monitoring {
128 
129 // SamplerCell stores each value of an Sampler.
130 //
131 // A cell can be passed off to a module which may repeatedly update it without
132 // needing further map-indexing computations. This improves both encapsulation
133 // (separate modules can own a cell each, without needing to know about the map
134 // to which both cells belong) and performance (since map indexing and
135 // associated locking are both avoided).
136 //
137 // This class is thread-safe.
138 class SamplerCell {
139  public:
SamplerCell(const std::vector<double> & bucket_limits)140   SamplerCell(const std::vector<double>& bucket_limits)
141       : histogram_(bucket_limits) {}
142 
~SamplerCell()143   ~SamplerCell() {}
144 
145   // Atomically adds a sample.
146   void Add(double sample);
147 
148   // Returns the current histogram value as a proto.
149   HistogramProto value() const;
150 
151  private:
152   histogram::ThreadSafeHistogram histogram_;
153 
154   TF_DISALLOW_COPY_AND_ASSIGN(SamplerCell);
155 };
156 
157 // Bucketing strategies for the samplers.
158 //
159 // We automatically add -DBL_MAX and DBL_MAX to the ranges, so that no sample
160 // goes out of bounds.
161 //
162 // WARNING: If you are changing the interface here, please do change the same in
163 // mobile_sampler.h.
164 class Buckets {
165  public:
166   virtual ~Buckets() = default;
167 
168   // Sets up buckets of the form:
169   // [-DBL_MAX, ..., scale * growth^i,
170   //   scale * growth_factor^(i + 1), ..., DBL_MAX].
171   //
172   // So for powers of 2 with a bucket count of 10, you would say (1, 2, 10)
173   static std::unique_ptr<Buckets> Exponential(double scale,
174                                               double growth_factor,
175                                               int bucket_count);
176 
177   // Sets up buckets of the form:
178   // [-DBL_MAX, ..., bucket_limits[i], bucket_limits[i + 1], ..., DBL_MAX].
179   static std::unique_ptr<Buckets> Explicit(
180       std::initializer_list<double> bucket_limits);
181 
182   // This alternative Explicit Buckets factory method is primarily meant to be
183   // used by the CLIF layer code paths that are incompatible with
184   // initialize_lists.
185   static std::unique_ptr<Buckets> Explicit(std::vector<double> bucket_limits);
186 
187   virtual const std::vector<double>& explicit_bounds() const = 0;
188 };
189 
190 // A stateful class for updating a cumulative histogram metric.
191 //
192 // This class encapsulates a set of histograms (or a single histogram for a
193 // label-less metric) configured with a list of increasing bucket boundaries.
194 // Each histogram is identified by a tuple of labels. The class allows the
195 // user to add a sample to each histogram value.
196 //
197 // Sampler allocates storage and maintains a cell for each value. You can
198 // retrieve an individual cell using a label-tuple and update it separately.
199 // This improves performance since operations related to retrieval, like
200 // map-indexing and locking, are avoided.
201 //
202 // This class is thread-safe.
203 template <int NumLabels>
204 class Sampler {
205  public:
~Sampler()206   ~Sampler() {
207     // Deleted here, before the metric_def is destroyed.
208     registration_handle_.reset();
209   }
210 
211   // Creates the metric based on the metric-definition arguments and buckets.
212   //
213   // Example;
214   // auto* sampler_with_label = Sampler<1>::New({"/tensorflow/sampler",
215   //   "Tensorflow sampler", "MyLabelName"}, {10.0, 20.0, 30.0});
216   static Sampler* New(const MetricDef<MetricKind::kCumulative, HistogramProto,
217                                       NumLabels>& metric_def,
218                       std::unique_ptr<Buckets> buckets);
219 
220   // Retrieves the cell for the specified labels, creating it on demand if
221   // not already present.
222   template <typename... Labels>
223   SamplerCell* GetCell(const Labels&... labels) TF_LOCKS_EXCLUDED(mu_);
224 
GetStatus()225   Status GetStatus() { return status_; }
226 
227  private:
228   friend class SamplerCell;
229 
Sampler(const MetricDef<MetricKind::kCumulative,HistogramProto,NumLabels> & metric_def,std::unique_ptr<Buckets> buckets)230   Sampler(const MetricDef<MetricKind::kCumulative, HistogramProto, NumLabels>&
231               metric_def,
232           std::unique_ptr<Buckets> buckets)
233       : metric_def_(metric_def),
234         buckets_(std::move(buckets)),
235         registration_handle_(CollectionRegistry::Default()->Register(
236             &metric_def_, [&](MetricCollectorGetter getter) {
237               auto metric_collector = getter.Get(&metric_def_);
238 
239               mutex_lock l(mu_);
240               for (const auto& cell : cells_) {
241                 metric_collector.CollectValue(cell.first, cell.second.value());
242               }
243             })) {
244     if (registration_handle_) {
245       status_ = OkStatus();
246     } else {
247       status_ = Status(tensorflow::error::Code::ALREADY_EXISTS,
248                        "Another metric with the same name already exists.");
249     }
250   }
251 
252   mutable mutex mu_;
253 
254   Status status_;
255 
256   // The metric definition. This will be used to identify the metric when we
257   // register it for collection.
258   const MetricDef<MetricKind::kCumulative, HistogramProto, NumLabels>
259       metric_def_;
260 
261   // Bucket limits for the histograms in the cells.
262   std::unique_ptr<Buckets> buckets_;
263 
264   // Registration handle with the CollectionRegistry.
265   std::unique_ptr<CollectionRegistry::RegistrationHandle> registration_handle_;
266 
267   using LabelArray = std::array<string, NumLabels>;
268   // we need a container here that guarantees pointer stability of the value,
269   // namely, the pointer of the value should remain valid even after more cells
270   // are inserted.
271   std::map<LabelArray, SamplerCell> cells_ TF_GUARDED_BY(mu_);
272 
273   TF_DISALLOW_COPY_AND_ASSIGN(Sampler);
274 };
275 
276 ////
277 //  Implementation details follow. API readers may skip.
278 ////
279 
Add(const double sample)280 inline void SamplerCell::Add(const double sample) { histogram_.Add(sample); }
281 
value()282 inline HistogramProto SamplerCell::value() const {
283   HistogramProto pb;
284   histogram_.EncodeToProto(&pb, true /* preserve_zero_buckets */);
285   return pb;
286 }
287 
288 template <int NumLabels>
New(const MetricDef<MetricKind::kCumulative,HistogramProto,NumLabels> & metric_def,std::unique_ptr<Buckets> buckets)289 Sampler<NumLabels>* Sampler<NumLabels>::New(
290     const MetricDef<MetricKind::kCumulative, HistogramProto, NumLabels>&
291         metric_def,
292     std::unique_ptr<Buckets> buckets) {
293   return new Sampler<NumLabels>(metric_def, std::move(buckets));
294 }
295 
296 template <int NumLabels>
297 template <typename... Labels>
GetCell(const Labels &...labels)298 SamplerCell* Sampler<NumLabels>::GetCell(const Labels&... labels)
299     TF_LOCKS_EXCLUDED(mu_) {
300   // Provides a more informative error message than the one during array
301   // construction below.
302   static_assert(sizeof...(Labels) == NumLabels,
303                 "Mismatch between Sampler<NumLabels> and number of labels "
304                 "provided in GetCell(...).");
305 
306   const LabelArray& label_array = {{labels...}};
307   mutex_lock l(mu_);
308   const auto found_it = cells_.find(label_array);
309   if (found_it != cells_.end()) {
310     return &(found_it->second);
311   }
312   return &(cells_
313                .emplace(std::piecewise_construct,
314                         std::forward_as_tuple(label_array),
315                         std::forward_as_tuple(buckets_->explicit_bounds()))
316                .first->second);
317 }
318 
319 }  // namespace monitoring
320 }  // namespace tensorflow
321 
322 #endif  // IS_MOBILE_PLATFORM
323 #endif  // TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
324