xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/convert/xplane_to_trace_events.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/convert/xplane_to_trace_events.h"
17 
18 #include <stddef.h>
19 
20 #include <algorithm>
21 #include <iterator>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/strings/string_view.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
29 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
30 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
31 #include "tensorflow/core/profiler/utils/trace_utils.h"
32 #include "tensorflow/core/profiler/utils/xplane_schema.h"
33 #include "tensorflow/core/profiler/utils/xplane_utils.h"
34 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
35 
36 namespace tensorflow {
37 namespace profiler {
38 
39 namespace {
40 
BuildDeviceAndResources(uint32 device_id,const XPlaneVisitor & plane,Device * device)41 void BuildDeviceAndResources(uint32 device_id, const XPlaneVisitor& plane,
42                              Device* device) {
43   device->set_name(std::string(plane.Name()));
44   device->set_device_id(device_id);
45 
46   bool sort_by_ordinal = (device_id == kHostThreadsDeviceId);
47   int ordinal = 0;
48   plane.ForEachLine([&](const XLineVisitor& line) {
49     uint32 resource_id = line.DisplayId();
50     Resource& resource = (*device->mutable_resources())[resource_id];
51     resource.set_resource_id(resource_id);
52     resource.set_name(std::string(line.DisplayName()));
53     if (sort_by_ordinal) {
54       // When sort_index is absent (i.e. 0), resource id will be used.
55       // Therefore sort_index starts with 1.
56       resource.set_sort_index(++ordinal);
57     }
58   });
59 }
60 
ConvertXPlaneToTraceEvents(uint32 device_id,const XPlaneVisitor & xplane,Trace * trace)61 void ConvertXPlaneToTraceEvents(uint32 device_id, const XPlaneVisitor& xplane,
62                                 Trace* trace) {
63   // Convert devices and resources.
64   BuildDeviceAndResources(device_id, xplane,
65                           &(*trace->mutable_devices())[device_id]);
66 
67   // Convert events.
68   xplane.ForEachLine([device_id, trace](const XLineVisitor& xline) {
69     uint32 resource_id = xline.DisplayId();
70     xline.ForEachEvent(
71         [device_id, resource_id, trace](const XEventVisitor& xevent) {
72           int64_t event_type =
73               xevent.Type().value_or(HostEventType::kUnknownHostEventType);
74           if (IsInternalEvent(event_type)) return;
75           auto* event = trace->add_trace_events();
76           auto& args = *event->mutable_args();
77           event->set_device_id(device_id);
78           event->set_resource_id(resource_id);
79           if (xevent.HasDisplayName()) {
80             event->set_name(std::string(xevent.DisplayName()));
81             args["long_name"] = std::string(xevent.Name());
82           } else {
83             event->set_name(std::string(xevent.Name()));
84           }
85           event->set_timestamp_ps(xevent.TimestampPs());
86           event->set_duration_ps(xevent.DurationPs());
87 
88           auto for_each_stat = [&](const XStatVisitor& stat) {
89             if (stat.ValueCase() == XStat::VALUE_NOT_SET) return;
90             if (IsInternalStat(stat.Type())) return;
91             if (stat.Type() == StatType::kStepName) {
92               event->set_name(stat.ToString());
93             }
94             args[std::string(stat.Name())] = stat.ToString();
95           };
96           // The metadata stats should appear before the per-occurrence stats.
97           xevent.Metadata().ForEachStat(for_each_stat);
98           xevent.ForEachStat(for_each_stat);
99         });
100   });
101 }
102 
103 }  // namespace
104 
MaybeDropEventsForTraceViewer(Trace * trace,uint32 limit)105 void MaybeDropEventsForTraceViewer(Trace* trace, uint32 limit) {
106   auto* trace_events = trace->mutable_trace_events();
107   size_t trace_event_size = trace_events->size();
108   if (trace_event_size <= limit) return;  // Nothing to do.
109   // Sort the events according to start time.
110   std::vector<uint64> timestamps;
111   timestamps.reserve(trace_event_size);
112   for (const auto& event : *trace_events) {
113     timestamps.push_back(event.timestamp_ps());
114   }
115   std::partial_sort(timestamps.begin(), timestamps.begin() + limit,
116                     timestamps.end(), std::less<uint64>());
117   uint64 cutoff_timestamp = timestamps[limit - 1];
118   trace_events->erase(std::remove_if(trace_events->begin(), trace_events->end(),
119                                      [&](const TraceEvent& event) {
120                                        return event.timestamp_ps() >
121                                               cutoff_timestamp;
122                                      }),
123                       trace_events->end());
124 }
125 
ConvertXSpaceToTraceEvents(const XSpace & xspace,Trace * trace)126 void ConvertXSpaceToTraceEvents(const XSpace& xspace, Trace* trace) {
127   const XPlane* host_plane = FindPlaneWithName(xspace, kHostThreadsPlaneName);
128   if (host_plane != nullptr) {
129     XPlaneVisitor xplane = CreateTfXPlaneVisitor(host_plane);
130     ConvertXPlaneToTraceEvents(kHostThreadsDeviceId, xplane, trace);
131   }
132   std::vector<const XPlane*> device_planes =
133       FindPlanesWithPrefix(xspace, kGpuPlanePrefix);
134   // We don't expect GPU and TPU planes and custom devices to be present in the
135   // same XSpace.
136   if (device_planes.empty()) {
137     device_planes = FindPlanesWithPrefix(xspace, kTpuPlanePrefix);
138   }
139   if (device_planes.empty()) {
140     device_planes = FindPlanesWithPrefix(xspace, kCustomPlanePrefix);
141   }
142   for (const XPlane* device_plane : device_planes) {
143     XPlaneVisitor xplane = CreateTfXPlaneVisitor(device_plane);
144     uint32 device_id = kFirstDeviceId + xplane.Id();
145     ConvertXPlaneToTraceEvents(device_id, xplane, trace);
146   }
147 
148   // Trace viewer (non-streaming) has scalability issues, we need to drop
149   // events to avoid loading failure for trace viewer.
150   constexpr uint64 kMaxEvents = 1000000;
151   MaybeDropEventsForTraceViewer(trace, kMaxEvents);
152 }
153 
ConvertXSpaceToTraceEventsString(const XSpace & xspace,std::string * content)154 void ConvertXSpaceToTraceEventsString(const XSpace& xspace,
155                                       std::string* content) {
156   Trace trace;
157   ConvertXSpaceToTraceEvents(xspace, &trace);
158   trace.SerializeToString(content);
159 }
160 
161 }  // namespace profiler
162 }  // namespace tensorflow
163