xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/pywrap_saved_model_metrics.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");;
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/strings/string_view.h"
17 #include "pybind11/pybind11.h"
18 #include "tensorflow/cc/saved_model/metrics.h"
19 
20 namespace tensorflow {
21 namespace saved_model {
22 namespace python {
23 
24 namespace py = pybind11;
25 
DefineMetricsModule(py::module main_module)26 void DefineMetricsModule(py::module main_module) {
27   auto m = main_module.def_submodule("metrics");
28 
29   m.doc() = "Python bindings for TensorFlow SavedModel and Checkpoint Metrics.";
30 
31   m.def(
32       "IncrementWrite",
33       [](const char* write_version) {
34         metrics::SavedModelWrite(write_version).IncrementBy(1);
35       },
36       py::kw_only(), py::arg("write_version"),
37       py::doc("Increment the '/tensorflow/core/saved_model/write/count' "
38               "counter."));
39 
40   m.def(
41       "GetWrite",
42       [](const char* write_version) {
43         return metrics::SavedModelWrite(write_version).value();
44       },
45       py::kw_only(), py::arg("write_version"),
46       py::doc("Get value of '/tensorflow/core/saved_model/write/count' "
47               "counter."));
48 
49   m.def(
50       "IncrementWriteApi",
51       [](const char* api_label) {
52         metrics::SavedModelWriteApi(api_label).IncrementBy(1);
53       },
54       py::doc("Increment the '/tensorflow/core/saved_model/write/api' "
55               "counter for API with `api_label`"));
56 
57   m.def(
58       "GetWriteApi",
59       [](const char* api_label) {
60         return metrics::SavedModelWriteApi(api_label).value();
61       },
62       py::doc("Get value of '/tensorflow/core/saved_model/write/api' "
63               "counter for `api_label` cell."));
64 
65   m.def(
66       "IncrementRead",
67       [](const char* write_version) {
68         metrics::SavedModelRead(write_version).IncrementBy(1);
69       },
70       py::kw_only(), py::arg("write_version"),
71       py::doc("Increment the '/tensorflow/core/saved_model/read/count' "
72               "counter after reading a SavedModel with the specifed "
73               "`write_version`."));
74 
75   m.def(
76       "GetRead",
77       [](const char* write_version) {
78         return metrics::SavedModelRead(write_version).value();
79       },
80       py::kw_only(), py::arg("write_version"),
81       py::doc("Get value of '/tensorflow/core/saved_model/read/count' "
82               "counter for SavedModels with the specified `write_version`."));
83 
84   m.def(
85       "IncrementReadApi",
86       [](const char* api_label) {
87         metrics::SavedModelReadApi(api_label).IncrementBy(1);
88       },
89       py::doc("Increment the '/tensorflow/core/saved_model/read/api' "
90               "counter for API with `api_label`."));
91 
92   m.def(
93       "GetReadApi",
94       [](const char* api_label) {
95         return metrics::SavedModelReadApi(api_label).value();
96       },
97       py::doc("Get value of '/tensorflow/core/saved_model/read/api' "
98               "counter for `api_label` cell."));
99 
100   m.def(
101       "AddCheckpointReadDuration",
102       [](const char* api_label, double microseconds) {
103         metrics::CheckpointReadDuration(api_label).Add(microseconds);
104       },
105       py::kw_only(), py::arg("api_label"), py::arg("microseconds"),
106       py::doc("Add `microseconds` to the cell `api_label`for "
107               "'/tensorflow/core/checkpoint/read/read_durations'."));
108 
109   m.def(
110       "GetCheckpointReadDurations",
111       [](const char* api_label) {
112         // This function is called sparingly in unit tests, so protobuf
113         // (de)-serialization round trip is not an issue.
114         return py::bytes(metrics::CheckpointReadDuration(api_label)
115                              .value()
116                              .SerializeAsString());
117       },
118       py::kw_only(), py::arg("api_label"),
119       py::doc("Get serialized HistogramProto of `api_label` cell for "
120               "'/tensorflow/core/checkpoint/read/read_durations'."));
121 
122   m.def(
123       "AddCheckpointWriteDuration",
124       [](const char* api_label, double microseconds) {
125         metrics::CheckpointWriteDuration(api_label).Add(microseconds);
126       },
127       py::kw_only(), py::arg("api_label"), py::arg("microseconds"),
128       py::doc("Add `microseconds` to the cell `api_label` for "
129               "'/tensorflow/core/checkpoint/write/write_durations'."));
130 
131   m.def(
132       "GetCheckpointWriteDurations",
133       [](const char* api_label) {
134         // This function is called sparingly, so protobuf (de)-serialization
135         // round trip is not an issue.
136         return py::bytes(metrics::CheckpointWriteDuration(api_label)
137                              .value()
138                              .SerializeAsString());
139       },
140       py::kw_only(), py::arg("api_label"),
141       py::doc("Get serialized HistogramProto of `api_label` cell for "
142               "'/tensorflow/core/checkpoint/write/write_durations'."));
143 
144   m.def(
145       "AddAsyncCheckpointWriteDuration",
146       [](const char* api_label, double microseconds) {
147         metrics::AsyncCheckpointWriteDuration(api_label).Add(microseconds);
148       },
149       py::kw_only(), py::arg("api_label"), py::arg("microseconds"),
150       py::doc("Add `microseconds` to the cell `api_label` for "
151               "'/tensorflow/core/checkpoint/write/async_write_durations'."));
152 
153   m.def(
154       "GetAsyncCheckpointWriteDurations",
155       [](const char* api_label) {
156         // This function is called sparingly, so protobuf (de)-serialization
157         // round trip is not an issue.
158         return py::bytes(metrics::AsyncCheckpointWriteDuration(api_label)
159                              .value()
160                              .SerializeAsString());
161       },
162       py::kw_only(), py::arg("api_label"),
163       py::doc("Get serialized HistogramProto of `api_label` cell for "
164               "'/tensorflow/core/checkpoint/write/async_write_durations'."));
165 
166   m.def(
167       "AddTrainingTimeSaved",
168       [](const char* api_label, double microseconds) {
169         metrics::TrainingTimeSaved(api_label).IncrementBy(microseconds);
170       },
171       py::kw_only(), py::arg("api_label"), py::arg("microseconds"),
172       py::doc("Add `microseconds` to the cell `api_label` for "
173               "'/tensorflow/core/checkpoint/write/training_time_saved'."));
174 
175   m.def(
176       "GetTrainingTimeSaved",
177       [](const char* api_label) {
178         return metrics::TrainingTimeSaved(api_label).value();
179       },
180       py::kw_only(), py::arg("api_label"),
181       py::doc("Get cell `api_label` for "
182               "'/tensorflow/core/checkpoint/write/training_time_saved'."));
183 
184   m.def(
185       "CalculateFileSize",
186       [](const char* filename) {
187         Env* env = Env::Default();
188         uint64 filesize = 0;
189         if (!env->GetFileSize(filename, &filesize).ok()) {
190           return (int64_t)-1;
191         }
192         // Convert to MB.
193         int64_t filesize_mb = filesize / 1000;
194         // Round to the nearest 100 MB.
195         // Smaller multiple.
196         int64_t a = (filesize_mb / 100) * 100;
197         // Larger multiple.
198         int64_t b = a + 100;
199         // Return closest of two.
200         return (filesize_mb - a > b - filesize_mb) ? b : a;
201       },
202       py::doc("Calculate filesize (MB) for `filename`, rounding to the nearest "
203               "100MB. Returns -1 if `filename` is invalid."));
204 
205   m.def(
206       "RecordCheckpointSize",
207       [](const char* api_label, int64_t filesize) {
208         metrics::CheckpointSize(api_label, filesize).IncrementBy(1);
209       },
210       py::kw_only(), py::arg("api_label"), py::arg("filesize"),
211       py::doc("Increment the "
212               "'/tensorflow/core/checkpoint/write/checkpoint_size' counter for "
213               "cell (api_label, filesize) after writing a checkpoint."));
214 
215   m.def(
216       "GetCheckpointSize",
217       [](const char* api_label, uint64 filesize) {
218         return metrics::CheckpointSize(api_label, filesize).value();
219       },
220       py::kw_only(), py::arg("api_label"), py::arg("filesize"),
221       py::doc("Get cell (api_label, filesize) for "
222               "'/tensorflow/core/checkpoint/write/checkpoint_size'."));
223 }
224 
225 }  // namespace python
226 }  // namespace saved_model
227 }  // namespace tensorflow
228