xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/utils/group_events.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/utils/group_events.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <iterator>
22 #include <map>
23 #include <memory>
24 #include <optional>
25 #include <queue>
26 #include <string>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/algorithm/container.h"
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/strings/match.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_join.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/profiler/lib/connected_traceme.h"
38 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
39 #include "tensorflow/core/profiler/utils/xplane_builder.h"
40 #include "tensorflow/core/profiler/utils/xplane_schema.h"
41 #include "tensorflow/core/profiler/utils/xplane_utils.h"
42 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
43 
44 namespace tensorflow {
45 namespace profiler {
46 namespace {
47 
48 // Creates stat metadata for the stats which may be added by grouping.
CreateStatMetadata(XPlane * plane)49 void CreateStatMetadata(XPlane* plane) {
50   XPlaneBuilder builder(plane);
51   builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId));
52   builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
53   builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kIsEager));
54 }
55 
56 // Returns event type if it is a KernelLaunch or KernelExecute event.
GetKernelEventType(bool is_host_plane,const XEventVisitor & event)57 std::optional<int64_t> GetKernelEventType(bool is_host_plane,
58                                           const XEventVisitor& event) {
59   if (event.GetStat(StatType::kCorrelationId).has_value()) {
60     return is_host_plane ? HostEventType::kKernelLaunch
61                          : HostEventType::kKernelExecute;
62   }
63   return std::nullopt;
64 }
65 
GetEventType(bool is_host_plane,const XEventVisitor & event)66 int64_t GetEventType(bool is_host_plane, const XEventVisitor& event) {
67   if (std::optional<int64_t> event_type = event.Type()) {
68     return *event_type;
69   } else if (std::optional<int64_t> kernel_event_type =
70                  GetKernelEventType(is_host_plane, event)) {
71     // KernelLaunch and KernelExecute event types are not supported by
72     // XPlaneVisitor and should be checked separately.
73     // TODO(b/148346217): Make XPlaneVisitor support KernelLaunch and
74     // KernelExecute event types.
75     return *kernel_event_type;
76   } else {
77     return HostEventType::kUnknownHostEventType;
78   }
79 }
80 
IsLegacyProducerEvent(const XEventVisitor & event)81 bool IsLegacyProducerEvent(const XEventVisitor& event) {
82   static const auto* const kProducerEvents = new absl::flat_hash_set<int64_t>{
83       HostEventType::kTraceContext, HostEventType::kFunctionRun,
84       HostEventType::kSessionRun, HostEventType::kRunGraph};
85   return event.Type().has_value() && kProducerEvents->contains(*event.Type());
86 }
87 
IsLegacyConsumerEvent(const XEventVisitor & event)88 bool IsLegacyConsumerEvent(const XEventVisitor& event) {
89   static const auto* const kConsumerEvents = new absl::flat_hash_set<int64_t>{
90       HostEventType::kExecutorStateProcess,
91       HostEventType::kExecutorDoneCallback, HostEventType::kRunGraphDone};
92   return event.Type().has_value() && kConsumerEvents->contains(*event.Type());
93 }
94 
IsLegacyRootEvent(const XEventVisitor & event)95 bool IsLegacyRootEvent(const XEventVisitor& event) {
96   static const auto* const kRootEvents = new absl::flat_hash_set<int64_t>{
97       HostEventType::kTraceContext, HostEventType::kFunctionRun,
98       HostEventType::kSessionRun, HostEventType::kRunGraph};
99   return event.Type().has_value() && kRootEvents->contains(*event.Type());
100 }
101 
102 // Stats used in ConnectIntraThread.
103 struct GroupingEventStats {
104   explicit GroupingEventStats(const XEventVisitor& event);
105 
106   std::optional<int> producer_type;
107   std::optional<uint64_t> producer_id;
108   std::optional<int> consumer_type;
109   std::optional<uint64_t> consumer_id;
110   std::optional<int> root_level;
111   bool is_async = false;
112 };
113 
GroupingEventStats(const XEventVisitor & event)114 GroupingEventStats::GroupingEventStats(const XEventVisitor& event) {
115   std::optional<int64_t> step_id;
116   event.ForEachStat([&](const XStatVisitor& stat) {
117     if (!stat.Type().has_value()) return;
118     switch (*stat.Type()) {
119       case StatType::kProducerType:
120         producer_type = stat.IntValue();
121         break;
122       case StatType::kProducerId:
123         producer_id = stat.IntOrUintValue();
124         break;
125       case StatType::kConsumerType:
126         consumer_type = stat.IntValue();
127         break;
128       case StatType::kConsumerId:
129         consumer_id = stat.IntOrUintValue();
130         break;
131       case StatType::kIsRoot:
132         root_level = stat.IntValue();
133         break;
134       case StatType::kIsAsync:
135         is_async = stat.BoolValue();
136         break;
137       case StatType::kStepId:
138         step_id = stat.IntValue();
139         break;
140       default:
141         break;
142     }
143   });
144   if (!producer_type.has_value() || !producer_id.has_value()) {
145     if (step_id.has_value() && IsLegacyProducerEvent(event)) {
146       producer_type = static_cast<int>(ContextType::kTfExecutor);
147       producer_id = *step_id;
148     }
149   }
150   if (!consumer_type.has_value() || !consumer_id.has_value()) {
151     if (step_id.has_value() && IsLegacyConsumerEvent(event)) {
152       consumer_type = static_cast<int>(ContextType::kTfExecutor);
153       consumer_id = *step_id;
154     }
155   }
156   if (!root_level.has_value() && IsLegacyRootEvent(event)) {
157     root_level = 1;
158   }
159 }
160 
SetContextGroup(const GroupingEventStats & stats,EventNode * event,ContextGroupMap * context_groups)161 void SetContextGroup(const GroupingEventStats& stats, EventNode* event,
162                      ContextGroupMap* context_groups) {
163   if (stats.producer_type.has_value() && stats.producer_id.has_value()) {
164     ((*context_groups)[*stats.producer_type][*stats.producer_id])
165         .producers.push_back(event);
166   }
167   if (stats.consumer_type.has_value() && stats.consumer_id.has_value()) {
168     ((*context_groups)[*stats.consumer_type][*stats.consumer_id])
169         .consumers.push_back(event);
170   }
171 }
172 
ConnectContextGroups(const ContextGroupMap & context_groups)173 void ConnectContextGroups(const ContextGroupMap& context_groups) {
174   for (auto& type_id_group : context_groups) {
175     for (auto& id_group : type_id_group.second) {
176       const ContextGroup& group = id_group.second;
177       for (EventNode* parent : group.producers) {
178         for (EventNode* child : group.consumers) {
179           parent->AddChild(child);
180         }
181       }
182     }
183   }
184 }
185 
HasFunctionRun(EventNode * event_node)186 bool HasFunctionRun(EventNode* event_node) {
187   for (EventNode* child : event_node->GetChildren()) {
188     if (child->GetEventVisitor().Type() == HostEventType::kFunctionRun) {
189       return true;
190     }
191   }
192   return false;
193 }
194 
IsImplicitRootEvent(const XEventVisitor & event)195 bool IsImplicitRootEvent(const XEventVisitor& event) {
196   static const auto* const kImplicitRootEvents =
197       new absl::flat_hash_set<int64_t>{
198           HostEventType::kFunctionRun, HostEventType::kSessionRun,
199           HostEventType::kRunGraph, HostEventType::kExecutorStateProcess};
200   return event.Type().has_value() &&
201          kImplicitRootEvents->contains(*event.Type());
202 }
203 
ProcessRootEvent(int64_t group_id,EventNode * root_event,GroupMetadataMap * group_metadata_map)204 void ProcessRootEvent(int64_t group_id, EventNode* root_event,
205                       GroupMetadataMap* group_metadata_map) {
206   root_event->PropagateGroupId(group_id, group_metadata_map);
207   std::string group_name = root_event->GetGroupName();
208   // TODO(b/160255693): Change the event name instead.
209   if (!IsImplicitRootEvent(root_event->GetEventVisitor())) {
210     // Add the `step_name` stat for the user-defined root events only. When an
211     // XEvent is converted to a trace event, the trace event name is set to the
212     // `step_name` stat's value if present.
213     root_event->AddStepName(group_name);
214   }
215   (*group_metadata_map)[group_id].name = std::move(group_name);
216 }
217 
218 using Comparator = std::function<bool(const EventNode*)>;
219 
FindParentWithComparator(const Comparator & comparator,const EventNode * node,bool include_self)220 const EventNode* FindParentWithComparator(const Comparator& comparator,
221                                           const EventNode* node,
222                                           bool include_self) {
223   std::queue<const EventNode*> nodes;
224   absl::flat_hash_set<const EventNode*> seen = {node};
225   if (include_self) {
226     nodes.push(node);
227   } else {
228     for (const EventNode* parent : node->GetParents()) {
229       nodes.push(parent);
230       seen.insert(parent);
231     }
232   }
233   while (!nodes.empty()) {
234     const EventNode* node = nodes.front();
235     nodes.pop();
236     if (comparator(node)) return node;
237     for (const EventNode* parent : node->GetParents()) {
238       if (seen.contains(parent)) continue;
239       nodes.push(parent);
240       seen.insert(parent);
241     }
242   }
243   return nullptr;
244 }
245 
246 // Returns true if it has JAX-related events.
HasJaxEvent(const EventNodeMap & event_node_map)247 bool HasJaxEvent(const EventNodeMap& event_node_map) {
248   return event_node_map.contains(HostEventType::kExecuteOnLocalDevices);
249 }
250 
IsIteratorEventType(absl::optional<int64_t> event_type)251 bool IsIteratorEventType(absl::optional<int64_t> event_type) {
252   return event_type == HostEventType::kIterator ||
253          event_type == HostEventType::kDeviceInputPipelineSecondIterator;
254 }
255 
256 }  // namespace
257 
258 // Returns true if TF's loop ops exist in the given XSpace's metadata.
CheckLoopOp(const XSpace & space)259 bool CheckLoopOp(const XSpace& space) {
260   for (const XPlane& plane : space.planes()) {
261     for (const auto& event_metadata : plane.event_metadata()) {
262       absl::optional<int64_t> event_type =
263           FindHostEventType(event_metadata.second.name());
264       if (!event_type.has_value()) continue;
265       switch (*event_type) {
266         case HostEventType::kWhileOpEvalCond:
267         case HostEventType::kWhileOpStartBody:
268         case HostEventType::kForOp:
269         case HostEventType::kParallelForOp:
270         case HostEventType::kForeverOp:
271           return true;
272         default:
273           break;
274       }
275     }
276   }
277   return false;
278 }
279 
GetContextStat(int64_t stat_type) const280 absl::optional<XStatVisitor> EventNode::GetContextStat(
281     int64_t stat_type) const {
282   std::queue<const EventNode*> nodes;
283   absl::flat_hash_set<const EventNode*> seen = {this};
284   nodes.push(this);
285   while (!nodes.empty()) {
286     const EventNode* node = nodes.front();
287     nodes.pop();
288     if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
289       return stat;
290     }
291     for (const EventNode* parent : node->GetParents()) {
292       if (seen.contains(parent)) continue;
293       nodes.push(parent);
294       seen.insert(parent);
295     }
296   }
297   return absl::nullopt;
298 }
299 
GetGroupName() const300 std::string EventNode::GetGroupName() const {
301   std::string name;
302   if (absl::optional<XStatVisitor> stat =
303           GetContextStat(StatType::kGraphType)) {
304     absl::StrAppend(&name, stat->StrOrRefValue(), " ");
305   } else if (!(IsImplicitRootEvent(visitor_))) {
306     absl::StrAppend(&name, GetEventVisitor().Name(), " ");
307   }
308   int64_t step_num = group_id_.value_or(0);
309   if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {
310     step_num = stat->IntValue();
311   } else if (absl::optional<XStatVisitor> stat =
312                  GetContextStat(StatType::kStepNum)) {
313     step_num = stat->IntValue();
314   }
315   absl::StrAppend(&name, step_num);
316   return name;
317 }
318 
FindOrAddStatByType(int64_t stat_type)319 XStat* EventNode::FindOrAddStatByType(int64_t stat_type) {
320   const XPlaneVisitor& plane = visitor_.Plane();
321   const XStatMetadata* stat_metadata = plane.GetStatMetadataByType(stat_type);
322   DCHECK(stat_metadata != nullptr);
323   auto* raw_event = const_cast<XEvent*>(&visitor_.RawEvent());  // NOLINT
324   return FindOrAddMutableStat(*stat_metadata, raw_event);
325 }
326 
SetGroupId(int64_t group_id)327 void EventNode::SetGroupId(int64_t group_id) {
328   group_id_ = group_id;
329   FindOrAddStatByType(StatType::kGroupId)->set_int64_value(group_id);
330 }
331 
PropagateGroupId(int64_t group_id,GroupMetadataMap * group_metadata_map)332 void EventNode::PropagateGroupId(int64_t group_id,
333                                  GroupMetadataMap* group_metadata_map) {
334   std::queue<EventNode*> nodes;
335   absl::flat_hash_set<EventNode*> seen = {this};
336   nodes.push(this);
337   while (!nodes.empty()) {
338     EventNode* node = nodes.front();
339     nodes.pop();
340     absl::optional<int64_t> node_group_id = node->GetGroupId();
341     if (node_group_id.has_value()) {
342       if (*node_group_id != group_id) {
343         (*group_metadata_map)[group_id].children.insert(*node_group_id);
344         (*group_metadata_map)[*node_group_id].parents.insert(group_id);
345       }
346     } else {
347       node->SetGroupId(group_id);
348       for (EventNode* child : node->GetChildren()) {
349         if (seen.contains(child)) continue;
350         nodes.push(child);
351         seen.insert(child);
352       }
353     }
354   }
355 }
356 
AddStepName(absl::string_view step_name)357 void EventNode::AddStepName(absl::string_view step_name) {
358   FindOrAddStatByType(StatType::kStepName)
359       ->set_str_value(step_name.data(), step_name.size());
360 }
361 
SetIsEager(bool is_eager)362 void EventNode::SetIsEager(bool is_eager) {
363   FindOrAddStatByType(StatType::kIsEager)->set_int64_value(is_eager ? 1 : 0);
364 }
365 
IsCompiledFunc() const366 bool EventNode::IsCompiledFunc() const {
367   auto is_func = visitor_.GetStat(StatType::kIsFunc);
368   return !is_func || is_func->IntValue();
369 }
370 
IsEager() const371 bool EventNode::IsEager() const {
372   /* Both eager mode (op-by-op) and non-eager mode (eager functions) of eager
373    * executions are unified and forward to TF1 executor now. Therefore we will
374    * check following conditions:
375    */
376   const EventNode* node = FindParent(HostEventType::kEagerKernelExecute);
377   if (node == nullptr) {
378     // if current op is NOT scheduled under "EagerExecute", likely this is
379     // from TF1, therefore not eager.
380     return false;
381   }
382 
383   // Otherwise, it is eager mode execution of an operation if and only if it is
384   // not a eager mode execution of a compiled function.
385   return !node->IsCompiledFunc();
386 }
387 
FindParent(int64_t event_type) const388 const EventNode* EventNode::FindParent(int64_t event_type) const {
389   return FindParentWithComparator(
390       [event_type](const EventNode* node) {
391         return node->GetEventVisitor().Type() == event_type;
392       },
393       this, /*include_self=*/true);
394 }
395 
ConnectIntraThread(XPlane * plane,XPlaneVisitor * visitor,ContextGroupMap * context_groups)396 void EventForest::ConnectIntraThread(XPlane* plane, XPlaneVisitor* visitor,
397                                      ContextGroupMap* context_groups) {
398   // TODO(b/149095099): avoid string comparison.
399   bool is_host_plane = (visitor->Name() == kHostThreadsPlaneName);
400   for (auto& line : *plane->mutable_lines()) {
401     std::vector<EventNode*> parent_nodes;
402     for (auto& event : *line.mutable_events()) {
403       XEventVisitor event_visitor(visitor, &line, &event);
404       int64_t event_type = GetEventType(is_host_plane, event_visitor);
405       EventNode* cur_node =
406           &event_node_map_[event_type].emplace_back(std::move(event_visitor));
407       GroupingEventStats stats(cur_node->GetEventVisitor());
408       if (stats.root_level.has_value()) {
409         cur_node->SetRootLevel(*stats.root_level);
410       }
411       // Update `context_groups` for `ConnectInterThread`.
412       SetContextGroup(stats, cur_node, context_groups);
413       // Async events are ignored when processing the nesting relationship.
414       if (!stats.is_async) {
415         while (!parent_nodes.empty()) {
416           EventNode* parent_node = parent_nodes.back();
417           if (parent_node->GetEventVisitor().GetTimespan().Includes(
418                   cur_node->GetEventVisitor().GetTimespan())) {
419             parent_node->AddChild(cur_node);
420             break;
421           } else {
422             parent_nodes.pop_back();
423           }
424         }
425         parent_nodes.push_back(cur_node);
426       }
427     }
428   }
429 }
430 
ConnectInterThread(const std::vector<InterThreadConnectInfo> & connect_info_list)431 void EventForest::ConnectInterThread(
432     const std::vector<InterThreadConnectInfo>& connect_info_list) {
433   for (const auto& connect_info : connect_info_list) {
434     absl::flat_hash_map<std::vector<uint64>, EventNode*> connect_map;
435     const std::vector<int64_t>& parent_stat_types =
436         connect_info.parent_stat_types;
437     const std::vector<int64_t>* child_stat_types =
438         &connect_info.child_stat_types;
439     if (child_stat_types->empty()) {
440       child_stat_types = &parent_stat_types;
441     }
442     if (auto parent_event_node_list =
443             gtl::FindOrNull(event_node_map_, connect_info.parent_event_type)) {
444       for (EventNode& parent_event_node : *parent_event_node_list) {
445         std::vector<uint64> stats;
446         for (auto stat_type : parent_stat_types) {
447           absl::optional<XStatVisitor> stat =
448               parent_event_node.GetContextStat(stat_type);
449           if (!stat) break;
450           stats.push_back(stat->IntOrUintValue());
451         }
452         if (stats.size() == parent_stat_types.size()) {
453           connect_map[stats] = &parent_event_node;
454         }
455       }
456     }
457     if (auto child_event_node_list =
458             gtl::FindOrNull(event_node_map_, connect_info.child_event_type)) {
459       for (EventNode& child_event_node : *child_event_node_list) {
460         std::vector<uint64> stats;
461         for (auto stat_type : *child_stat_types) {
462           absl::optional<XStatVisitor> stat =
463               child_event_node.GetContextStat(stat_type);
464           if (!stat) break;
465           stats.push_back(stat->IntOrUintValue());
466         }
467         if (stats.size() == child_stat_types->size()) {
468           if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
469             parent_event_node->AddChild(&child_event_node);
470           }
471         }
472       }
473     }
474   }
475 }
476 
477 // Returns whether a root event needs grouping.
RootNeedsGrouping(const EventNode * root)478 bool RootNeedsGrouping(const EventNode* root) {
479   // No grouping is needed if it is already grouped.
480   if (root->GetGroupId().has_value()) return false;
481   // If there is a parent node with the same root level, skip grouping at <root>
482   // and later apply grouping at the parent node.
483   // If there is a parent node with a different root level, apply grouping at
484   // <root>, and later apply grouping at the parent node. Root events with
485   // different levels are grouped separately.
486   const EventNode* root_parent = FindParentWithComparator(
487       [root](const EventNode* parent) {
488         return parent->RootLevel() == root->RootLevel();
489       },
490       root,
491       /*include_self=*/false);
492   return root_parent == nullptr;
493 }
494 
495 // Sorts root events based on root level and timestamp.
SortRootEventList(EventList * event_list)496 void SortRootEventList(EventList* event_list) {
497   absl::c_sort(*event_list, [](const EventNode* e1, const EventNode* e2) {
498     // If two root events have the same root level, the root event with an
499     // earlier timestamp will be processed first. Otherwise, the event with a
500     // larger root level will be processed first.
501     return e1->RootLevel() == e2->RootLevel()
502                ? *e1 < *e2
503                : e1->RootLevel() > e2->RootLevel();
504   });
505 }
506 
CreateEventGroups()507 void EventForest::CreateEventGroups() {
508   // Create a group for each TF loop iteration in non-JAX profiles.
509   int64_t group_id = 0;
510   if (!HasJaxEvent(event_node_map_) && !tf_loop_root_events_.empty()) {
511     for (EventNode* root_event : tf_loop_root_events_) {
512       ProcessRootEvent(group_id++, root_event, &group_metadata_map_);
513     }
514     return;
515   }
516 
517   // Iterate over all events and collect all root events.
518   EventList root_events;
519   for (auto& [event_type, events] : event_node_map_) {
520     for (EventNode& event : events) {
521       if (!event.RootLevel()) continue;
522       absl::optional<XStatVisitor> step_id_stat =
523           event.GetEventVisitor().GetStat(StatType::kStepId);
524       // If this is a root event that associated with tf.data, skip.
525       if (step_id_stat && tf_data_step_ids_.contains(step_id_stat->IntValue()))
526         continue;
527       root_events.push_back(&event);
528     }
529   }
530 
531   SortRootEventList(&root_events);
532 
533   for (EventNode* root_event : root_events) {
534     if (RootNeedsGrouping(root_event) &&
535         // Ignores legacy TF root events for JAX profiles.
536         (!HasJaxEvent(event_node_map_) ||
537          !IsLegacyRootEvent(root_event->GetEventVisitor()))) {
538       ProcessRootEvent(group_id++, root_event, &group_metadata_map_);
539     }
540   }
541 }
542 
MarkEagerlyExecutedGpuKernels()543 void EventForest::MarkEagerlyExecutedGpuKernels() {
544   auto kernel_execute_event_node_list =
545       gtl::FindOrNull(event_node_map_, HostEventType::kKernelExecute);
546   if (!kernel_execute_event_node_list) return;
547   for (EventNode& kernel_execute_event_node : *kernel_execute_event_node_list) {
548     kernel_execute_event_node.SetIsEager(kernel_execute_event_node.IsEager());
549   }
550 }
551 
MarkEagerlyExecutedCpuTfOps()552 void EventForest::MarkEagerlyExecutedCpuTfOps() {
553   auto tf_op_run_event_node_list =
554       gtl::FindOrNull(event_node_map_, HostEventType::kTfOpRun);
555   if (!tf_op_run_event_node_list) return;
556   for (EventNode& tf_op_run_event_node : *tf_op_run_event_node_list) {
557     tf_op_run_event_node.SetIsEager(tf_op_run_event_node.IsEager());
558   }
559 }
560 
ProcessTfDataSteps()561 void EventForest::ProcessTfDataSteps() {
562   const int64_t tf_data_event_types[] = {
563       HostEventType::kTfDataCapturedFunctionRun,
564       HostEventType::kTfDataCapturedFunctionRunAsync,
565       HostEventType::kTfDataCapturedFunctionRunInstantiated,
566       HostEventType::kTfDataCapturedFunctionRunWithBorrowedArgs};
567   for (const int64_t tf_data_event_type : tf_data_event_types) {
568     auto tf_data_events = gtl::FindOrNull(event_node_map_, tf_data_event_type);
569     if (!tf_data_events) continue;
570     for (const EventNode& tf_data_event : *tf_data_events) {
571       absl::optional<XStatVisitor> step_id_stat =
572           tf_data_event.GetEventVisitor().GetStat(StatType::kStepId);
573       if (!step_id_stat) continue;
574       tf_data_step_ids_.insert(step_id_stat->IntValue());
575     }
576   }
577 }
578 
ProcessTensorFlowLoop()579 void EventForest::ProcessTensorFlowLoop() {
580   struct TensorFlowLoopIteration {
581     EventNode* first_event = nullptr;
582     std::vector<EventNode*> events;
583   };
584   using TensorFlowLoop =
585       absl::flat_hash_map<int64_t /*iter_num*/, TensorFlowLoopIteration>;
586   absl::flat_hash_map<int64_t /*step_id*/, TensorFlowLoop> tf_loops;
587 
588   // Sort the TF executor events by TF function/session (step_id) and iter_num.
589   auto executor_event_list =
590       gtl::FindOrNull(event_node_map_, HostEventType::kExecutorStateProcess);
591   if (!executor_event_list) return;
592   for (EventNode& executor_event : *executor_event_list) {
593     absl::optional<XStatVisitor> step_id_stat =
594         executor_event.GetEventVisitor().GetStat(StatType::kStepId);
595     absl::optional<XStatVisitor> iter_num_stat =
596         executor_event.GetEventVisitor().GetStat(StatType::kIterNum);
597     if (!step_id_stat || !iter_num_stat) continue;
598     int64_t step_id = step_id_stat->IntValue();
599     // Skip tf.data events.
600     if (tf_data_step_ids_.contains(step_id)) continue;
601     TensorFlowLoop& tf_loop = tf_loops[step_id];
602     TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->IntValue()];
603     if (!iteration.first_event || executor_event < *iteration.first_event) {
604       iteration.first_event = &executor_event;
605     }
606     iteration.events.push_back(&executor_event);
607   }
608 
609   std::vector<const TensorFlowLoopIteration*> iters;
610   for (const auto& step_id_and_tf_loop : tf_loops) {
611     const TensorFlowLoop& tf_loop = step_id_and_tf_loop.second;
612     // Filter out TF function/session without loops.
613     if (tf_loop.size() == 1 && tf_loop.contains(0)) continue;
614     for (const auto& iter_num_and_iter : tf_loop) {
615       iters.push_back(&iter_num_and_iter.second);
616     }
617   }
618 
619   // Sort iterations based on timestamp of the first event in the iteration.
620   absl::c_sort(iters, [](const auto& iter1, const auto& iter2) {
621     return *iter1->first_event < *iter2->first_event;
622   });
623 
624   // Register the first event of each iteration as a root event. Also, add the
625   // other events of the iteration as child to the root event.
626   for (const TensorFlowLoopIteration* iter : iters) {
627     EventNode* root_event = iter->first_event;
628     tf_loop_root_events_.push_back(root_event);
629     for (EventNode* event : iter->events) {
630       if (event == root_event) continue;
631       root_event->AddChild(event);
632     }
633   }
634 }
635 
ProcessWorker()636 void EventForest::ProcessWorker() {
637   auto eager_kernel_execute_event_list =
638       gtl::FindOrNull(event_node_map_, HostEventType::kEagerKernelExecute);
639   if (!eager_kernel_execute_event_list) return;
640   // The last EagerKernelExecute with a FunctionRun child.
641   EventNode* root_event = nullptr;
642   for (EventNode& eager_kernel_execute_event :
643        *eager_kernel_execute_event_list) {
644     if (HasFunctionRun(&eager_kernel_execute_event)) {
645       // A function op becomes a new root.
646       root_event = &eager_kernel_execute_event;
647       root_event->SetRootLevel(1);
648     } else if (root_event) {
649       // Add non-function eager ops as child.
650       root_event->AddChild(&eager_kernel_execute_event);
651     }
652   }
653 }
654 
AddPlane(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,XPlane * plane)655 void EventForest::AddPlane(
656     const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
657     XPlane* plane) {
658   CreateStatMetadata(plane);
659   planes_.push_back({plane, visitor_factory(plane)});
660 }
661 
AddSpace(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,XSpace * space)662 void EventForest::AddSpace(
663     const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
664     XSpace* space) {
665   for (XPlane& plane : *space->mutable_planes()) {
666     AddPlane(visitor_factory, &plane);
667   }
668 }
669 
AddPlanes(const std::function<XPlaneVisitor (const XPlane *)> visitor_factory,const std::vector<XPlane * > & planes)670 void EventForest::AddPlanes(
671     const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
672     const std::vector<XPlane*>& planes) {
673   for (XPlane* plane : planes) {
674     AddPlane(visitor_factory, plane);
675   }
676 }
677 
ConnectEvents(const std::vector<InterThreadConnectInfo> & connect_info_list)678 void EventForest::ConnectEvents(
679     const std::vector<InterThreadConnectInfo>& connect_info_list) {
680   ContextGroupMap context_groups;
681   for (auto& plane_visitor : planes_) {
682     ConnectIntraThread(plane_visitor.first, &plane_visitor.second,
683                        &context_groups);
684   }
685   ConnectInterThread(connect_info_list);
686   ConnectContextGroups(context_groups);
687 }
688 
ConnectTfDataEvents()689 void EventForest::ConnectTfDataEvents() {
690   absl::flat_hash_map<
691       std::pair<int64_t /*iterator_id*/, int64_t /*element_id*/>,
692       std::vector<EventNode*>>
693       produce_iterator_map;
694   uint64 num_producers = 0;
695   for (HostEventType event_type :
696        {HostEventType::kPrefetchProduce,
697         HostEventType::kParallelInterleaveProduce,
698         HostEventType::kParallelMapProduce, HostEventType::kMapAndBatchProduce,
699         HostEventType::kParseExampleProduce,
700         HostEventType::kParallelBatchProduce}) {
701     auto produce_event_list = gtl::FindOrNull(event_node_map_, event_type);
702     if (!produce_event_list) continue;
703     VLOG(1) << produce_event_list->size() << " "
704             << GetHostEventTypeStr(event_type) << " events found.";
705     for (EventNode& produce_event : *produce_event_list) {
706       absl::optional<XStatVisitor> element_id =
707           produce_event.GetEventVisitor().GetStat(StatType::kElementId);
708       if (!element_id.has_value()) continue;
709       for (EventNode* produce_iterator : produce_event.GetChildren()) {
710         if (IsIteratorEventType(produce_iterator->GetEventVisitor().Type())) {
711           absl::optional<XStatVisitor> iterator_id =
712               produce_iterator->GetEventVisitor().GetStat(StatType::kParentId);
713           if (!iterator_id.has_value()) break;
714           produce_iterator_map[{iterator_id->IntValue(),
715                                 element_id->IntValue()}]
716               .push_back(produce_iterator);
717           ++num_producers;
718           break;
719         }
720       }
721     }
722   }
723   VLOG(1) << num_producers << " producer iterators found.";
724   uint64 num_matched = 0;
725   for (HostEventType event_type :
726        {HostEventType::kPrefetchConsume,
727         HostEventType::kParallelInterleaveConsume,
728         HostEventType::kParallelMapConsume, HostEventType::kMapAndBatchConsume,
729         HostEventType::kParseExampleConsume,
730         HostEventType::kParallelBatchConsume}) {
731     auto consume_event_list = gtl::FindOrNull(event_node_map_, event_type);
732     if (!consume_event_list) continue;
733     VLOG(1) << consume_event_list->size() << " "
734             << GetHostEventTypeStr(event_type) << " events found.";
735     for (EventNode& consume_event : *consume_event_list) {
736       absl::optional<XStatVisitor> element_id =
737           consume_event.GetEventVisitor().GetStat(StatType::kElementId);
738       if (!element_id.has_value()) continue;
739       if (consume_event.GetParents().empty()) continue;
740       // consume_event is nested by consumer_iterator and does not have other
741       // parents.
742       EventNode* consume_iterator = consume_event.GetParents().at(0);
743       if (!consume_iterator ||
744           !IsIteratorEventType(consume_iterator->GetEventVisitor().Type())) {
745         continue;
746       }
747       absl::optional<XStatVisitor> iterator_id =
748           consume_iterator->GetEventVisitor().GetStat(StatType::kStepId);
749       if (!iterator_id.has_value()) continue;
750       if (auto produce_iterators = gtl::FindOrNull(
751               produce_iterator_map, std::make_pair(iterator_id->IntValue(),
752                                                    element_id->IntValue()))) {
753         for (EventNode* produce_iterator : *produce_iterators) {
754           consume_iterator->AddChild(produce_iterator);
755           ++num_matched;
756         }
757       }
758     }
759   }
760   VLOG(1) << num_matched << " consumer iterators matched.";
761 }
762 
GroupEvents()763 void EventForest::GroupEvents() {
764   ProcessTfDataSteps();
765   ProcessTensorFlowLoop();
766   ProcessWorker();
767   CreateEventGroups();
768   MarkEagerlyExecutedGpuKernels();
769   MarkEagerlyExecutedCpuTfOps();
770 }
771 
CreateInterThreadConnectInfoList()772 std::vector<InterThreadConnectInfo> CreateInterThreadConnectInfoList() {
773   std::vector<InterThreadConnectInfo> connect_info_list = {
774       {HostEventType::kExecutorStateProcess,
775        HostEventType::kIteratorGetNextOp,
776        {StatType::kStepId, StatType::kIterNum}},
777       {HostEventType::kExecutorStateProcess,
778        HostEventType::kIteratorGetNextAsOptionalOp,
779        {StatType::kStepId, StatType::kIterNum}},
780       {HostEventType::kKernelLaunch,
781        HostEventType::kKernelExecute,
782        {StatType::kCorrelationId}}};
783   return connect_info_list;
784 }
785 
GroupTfEvents(XSpace * space,EventForest * event_forest)786 void GroupTfEvents(XSpace* space, EventForest* event_forest) {
787   if (CheckLoopOp(*space)) {
788     // TODO(b/154510598): Support TF's loop ops.
789     return;
790   }
791   std::vector<InterThreadConnectInfo> connect_info_list =
792       CreateInterThreadConnectInfoList();
793   event_forest->AddSpace(CreateTfXPlaneVisitor, space);
794   event_forest->ConnectEvents(connect_info_list);
795   event_forest->GroupEvents();
796 }
797 
GroupTfEvents(XSpace * space)798 void GroupTfEvents(XSpace* space) {
799   EventForest event_forest;
800   GroupTfEvents(space, &event_forest);
801 }
802 
803 }  // namespace profiler
804 }  // namespace tensorflow
805