xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/internal/tfprof_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/internal/tfprof_op.h"
17 
18 #include <stdio.h>
19 
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "tensorflow/core/platform/regexp.h"
25 #include "tensorflow/core/profiler/internal/tfprof_constants.h"
26 #include "tensorflow/core/profiler/internal/tfprof_tensor.h"
27 
28 namespace tensorflow {
29 namespace tfprof {
30 namespace {
FormatToalExecTime(const ShowMultiNode * node,const ShowMultiNode * root)31 string FormatToalExecTime(const ShowMultiNode* node,
32                           const ShowMultiNode* root) {
33   double accu_pct = 0.0;
34   double pct = 0.0;
35   if (node->proto().total_exec_micros() > 0) {
36     accu_pct = 100.0 * node->proto().total_exec_micros() /
37                root->proto().total_exec_micros();
38     pct =
39         100.0 * node->proto().exec_micros() / root->proto().total_exec_micros();
40   }
41 
42   return absl::StrFormat(
43       "%30s",
44       absl::StrFormat("%s (%.2f%%, %.2f%%)",
45                       FormatTime(node->proto().exec_micros()), accu_pct, pct));
46 }
FormatCPUExecTime(const ShowMultiNode * node,const ShowMultiNode * root)47 string FormatCPUExecTime(const ShowMultiNode* node, const ShowMultiNode* root) {
48   double accu_pct = 0.0;
49   double pct = 0.0;
50   if (node->proto().total_cpu_exec_micros() > 0) {
51     accu_pct = 100.0 * node->proto().total_cpu_exec_micros() /
52                root->proto().total_cpu_exec_micros();
53     pct = 100.0 * node->proto().cpu_exec_micros() /
54           root->proto().total_cpu_exec_micros();
55   }
56 
57   return absl::StrFormat(
58       "%30s", absl::StrFormat("%s (%.2f%%, %.2f%%)",
59                               FormatTime(node->proto().cpu_exec_micros()),
60                               accu_pct, pct));
61 }
FormatAcceleratorExecTime(const ShowMultiNode * node,const ShowMultiNode * root)62 string FormatAcceleratorExecTime(const ShowMultiNode* node,
63                                  const ShowMultiNode* root) {
64   double accu_pct = 0.0;
65   double pct = 0.0;
66   if (node->proto().total_accelerator_exec_micros() > 0) {
67     accu_pct = 100.0 * node->proto().total_accelerator_exec_micros() /
68                root->proto().total_accelerator_exec_micros();
69     pct = 100.0 * node->proto().accelerator_exec_micros() /
70           root->proto().total_accelerator_exec_micros();
71   }
72 
73   return absl::StrFormat(
74       "%30s",
75       absl::StrFormat("%s (%.2f%%, %.2f%%)",
76                       FormatTime(node->proto().accelerator_exec_micros()),
77                       accu_pct, pct));
78 }
79 }  // namespace
80 
AddNode(TFGraphNode * node)81 void TFOp::AddNode(TFGraphNode* node) {
82   const string& op = node->op();
83   if (tfcnodes_map_.find(op) == tfcnodes_map_.end()) {
84     tfcnodes_map_[op] =
85         std::unique_ptr<TFMultiGraphNode>(new TFMultiGraphNode(op));
86   }
87   TFMultiGraphNode* tfcnode = tfcnodes_map_[op].get();
88   tfcnode->AddGraphNode(node);
89 }
90 
Build()91 void TFOp::Build() {
92   for (auto& tn : tfcnodes_map_) {
93     cnodes_map_[tn.first] =
94         std::unique_ptr<OpNode>(new OpNode(tn.second.get()));
95   }
96 
97   tfcnodes_map_[kTFProfRoot] =
98       std::unique_ptr<TFMultiGraphNode>(new TFMultiGraphNode(kTFProfRoot));
99   root_.reset(new OpNode(tfcnodes_map_[kTFProfRoot].get()));
100 }
101 
ShowInternal(const Options & opts,Timeline * timeline)102 const ShowMultiNode* TFOp::ShowInternal(const Options& opts,
103                                         Timeline* timeline) {
104   root_->ResetTotalStats();
105   if (opts.output_type == kOutput[3]) {
106     absl::FPrintF(stderr, "Only 'code' view supports pprof output now.\n");
107     return root_.get();
108   }
109   if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) {
110     root_->formatted_str = FormatNode(root_.get(), root_.get(), opts);
111   }
112   if (timeline) {
113     absl::FPrintF(stderr,
114                   "op view doesn't support timeline yet. "
115                   "Consider graph/scope/code view.\n");
116     return root_.get();
117   }
118   if (cnodes_map_.empty()) {
119     return root_.get();
120   }
121 
122   std::vector<OpNode*> nodes;
123   for (auto& n : cnodes_map_) {
124     n.second->account = ReAccount(n.second.get(), opts);
125     n.second->ResetTotalStats();
126     n.second->AddSelfToTotalStats();
127     nodes.push_back(n.second.get());
128   }
129   nodes = SortNodes(nodes, opts);
130   // pre keeps track of previous visited node.
131   OpNode* pre = nullptr;
132   std::vector<OpNode*> account_nodes;
133   for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
134     if ((*it)->account) {
135       if (pre) (*it)->AggregateTotalStats(pre);
136       account_nodes.push_back(*it);
137       pre = *it;
138     }
139   }
140   std::reverse(std::begin(account_nodes), std::end(account_nodes));
141   if (pre) {
142     root_->AggregateTotalStats(pre);
143   }
144 
145   // Perform the display and optionally redo accounting.
146   int64_t depth = 0;
147   std::vector<OpNode*> show_nodes;
148   int64_t start = SearchRoot(account_nodes, opts.start_name_regexes);
149   for (int64_t i = start, end = account_nodes.size(); i < end; ++i, ++depth) {
150     OpNode* n = account_nodes[i];
151     if (ShouldTrim(n, opts.trim_name_regexes) || depth > opts.max_depth) {
152       break;
153     }
154     n->show = ShouldShow(n, opts, depth);
155     if (n->show) show_nodes.push_back(n);
156   }
157 
158   pre = nullptr;
159   for (auto it = show_nodes.rbegin(); it != show_nodes.rend(); ++it) {
160     if (opts.account_displayed_op_only) {
161       (*it)->ResetTotalStats();
162       (*it)->AddSelfToTotalStats();
163       if (pre) (*it)->AggregateTotalStats(pre);
164     }
165     pre = *it;
166   }
167   if (opts.account_displayed_op_only) {
168     root_->ResetTotalStats();
169     if (pre) {
170       root_->AggregateTotalStats(pre);
171     }
172   }
173   if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) {
174     string display_str = FormatLegend(opts);
175     for (OpNode* node : show_nodes) {
176       display_str += FormatNode(node, root_.get(), opts);
177     }
178     // In op view, we don't show root (total). But it will still in proto.
179     // TODO(xpan): Is it the right choice?
180     root_->formatted_str = display_str;
181   }
182   // Populate the children field.
183   auto* pre_pb = root_->mutable_proto();
184   for (auto& show_node : show_nodes) {
185     pre_pb->clear_children();
186     pre_pb->add_children()->Swap(show_node->mutable_proto());
187     pre_pb = pre_pb->mutable_children(0);
188   }
189   return root_.get();
190 }
191 
SearchRoot(const std::vector<OpNode * > nodes,const std::vector<string> & regexes)192 int64_t TFOp::SearchRoot(const std::vector<OpNode*> nodes,
193                          const std::vector<string>& regexes) {
194   if (regexes.empty() || (regexes.size() == 1 && regexes[0] == ".*")) {
195     return 0;
196   }
197   int64_t i = 0;
198   const int64_t nodes_size = nodes.size();
199   for (; i < nodes_size; ++i) {
200     for (const string& regex : regexes) {
201       if (RE2::FullMatch(nodes[i]->name(), regex)) {
202         return i;
203       }
204     }
205   }
206   return i;
207 }
208 
FormatMemoryNode(int64_t node_total_bytes,int64_t root_total_bytes,int64_t node_bytes) const209 string TFOp::FormatMemoryNode(int64_t node_total_bytes,
210                               int64_t root_total_bytes,
211                               int64_t node_bytes) const {
212   double accu_pct = 0.0;
213   double pct = 0.0;
214   if (node_bytes > 0) {
215     accu_pct = 100.0 * node_total_bytes / root_total_bytes;
216     pct = 100.0 * node_bytes / root_total_bytes;
217   }
218   return absl::StrFormat(
219       "%30s", absl::StrFormat("%s (%.2f%%, %.2f%%)", FormatMemory(node_bytes),
220                               accu_pct, pct));
221 }
222 
FormatNode(OpNode * node,OpNode * root,const Options & opts) const223 string TFOp::FormatNode(OpNode* node, OpNode* root, const Options& opts) const {
224   std::vector<string> attrs;
225 
226   if (opts.select.find(kShown[0]) != opts.select.end()) {
227     attrs.push_back(FormatMemoryNode(node->proto().total_requested_bytes(),
228                                      root->proto().total_requested_bytes(),
229                                      node->proto().requested_bytes()));
230   }
231 
232   if (opts.select.find(kShown[11]) != opts.select.end()) {
233     attrs.push_back(FormatMemoryNode(node->proto().total_peak_bytes(),
234                                      root->proto().total_peak_bytes(),
235                                      node->proto().peak_bytes()));
236   }
237 
238   if (opts.select.find(kShown[12]) != opts.select.end()) {
239     attrs.push_back(FormatMemoryNode(node->proto().total_residual_bytes(),
240                                      root->proto().total_residual_bytes(),
241                                      node->proto().residual_bytes()));
242   }
243   if (opts.select.find(kShown[13]) != opts.select.end()) {
244     attrs.push_back(FormatMemoryNode(node->proto().total_output_bytes(),
245                                      root->proto().total_output_bytes(),
246                                      node->proto().output_bytes()));
247   }
248 
249   if (opts.select.find(kShown[1]) != opts.select.end()) {
250     attrs.push_back(FormatToalExecTime(node, root));
251     attrs.push_back(FormatAcceleratorExecTime(node, root));
252     attrs.push_back(FormatCPUExecTime(node, root));
253   }
254   if (opts.select.find(kShown[9]) != opts.select.end() &&
255       opts.select.find(kShown[1]) == opts.select.end()) {
256     attrs.push_back(FormatAcceleratorExecTime(node, root));
257   }
258   if (opts.select.find(kShown[10]) != opts.select.end() &&
259       opts.select.find(kShown[1]) == opts.select.end()) {
260     attrs.push_back(FormatCPUExecTime(node, root));
261   }
262   if (opts.select.find(kShown[2]) != opts.select.end()) {
263     double accu_pct = 0.0;
264     double pct = 0.0;
265     if (node->proto().total_parameters() > 0) {
266       accu_pct = 100.0 * node->proto().total_parameters() /
267                  root->proto().total_parameters();
268       pct =
269           100.0 * node->proto().parameters() / root->proto().total_parameters();
270     }
271     attrs.push_back(absl::StrFormat(
272         "%30s", absl::StrFormat("%s params (%.2f%%, %.2f%%)",
273                                 FormatNumber(node->proto().parameters()),
274                                 accu_pct, pct)));
275   }
276 
277   if (opts.select.find(kShown[3]) != opts.select.end()) {
278     double accu_pct = 0.0;
279     double pct = 0.0;
280     if (node->proto().total_float_ops() > 0) {
281       accu_pct = 100.0 * node->proto().total_float_ops() /
282                  root->proto().total_float_ops();
283       pct = 100.0 * node->proto().float_ops() / root->proto().total_float_ops();
284     }
285 
286     attrs.push_back(absl::StrFormat(
287         "%30s", absl::StrFormat("%s float_ops (%.2f%%, %.2f%%)",
288                                 FormatNumber(node->proto().float_ops()),
289                                 accu_pct, pct)));
290   }
291 
292   if (opts.select.find(kShown[5]) != opts.select.end()) {
293     attrs.push_back(absl::StrJoin(node->node->devices(), "|"));
294   }
295 
296   if (opts.select.find(kShown[6]) != opts.select.end()) {
297     std::set<string> op_types = node->node->op_types();
298     attrs.push_back(absl::StrJoin(op_types, "|"));
299   }
300 
301   if (opts.select.find(kShown[7]) != opts.select.end()) {
302     int64_t total_runs = 0;
303     for (const auto& gnode : node->proto().graph_nodes()) {
304       total_runs += gnode.run_count();
305     }
306     attrs.push_back(absl::StrFormat(
307         "%10s", absl::StrFormat("%d|%d", total_runs,
308                                 node->proto().graph_nodes_size())));
309   }
310 
311   string node_str =
312       absl::StrFormat("%-25s%s\n", node->name(), absl::StrJoin(attrs, ", "));
313 
314   if (opts.select.find(kShown[8]) != opts.select.end()) {
315     string input_shape_str = FormatInputShapes(node->proto());
316     if (!input_shape_str.empty()) {
317       node_str = absl::StrFormat("%s\n%s\n\n", node_str, input_shape_str);
318     }
319   }
320   return node_str;
321 }
322 }  // namespace tfprof
323 }  // namespace tensorflow
324