1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/utils/group_events.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <iterator>
22 #include <map>
23 #include <memory>
24 #include <optional>
25 #include <queue>
26 #include <string>
27 #include <utility>
28 #include <vector>
29
30 #include "absl/algorithm/container.h"
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/strings/match.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_join.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/profiler/lib/connected_traceme.h"
38 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
39 #include "tensorflow/core/profiler/utils/xplane_builder.h"
40 #include "tensorflow/core/profiler/utils/xplane_schema.h"
41 #include "tensorflow/core/profiler/utils/xplane_utils.h"
42 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
43
44 namespace tensorflow {
45 namespace profiler {
46 namespace {
47
48 // Creates stat metadata for the stats which may be added by grouping.
CreateStatMetadata(XPlane * plane)49 void CreateStatMetadata(XPlane* plane) {
50 XPlaneBuilder builder(plane);
51 builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId));
52 builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
53 builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kIsEager));
54 }
55
56 // Returns event type if it is a KernelLaunch or KernelExecute event.
GetKernelEventType(bool is_host_plane,const XEventVisitor & event)57 std::optional<int64_t> GetKernelEventType(bool is_host_plane,
58 const XEventVisitor& event) {
59 if (event.GetStat(StatType::kCorrelationId).has_value()) {
60 return is_host_plane ? HostEventType::kKernelLaunch
61 : HostEventType::kKernelExecute;
62 }
63 return std::nullopt;
64 }
65
GetEventType(bool is_host_plane,const XEventVisitor & event)66 int64_t GetEventType(bool is_host_plane, const XEventVisitor& event) {
67 if (std::optional<int64_t> event_type = event.Type()) {
68 return *event_type;
69 } else if (std::optional<int64_t> kernel_event_type =
70 GetKernelEventType(is_host_plane, event)) {
71 // KernelLaunch and KernelExecute event types are not supported by
72 // XPlaneVisitor and should be checked separately.
73 // TODO(b/148346217): Make XPlaneVisitor support KernelLaunch and
74 // KernelExecute event types.
75 return *kernel_event_type;
76 } else {
77 return HostEventType::kUnknownHostEventType;
78 }
79 }
80
IsLegacyProducerEvent(const XEventVisitor & event)81 bool IsLegacyProducerEvent(const XEventVisitor& event) {
82 static const auto* const kProducerEvents = new absl::flat_hash_set<int64_t>{
83 HostEventType::kTraceContext, HostEventType::kFunctionRun,
84 HostEventType::kSessionRun, HostEventType::kRunGraph};
85 return event.Type().has_value() && kProducerEvents->contains(*event.Type());
86 }
87
IsLegacyConsumerEvent(const XEventVisitor & event)88 bool IsLegacyConsumerEvent(const XEventVisitor& event) {
89 static const auto* const kConsumerEvents = new absl::flat_hash_set<int64_t>{
90 HostEventType::kExecutorStateProcess,
91 HostEventType::kExecutorDoneCallback, HostEventType::kRunGraphDone};
92 return event.Type().has_value() && kConsumerEvents->contains(*event.Type());
93 }
94
IsLegacyRootEvent(const XEventVisitor & event)95 bool IsLegacyRootEvent(const XEventVisitor& event) {
96 static const auto* const kRootEvents = new absl::flat_hash_set<int64_t>{
97 HostEventType::kTraceContext, HostEventType::kFunctionRun,
98 HostEventType::kSessionRun, HostEventType::kRunGraph};
99 return event.Type().has_value() && kRootEvents->contains(*event.Type());
100 }
101
102 // Stats used in ConnectIntraThread.
103 struct GroupingEventStats {
104 explicit GroupingEventStats(const XEventVisitor& event);
105
106 std::optional<int> producer_type;
107 std::optional<uint64_t> producer_id;
108 std::optional<int> consumer_type;
109 std::optional<uint64_t> consumer_id;
110 std::optional<int> root_level;
111 bool is_async = false;
112 };
113
GroupingEventStats(const XEventVisitor & event)114 GroupingEventStats::GroupingEventStats(const XEventVisitor& event) {
115 std::optional<int64_t> step_id;
116 event.ForEachStat([&](const XStatVisitor& stat) {
117 if (!stat.Type().has_value()) return;
118 switch (*stat.Type()) {
119 case StatType::kProducerType:
120 producer_type = stat.IntValue();
121 break;
122 case StatType::kProducerId:
123 producer_id = stat.IntOrUintValue();
124 break;
125 case StatType::kConsumerType:
126 consumer_type = stat.IntValue();
127 break;
128 case StatType::kConsumerId:
129 consumer_id = stat.IntOrUintValue();
130 break;
131 case StatType::kIsRoot:
132 root_level = stat.IntValue();
133 break;
134 case StatType::kIsAsync:
135 is_async = stat.BoolValue();
136 break;
137 case StatType::kStepId:
138 step_id = stat.IntValue();
139 break;
140 default:
141 break;
142 }
143 });
144 if (!producer_type.has_value() || !producer_id.has_value()) {
145 if (step_id.has_value() && IsLegacyProducerEvent(event)) {
146 producer_type = static_cast<int>(ContextType::kTfExecutor);
147 producer_id = *step_id;
148 }
149 }
150 if (!consumer_type.has_value() || !consumer_id.has_value()) {
151 if (step_id.has_value() && IsLegacyConsumerEvent(event)) {
152 consumer_type = static_cast<int>(ContextType::kTfExecutor);
153 consumer_id = *step_id;
154 }
155 }
156 if (!root_level.has_value() && IsLegacyRootEvent(event)) {
157 root_level = 1;
158 }
159 }
160
SetContextGroup(const GroupingEventStats & stats,EventNode * event,ContextGroupMap * context_groups)161 void SetContextGroup(const GroupingEventStats& stats, EventNode* event,
162 ContextGroupMap* context_groups) {
163 if (stats.producer_type.has_value() && stats.producer_id.has_value()) {
164 ((*context_groups)[*stats.producer_type][*stats.producer_id])
165 .producers.push_back(event);
166 }
167 if (stats.consumer_type.has_value() && stats.consumer_id.has_value()) {
168 ((*context_groups)[*stats.consumer_type][*stats.consumer_id])
169 .consumers.push_back(event);
170 }
171 }
172
ConnectContextGroups(const ContextGroupMap & context_groups)173 void ConnectContextGroups(const ContextGroupMap& context_groups) {
174 for (auto& type_id_group : context_groups) {
175 for (auto& id_group : type_id_group.second) {
176 const ContextGroup& group = id_group.second;
177 for (EventNode* parent : group.producers) {
178 for (EventNode* child : group.consumers) {
179 parent->AddChild(child);
180 }
181 }
182 }
183 }
184 }
185
HasFunctionRun(EventNode * event_node)186 bool HasFunctionRun(EventNode* event_node) {
187 for (EventNode* child : event_node->GetChildren()) {
188 if (child->GetEventVisitor().Type() == HostEventType::kFunctionRun) {
189 return true;
190 }
191 }
192 return false;
193 }
194
IsImplicitRootEvent(const XEventVisitor & event)195 bool IsImplicitRootEvent(const XEventVisitor& event) {
196 static const auto* const kImplicitRootEvents =
197 new absl::flat_hash_set<int64_t>{
198 HostEventType::kFunctionRun, HostEventType::kSessionRun,
199 HostEventType::kRunGraph, HostEventType::kExecutorStateProcess};
200 return event.Type().has_value() &&
201 kImplicitRootEvents->contains(*event.Type());
202 }
203
ProcessRootEvent(int64_t group_id,EventNode * root_event,GroupMetadataMap * group_metadata_map)204 void ProcessRootEvent(int64_t group_id, EventNode* root_event,
205 GroupMetadataMap* group_metadata_map) {
206 root_event->PropagateGroupId(group_id, group_metadata_map);
207 std::string group_name = root_event->GetGroupName();
208 // TODO(b/160255693): Change the event name instead.
209 if (!IsImplicitRootEvent(root_event->GetEventVisitor())) {
210 // Add the `step_name` stat for the user-defined root events only. When an
211 // XEvent is converted to a trace event, the trace event name is set to the
212 // `step_name` stat's value if present.
213 root_event->AddStepName(group_name);
214 }
215 (*group_metadata_map)[group_id].name = std::move(group_name);
216 }
217
218 using Comparator = std::function<bool(const EventNode*)>;
219
FindParentWithComparator(const Comparator & comparator,const EventNode * node,bool include_self)220 const EventNode* FindParentWithComparator(const Comparator& comparator,
221 const EventNode* node,
222 bool include_self) {
223 std::queue<const EventNode*> nodes;
224 absl::flat_hash_set<const EventNode*> seen = {node};
225 if (include_self) {
226 nodes.push(node);
227 } else {
228 for (const EventNode* parent : node->GetParents()) {
229 nodes.push(parent);
230 seen.insert(parent);
231 }
232 }
233 while (!nodes.empty()) {
234 const EventNode* node = nodes.front();
235 nodes.pop();
236 if (comparator(node)) return node;
237 for (const EventNode* parent : node->GetParents()) {
238 if (seen.contains(parent)) continue;
239 nodes.push(parent);
240 seen.insert(parent);
241 }
242 }
243 return nullptr;
244 }
245
246 // Returns true if it has JAX-related events.
HasJaxEvent(const EventNodeMap & event_node_map)247 bool HasJaxEvent(const EventNodeMap& event_node_map) {
248 return event_node_map.contains(HostEventType::kExecuteOnLocalDevices);
249 }
250
IsIteratorEventType(absl::optional<int64_t> event_type)251 bool IsIteratorEventType(absl::optional<int64_t> event_type) {
252 return event_type == HostEventType::kIterator ||
253 event_type == HostEventType::kDeviceInputPipelineSecondIterator;
254 }
255
256 } // namespace
257
258 // Returns true if TF's loop ops exist in the given XSpace's metadata.
CheckLoopOp(const XSpace & space)259 bool CheckLoopOp(const XSpace& space) {
260 for (const XPlane& plane : space.planes()) {
261 for (const auto& event_metadata : plane.event_metadata()) {
262 absl::optional<int64_t> event_type =
263 FindHostEventType(event_metadata.second.name());
264 if (!event_type.has_value()) continue;
265 switch (*event_type) {
266 case HostEventType::kWhileOpEvalCond:
267 case HostEventType::kWhileOpStartBody:
268 case HostEventType::kForOp:
269 case HostEventType::kParallelForOp:
270 case HostEventType::kForeverOp:
271 return true;
272 default:
273 break;
274 }
275 }
276 }
277 return false;
278 }
279
GetContextStat(int64_t stat_type) const280 absl::optional<XStatVisitor> EventNode::GetContextStat(
281 int64_t stat_type) const {
282 std::queue<const EventNode*> nodes;
283 absl::flat_hash_set<const EventNode*> seen = {this};
284 nodes.push(this);
285 while (!nodes.empty()) {
286 const EventNode* node = nodes.front();
287 nodes.pop();
288 if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
289 return stat;
290 }
291 for (const EventNode* parent : node->GetParents()) {
292 if (seen.contains(parent)) continue;
293 nodes.push(parent);
294 seen.insert(parent);
295 }
296 }
297 return absl::nullopt;
298 }
299
GetGroupName() const300 std::string EventNode::GetGroupName() const {
301 std::string name;
302 if (absl::optional<XStatVisitor> stat =
303 GetContextStat(StatType::kGraphType)) {
304 absl::StrAppend(&name, stat->StrOrRefValue(), " ");
305 } else if (!(IsImplicitRootEvent(visitor_))) {
306 absl::StrAppend(&name, GetEventVisitor().Name(), " ");
307 }
308 int64_t step_num = group_id_.value_or(0);
309 if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {
310 step_num = stat->IntValue();
311 } else if (absl::optional<XStatVisitor> stat =
312 GetContextStat(StatType::kStepNum)) {
313 step_num = stat->IntValue();
314 }
315 absl::StrAppend(&name, step_num);
316 return name;
317 }
318
FindOrAddStatByType(int64_t stat_type)319 XStat* EventNode::FindOrAddStatByType(int64_t stat_type) {
320 const XPlaneVisitor& plane = visitor_.Plane();
321 const XStatMetadata* stat_metadata = plane.GetStatMetadataByType(stat_type);
322 DCHECK(stat_metadata != nullptr);
323 auto* raw_event = const_cast<XEvent*>(&visitor_.RawEvent()); // NOLINT
324 return FindOrAddMutableStat(*stat_metadata, raw_event);
325 }
326
SetGroupId(int64_t group_id)327 void EventNode::SetGroupId(int64_t group_id) {
328 group_id_ = group_id;
329 FindOrAddStatByType(StatType::kGroupId)->set_int64_value(group_id);
330 }
331
PropagateGroupId(int64_t group_id,GroupMetadataMap * group_metadata_map)332 void EventNode::PropagateGroupId(int64_t group_id,
333 GroupMetadataMap* group_metadata_map) {
334 std::queue<EventNode*> nodes;
335 absl::flat_hash_set<EventNode*> seen = {this};
336 nodes.push(this);
337 while (!nodes.empty()) {
338 EventNode* node = nodes.front();
339 nodes.pop();
340 absl::optional<int64_t> node_group_id = node->GetGroupId();
341 if (node_group_id.has_value()) {
342 if (*node_group_id != group_id) {
343 (*group_metadata_map)[group_id].children.insert(*node_group_id);
344 (*group_metadata_map)[*node_group_id].parents.insert(group_id);
345 }
346 } else {
347 node->SetGroupId(group_id);
348 for (EventNode* child : node->GetChildren()) {
349 if (seen.contains(child)) continue;
350 nodes.push(child);
351 seen.insert(child);
352 }
353 }
354 }
355 }
356
AddStepName(absl::string_view step_name)357 void EventNode::AddStepName(absl::string_view step_name) {
358 FindOrAddStatByType(StatType::kStepName)
359 ->set_str_value(step_name.data(), step_name.size());
360 }
361
SetIsEager(bool is_eager)362 void EventNode::SetIsEager(bool is_eager) {
363 FindOrAddStatByType(StatType::kIsEager)->set_int64_value(is_eager ? 1 : 0);
364 }
365
IsCompiledFunc() const366 bool EventNode::IsCompiledFunc() const {
367 auto is_func = visitor_.GetStat(StatType::kIsFunc);
368 return !is_func || is_func->IntValue();
369 }
370
IsEager() const371 bool EventNode::IsEager() const {
372 /* Both eager mode (op-by-op) and non-eager mode (eager functions) of eager
373 * executions are unified and forward to TF1 executor now. Therefore we will
374 * check following conditions:
375 */
376 const EventNode* node = FindParent(HostEventType::kEagerKernelExecute);
377 if (node == nullptr) {
378 // if current op is NOT scheduled under "EagerExecute", likely this is
379 // from TF1, therefore not eager.
380 return false;
381 }
382
383 // Otherwise, it is eager mode execution of an operation if and only if it is
384 // not a eager mode execution of a compiled function.
385 return !node->IsCompiledFunc();
386 }
387
FindParent(int64_t event_type) const388 const EventNode* EventNode::FindParent(int64_t event_type) const {
389 return FindParentWithComparator(
390 [event_type](const EventNode* node) {
391 return node->GetEventVisitor().Type() == event_type;
392 },
393 this, /*include_self=*/true);
394 }
395
ConnectIntraThread(XPlane * plane,XPlaneVisitor * visitor,ContextGroupMap * context_groups)396 void EventForest::ConnectIntraThread(XPlane* plane, XPlaneVisitor* visitor,
397 ContextGroupMap* context_groups) {
398 // TODO(b/149095099): avoid string comparison.
399 bool is_host_plane = (visitor->Name() == kHostThreadsPlaneName);
400 for (auto& line : *plane->mutable_lines()) {
401 std::vector<EventNode*> parent_nodes;
402 for (auto& event : *line.mutable_events()) {
403 XEventVisitor event_visitor(visitor, &line, &event);
404 int64_t event_type = GetEventType(is_host_plane, event_visitor);
405 EventNode* cur_node =
406 &event_node_map_[event_type].emplace_back(std::move(event_visitor));
407 GroupingEventStats stats(cur_node->GetEventVisitor());
408 if (stats.root_level.has_value()) {
409 cur_node->SetRootLevel(*stats.root_level);
410 }
411 // Update `context_groups` for `ConnectInterThread`.
412 SetContextGroup(stats, cur_node, context_groups);
413 // Async events are ignored when processing the nesting relationship.
414 if (!stats.is_async) {
415 while (!parent_nodes.empty()) {
416 EventNode* parent_node = parent_nodes.back();
417 if (parent_node->GetEventVisitor().GetTimespan().Includes(
418 cur_node->GetEventVisitor().GetTimespan())) {
419 parent_node->AddChild(cur_node);
420 break;
421 } else {
422 parent_nodes.pop_back();
423 }
424 }
425 parent_nodes.push_back(cur_node);
426 }
427 }
428 }
429 }
430
ConnectInterThread(const std::vector<InterThreadConnectInfo> & connect_info_list)431 void EventForest::ConnectInterThread(
432 const std::vector<InterThreadConnectInfo>& connect_info_list) {
433 for (const auto& connect_info : connect_info_list) {
434 absl::flat_hash_map<std::vector<uint64>, EventNode*> connect_map;
435 const std::vector<int64_t>& parent_stat_types =
436 connect_info.parent_stat_types;
437 const std::vector<int64_t>* child_stat_types =
438 &connect_info.child_stat_types;
439 if (child_stat_types->empty()) {
440 child_stat_types = &parent_stat_types;
441 }
442 if (auto parent_event_node_list =
443 gtl::FindOrNull(event_node_map_, connect_info.parent_event_type)) {
444 for (EventNode& parent_event_node : *parent_event_node_list) {
445 std::vector<uint64> stats;
446 for (auto stat_type : parent_stat_types) {
447 absl::optional<XStatVisitor> stat =
448 parent_event_node.GetContextStat(stat_type);
449 if (!stat) break;
450 stats.push_back(stat->IntOrUintValue());
451 }
452 if (stats.size() == parent_stat_types.size()) {
453 connect_map[stats] = &parent_event_node;
454 }
455 }
456 }
457 if (auto child_event_node_list =
458 gtl::FindOrNull(event_node_map_, connect_info.child_event_type)) {
459 for (EventNode& child_event_node : *child_event_node_list) {
460 std::vector<uint64> stats;
461 for (auto stat_type : *child_stat_types) {
462 absl::optional<XStatVisitor> stat =
463 child_event_node.GetContextStat(stat_type);
464 if (!stat) break;
465 stats.push_back(stat->IntOrUintValue());
466 }
467 if (stats.size() == child_stat_types->size()) {
468 if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
469 parent_event_node->AddChild(&child_event_node);
470 }
471 }
472 }
473 }
474 }
475 }
476
477 // Returns whether a root event needs grouping.
RootNeedsGrouping(const EventNode * root)478 bool RootNeedsGrouping(const EventNode* root) {
479 // No grouping is needed if it is already grouped.
480 if (root->GetGroupId().has_value()) return false;
481 // If there is a parent node with the same root level, skip grouping at <root>
482 // and later apply grouping at the parent node.
483 // If there is a parent node with a different root level, apply grouping at
484 // <root>, and later apply grouping at the parent node. Root events with
485 // different levels are grouped separately.
486 const EventNode* root_parent = FindParentWithComparator(
487 [root](const EventNode* parent) {
488 return parent->RootLevel() == root->RootLevel();
489 },
490 root,
491 /*include_self=*/false);
492 return root_parent == nullptr;
493 }
494
495 // Sorts root events based on root level and timestamp.
SortRootEventList(EventList * event_list)496 void SortRootEventList(EventList* event_list) {
497 absl::c_sort(*event_list, [](const EventNode* e1, const EventNode* e2) {
498 // If two root events have the same root level, the root event with an
499 // earlier timestamp will be processed first. Otherwise, the event with a
500 // larger root level will be processed first.
501 return e1->RootLevel() == e2->RootLevel()
502 ? *e1 < *e2
503 : e1->RootLevel() > e2->RootLevel();
504 });
505 }
506
CreateEventGroups()507 void EventForest::CreateEventGroups() {
508 // Create a group for each TF loop iteration in non-JAX profiles.
509 int64_t group_id = 0;
510 if (!HasJaxEvent(event_node_map_) && !tf_loop_root_events_.empty()) {
511 for (EventNode* root_event : tf_loop_root_events_) {
512 ProcessRootEvent(group_id++, root_event, &group_metadata_map_);
513 }
514 return;
515 }
516
517 // Iterate over all events and collect all root events.
518 EventList root_events;
519 for (auto& [event_type, events] : event_node_map_) {
520 for (EventNode& event : events) {
521 if (!event.RootLevel()) continue;
522 absl::optional<XStatVisitor> step_id_stat =
523 event.GetEventVisitor().GetStat(StatType::kStepId);
524 // If this is a root event that associated with tf.data, skip.
525 if (step_id_stat && tf_data_step_ids_.contains(step_id_stat->IntValue()))
526 continue;
527 root_events.push_back(&event);
528 }
529 }
530
531 SortRootEventList(&root_events);
532
533 for (EventNode* root_event : root_events) {
534 if (RootNeedsGrouping(root_event) &&
535 // Ignores legacy TF root events for JAX profiles.
536 (!HasJaxEvent(event_node_map_) ||
537 !IsLegacyRootEvent(root_event->GetEventVisitor()))) {
538 ProcessRootEvent(group_id++, root_event, &group_metadata_map_);
539 }
540 }
541 }
542
MarkEagerlyExecutedGpuKernels()543 void EventForest::MarkEagerlyExecutedGpuKernels() {
544 auto kernel_execute_event_node_list =
545 gtl::FindOrNull(event_node_map_, HostEventType::kKernelExecute);
546 if (!kernel_execute_event_node_list) return;
547 for (EventNode& kernel_execute_event_node : *kernel_execute_event_node_list) {
548 kernel_execute_event_node.SetIsEager(kernel_execute_event_node.IsEager());
549 }
550 }
551
MarkEagerlyExecutedCpuTfOps()552 void EventForest::MarkEagerlyExecutedCpuTfOps() {
553 auto tf_op_run_event_node_list =
554 gtl::FindOrNull(event_node_map_, HostEventType::kTfOpRun);
555 if (!tf_op_run_event_node_list) return;
556 for (EventNode& tf_op_run_event_node : *tf_op_run_event_node_list) {
557 tf_op_run_event_node.SetIsEager(tf_op_run_event_node.IsEager());
558 }
559 }
560
ProcessTfDataSteps()561 void EventForest::ProcessTfDataSteps() {
562 const int64_t tf_data_event_types[] = {
563 HostEventType::kTfDataCapturedFunctionRun,
564 HostEventType::kTfDataCapturedFunctionRunAsync,
565 HostEventType::kTfDataCapturedFunctionRunInstantiated,
566 HostEventType::kTfDataCapturedFunctionRunWithBorrowedArgs};
567 for (const int64_t tf_data_event_type : tf_data_event_types) {
568 auto tf_data_events = gtl::FindOrNull(event_node_map_, tf_data_event_type);
569 if (!tf_data_events) continue;
570 for (const EventNode& tf_data_event : *tf_data_events) {
571 absl::optional<XStatVisitor> step_id_stat =
572 tf_data_event.GetEventVisitor().GetStat(StatType::kStepId);
573 if (!step_id_stat) continue;
574 tf_data_step_ids_.insert(step_id_stat->IntValue());
575 }
576 }
577 }
578
ProcessTensorFlowLoop()579 void EventForest::ProcessTensorFlowLoop() {
580 struct TensorFlowLoopIteration {
581 EventNode* first_event = nullptr;
582 std::vector<EventNode*> events;
583 };
584 using TensorFlowLoop =
585 absl::flat_hash_map<int64_t /*iter_num*/, TensorFlowLoopIteration>;
586 absl::flat_hash_map<int64_t /*step_id*/, TensorFlowLoop> tf_loops;
587
588 // Sort the TF executor events by TF function/session (step_id) and iter_num.
589 auto executor_event_list =
590 gtl::FindOrNull(event_node_map_, HostEventType::kExecutorStateProcess);
591 if (!executor_event_list) return;
592 for (EventNode& executor_event : *executor_event_list) {
593 absl::optional<XStatVisitor> step_id_stat =
594 executor_event.GetEventVisitor().GetStat(StatType::kStepId);
595 absl::optional<XStatVisitor> iter_num_stat =
596 executor_event.GetEventVisitor().GetStat(StatType::kIterNum);
597 if (!step_id_stat || !iter_num_stat) continue;
598 int64_t step_id = step_id_stat->IntValue();
599 // Skip tf.data events.
600 if (tf_data_step_ids_.contains(step_id)) continue;
601 TensorFlowLoop& tf_loop = tf_loops[step_id];
602 TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->IntValue()];
603 if (!iteration.first_event || executor_event < *iteration.first_event) {
604 iteration.first_event = &executor_event;
605 }
606 iteration.events.push_back(&executor_event);
607 }
608
609 std::vector<const TensorFlowLoopIteration*> iters;
610 for (const auto& step_id_and_tf_loop : tf_loops) {
611 const TensorFlowLoop& tf_loop = step_id_and_tf_loop.second;
612 // Filter out TF function/session without loops.
613 if (tf_loop.size() == 1 && tf_loop.contains(0)) continue;
614 for (const auto& iter_num_and_iter : tf_loop) {
615 iters.push_back(&iter_num_and_iter.second);
616 }
617 }
618
619 // Sort iterations based on timestamp of the first event in the iteration.
620 absl::c_sort(iters, [](const auto& iter1, const auto& iter2) {
621 return *iter1->first_event < *iter2->first_event;
622 });
623
624 // Register the first event of each iteration as a root event. Also, add the
625 // other events of the iteration as child to the root event.
626 for (const TensorFlowLoopIteration* iter : iters) {
627 EventNode* root_event = iter->first_event;
628 tf_loop_root_events_.push_back(root_event);
629 for (EventNode* event : iter->events) {
630 if (event == root_event) continue;
631 root_event->AddChild(event);
632 }
633 }
634 }
635
ProcessWorker()636 void EventForest::ProcessWorker() {
637 auto eager_kernel_execute_event_list =
638 gtl::FindOrNull(event_node_map_, HostEventType::kEagerKernelExecute);
639 if (!eager_kernel_execute_event_list) return;
640 // The last EagerKernelExecute with a FunctionRun child.
641 EventNode* root_event = nullptr;
642 for (EventNode& eager_kernel_execute_event :
643 *eager_kernel_execute_event_list) {
644 if (HasFunctionRun(&eager_kernel_execute_event)) {
645 // A function op becomes a new root.
646 root_event = &eager_kernel_execute_event;
647 root_event->SetRootLevel(1);
648 } else if (root_event) {
649 // Add non-function eager ops as child.
650 root_event->AddChild(&eager_kernel_execute_event);
651 }
652 }
653 }
654
AddPlane(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,XPlane * plane)655 void EventForest::AddPlane(
656 const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
657 XPlane* plane) {
658 CreateStatMetadata(plane);
659 planes_.push_back({plane, visitor_factory(plane)});
660 }
661
AddSpace(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,XSpace * space)662 void EventForest::AddSpace(
663 const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
664 XSpace* space) {
665 for (XPlane& plane : *space->mutable_planes()) {
666 AddPlane(visitor_factory, &plane);
667 }
668 }
669
AddPlanes(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,const std::vector<XPlane * > & planes)670 void EventForest::AddPlanes(
671 const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
672 const std::vector<XPlane*>& planes) {
673 for (XPlane* plane : planes) {
674 AddPlane(visitor_factory, plane);
675 }
676 }
677
ConnectEvents(const std::vector<InterThreadConnectInfo> & connect_info_list)678 void EventForest::ConnectEvents(
679 const std::vector<InterThreadConnectInfo>& connect_info_list) {
680 ContextGroupMap context_groups;
681 for (auto& plane_visitor : planes_) {
682 ConnectIntraThread(plane_visitor.first, &plane_visitor.second,
683 &context_groups);
684 }
685 ConnectInterThread(connect_info_list);
686 ConnectContextGroups(context_groups);
687 }
688
ConnectTfDataEvents()689 void EventForest::ConnectTfDataEvents() {
690 absl::flat_hash_map<
691 std::pair<int64_t /*iterator_id*/, int64_t /*element_id*/>,
692 std::vector<EventNode*>>
693 produce_iterator_map;
694 uint64 num_producers = 0;
695 for (HostEventType event_type :
696 {HostEventType::kPrefetchProduce,
697 HostEventType::kParallelInterleaveProduce,
698 HostEventType::kParallelMapProduce, HostEventType::kMapAndBatchProduce,
699 HostEventType::kParseExampleProduce,
700 HostEventType::kParallelBatchProduce}) {
701 auto produce_event_list = gtl::FindOrNull(event_node_map_, event_type);
702 if (!produce_event_list) continue;
703 VLOG(1) << produce_event_list->size() << " "
704 << GetHostEventTypeStr(event_type) << " events found.";
705 for (EventNode& produce_event : *produce_event_list) {
706 absl::optional<XStatVisitor> element_id =
707 produce_event.GetEventVisitor().GetStat(StatType::kElementId);
708 if (!element_id.has_value()) continue;
709 for (EventNode* produce_iterator : produce_event.GetChildren()) {
710 if (IsIteratorEventType(produce_iterator->GetEventVisitor().Type())) {
711 absl::optional<XStatVisitor> iterator_id =
712 produce_iterator->GetEventVisitor().GetStat(StatType::kParentId);
713 if (!iterator_id.has_value()) break;
714 produce_iterator_map[{iterator_id->IntValue(),
715 element_id->IntValue()}]
716 .push_back(produce_iterator);
717 ++num_producers;
718 break;
719 }
720 }
721 }
722 }
723 VLOG(1) << num_producers << " producer iterators found.";
724 uint64 num_matched = 0;
725 for (HostEventType event_type :
726 {HostEventType::kPrefetchConsume,
727 HostEventType::kParallelInterleaveConsume,
728 HostEventType::kParallelMapConsume, HostEventType::kMapAndBatchConsume,
729 HostEventType::kParseExampleConsume,
730 HostEventType::kParallelBatchConsume}) {
731 auto consume_event_list = gtl::FindOrNull(event_node_map_, event_type);
732 if (!consume_event_list) continue;
733 VLOG(1) << consume_event_list->size() << " "
734 << GetHostEventTypeStr(event_type) << " events found.";
735 for (EventNode& consume_event : *consume_event_list) {
736 absl::optional<XStatVisitor> element_id =
737 consume_event.GetEventVisitor().GetStat(StatType::kElementId);
738 if (!element_id.has_value()) continue;
739 if (consume_event.GetParents().empty()) continue;
740 // consume_event is nested by consumer_iterator and does not have other
741 // parents.
742 EventNode* consume_iterator = consume_event.GetParents().at(0);
743 if (!consume_iterator ||
744 !IsIteratorEventType(consume_iterator->GetEventVisitor().Type())) {
745 continue;
746 }
747 absl::optional<XStatVisitor> iterator_id =
748 consume_iterator->GetEventVisitor().GetStat(StatType::kStepId);
749 if (!iterator_id.has_value()) continue;
750 if (auto produce_iterators = gtl::FindOrNull(
751 produce_iterator_map, std::make_pair(iterator_id->IntValue(),
752 element_id->IntValue()))) {
753 for (EventNode* produce_iterator : *produce_iterators) {
754 consume_iterator->AddChild(produce_iterator);
755 ++num_matched;
756 }
757 }
758 }
759 }
760 VLOG(1) << num_matched << " consumer iterators matched.";
761 }
762
GroupEvents()763 void EventForest::GroupEvents() {
764 ProcessTfDataSteps();
765 ProcessTensorFlowLoop();
766 ProcessWorker();
767 CreateEventGroups();
768 MarkEagerlyExecutedGpuKernels();
769 MarkEagerlyExecutedCpuTfOps();
770 }
771
CreateInterThreadConnectInfoList()772 std::vector<InterThreadConnectInfo> CreateInterThreadConnectInfoList() {
773 std::vector<InterThreadConnectInfo> connect_info_list = {
774 {HostEventType::kExecutorStateProcess,
775 HostEventType::kIteratorGetNextOp,
776 {StatType::kStepId, StatType::kIterNum}},
777 {HostEventType::kExecutorStateProcess,
778 HostEventType::kIteratorGetNextAsOptionalOp,
779 {StatType::kStepId, StatType::kIterNum}},
780 {HostEventType::kKernelLaunch,
781 HostEventType::kKernelExecute,
782 {StatType::kCorrelationId}}};
783 return connect_info_list;
784 }
785
GroupTfEvents(XSpace * space,EventForest * event_forest)786 void GroupTfEvents(XSpace* space, EventForest* event_forest) {
787 if (CheckLoopOp(*space)) {
788 // TODO(b/154510598): Support TF's loop ops.
789 return;
790 }
791 std::vector<InterThreadConnectInfo> connect_info_list =
792 CreateInterThreadConnectInfoList();
793 event_forest->AddSpace(CreateTfXPlaneVisitor, space);
794 event_forest->ConnectEvents(connect_info_list);
795 event_forest->GroupEvents();
796 }
797
GroupTfEvents(XSpace * space)798 void GroupTfEvents(XSpace* space) {
799 EventForest event_forest;
800 GroupTfEvents(space, &event_forest);
801 }
802
803 } // namespace profiler
804 } // namespace tensorflow
805