xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/utils/group_events_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/utils/group_events.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/strings/numbers.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_split.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/core/profiler/lib/connected_traceme.h"
27 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
28 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
29 #include "tensorflow/core/profiler/utils/xplane_builder.h"
30 #include "tensorflow/core/profiler/utils/xplane_schema.h"
31 #include "tensorflow/core/profiler/utils/xplane_test_utils.h"
32 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
33 
34 namespace tensorflow {
35 namespace profiler {
36 namespace {
37 
TEST(GroupEventsTest,GroupGpuTraceLegacyRootTest)38 TEST(GroupEventsTest, GroupGpuTraceLegacyRootTest) {
39   constexpr int64_t kStepNum = 123;
40   constexpr int64_t kStepId = 0;
41   constexpr int64_t kCorrelationId = 100;
42 
43   XSpace space;
44   XPlaneBuilder host_plane_builder(GetOrCreateHostXPlane(&space));
45   host_plane_builder.ReserveLines(2);
46 
47   auto main_thread = host_plane_builder.GetOrCreateLine(0);
48   CreateXEvent(
49       &host_plane_builder, &main_thread, HostEventType::kTraceContext, 0, 100,
50       {{StatType::kGraphType, "train"}, {StatType::kStepNum, kStepNum}});
51   CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun,
52                10, 90, {{StatType::kStepId, kStepId}});
53 
54   auto tf_executor_thread = host_plane_builder.GetOrCreateLine(1);
55   CreateXEvent(&host_plane_builder, &tf_executor_thread,
56                HostEventType::kExecutorStateProcess, 20, 80,
57                {{StatType::kStepId, kStepId}});
58   CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 30, 70,
59                {{StatType::kCorrelationId, kCorrelationId}});
60 
61   XPlane* device_plane = space.add_planes();
62   XPlaneBuilder device_plane_builder(device_plane);
63   device_plane_builder.ReserveLines(1);
64 
65   auto stream = device_plane_builder.GetOrCreateLine(0);
66   CreateXEvent(&device_plane_builder, &stream, "matmul", 200, 300,
67                {{StatType::kCorrelationId, kCorrelationId}});
68 
69   EventForest event_forest;
70   GroupTfEvents(&space, &event_forest);
71   const GroupMetadataMap& group_metadata_map =
72       event_forest.GetGroupMetadataMap();
73   XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
74   EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 3);
75   EXPECT_EQ(device_plane_visitor.GetStatType(
76                 device_plane->lines(0).events(0).stats(1).metadata_id()),
77             StatType::kGroupId);
78   EXPECT_EQ(group_metadata_map.size(), 1);
79   EXPECT_EQ(group_metadata_map.at(0).name, "train 123");
80 }
81 
TEST(GroupEventsTest,GroupGpuTraceTest)82 TEST(GroupEventsTest, GroupGpuTraceTest) {
83   constexpr int64_t kStepNum = 123;
84   constexpr int64_t kStepId = 0;
85   constexpr int64_t kCorrelationId = 100;
86 
87   XSpace space;
88   XPlaneBuilder host_plane_builder(GetOrCreateHostXPlane(&space));
89   host_plane_builder.ReserveLines(2);
90 
91   auto main_thread = host_plane_builder.GetOrCreateLine(0);
92   CreateXEvent(
93       &host_plane_builder, &main_thread, "train", 0, 100,
94       {{StatType::kStepNum, kStepNum}, {StatType::kIsRoot, int64_t{1}}});
95   CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun,
96                10, 90, {{StatType::kStepId, kStepId}});
97 
98   auto tf_executor_thread = host_plane_builder.GetOrCreateLine(1);
99   CreateXEvent(&host_plane_builder, &tf_executor_thread,
100                HostEventType::kExecutorStateProcess, 20, 80,
101                {{StatType::kStepId, kStepId}});
102   CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 30, 70,
103                {{StatType::kCorrelationId, kCorrelationId}});
104 
105   XPlane* device_plane = space.add_planes();
106   XPlaneBuilder device_plane_builder(device_plane);
107   device_plane_builder.ReserveLines(1);
108 
109   auto stream = device_plane_builder.GetOrCreateLine(0);
110   CreateXEvent(&device_plane_builder, &stream, "matmul", 200, 300,
111                {{StatType::kCorrelationId, kCorrelationId}});
112 
113   EventForest event_forest;
114   GroupTfEvents(&space, &event_forest);
115   const GroupMetadataMap& group_metadata_map =
116       event_forest.GetGroupMetadataMap();
117   XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
118   EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 3);
119   EXPECT_EQ(device_plane_visitor.GetStatType(
120                 device_plane->lines(0).events(0).stats(1).metadata_id()),
121             StatType::kGroupId);
122   EXPECT_EQ(group_metadata_map.size(), 1);
123   EXPECT_EQ(group_metadata_map.at(0).name, "train 123");
124 }
125 
TEST(GroupEventsTest,GroupTensorFlowLoopTest)126 TEST(GroupEventsTest, GroupTensorFlowLoopTest) {
127   constexpr int64_t kStepId = 0;
128   constexpr int64_t kIterNum = 10;
129   constexpr int64_t kCorrelationId = 100;
130 
131   XSpace space;
132   XPlaneBuilder host_plane_builder(GetOrCreateHostXPlane(&space));
133   host_plane_builder.ReserveLines(1);
134 
135   auto tf_executor_thread = host_plane_builder.GetOrCreateLine(0);
136   CreateXEvent(&host_plane_builder, &tf_executor_thread,
137                HostEventType::kExecutorStateProcess, 5, 10,
138                {{StatType::kStepId, kStepId}, {StatType::kIterNum, kIterNum}});
139   CreateXEvent(&host_plane_builder, &tf_executor_thread,
140                HostEventType::kExecutorStateProcess, 20, 80,
141                {{StatType::kStepId, kStepId}, {StatType::kIterNum, kIterNum}});
142   CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 30, 70,
143                {{StatType::kCorrelationId, kCorrelationId}});
144 
145   XPlane* device_plane = space.add_planes();
146   XPlaneBuilder device_plane_builder(device_plane);
147   device_plane_builder.ReserveLines(1);
148 
149   auto stream = device_plane_builder.GetOrCreateLine(0);
150   CreateXEvent(&device_plane_builder, &stream, "matmul", 200, 300,
151                {{StatType::kCorrelationId, kCorrelationId}});
152 
153   EventForest event_forest;
154   GroupTfEvents(&space, &event_forest);
155   const GroupMetadataMap& group_metadata_map =
156       event_forest.GetGroupMetadataMap();
157   XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
158   EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 3);
159   EXPECT_EQ(device_plane_visitor.GetStatType(
160                 device_plane->lines(0).events(0).stats(1).metadata_id()),
161             StatType::kGroupId);
162   // group_id is assigned using a list of consecutive number starting from 0.
163   EXPECT_EQ(device_plane->lines(0).events(0).stats(1).int64_value(), 0);
164   EXPECT_EQ(group_metadata_map.size(), 1);
165   // group name of ExecutorState::Process event is assigned using iter_num.
166   ASSERT_TRUE(group_metadata_map.contains(0));
167   EXPECT_EQ(group_metadata_map.at(0).name, "10");
168 }
169 
170 // When there are multiple TF loops, group_id is assigned in the order of TF
171 // loops' start times and iter_num. In this test case, the profile captures the
172 // last two iterations (iter_num=10,11) of the first TF loop (step_id=0) and the
173 // first two iterations (iter_num=0,1) of the second TF loop (step_id=1).
174 // group_id is initialized to the first TF loop's first iter_num (10) and then
175 // monotonically increased.
TEST(GroupEventsTest,GroupMultipleTensorFlowLoopsTest)176 TEST(GroupEventsTest, GroupMultipleTensorFlowLoopsTest) {
177   constexpr int64_t kFirstStepId = 0;
178   constexpr int64_t kSecondStepId = 1;
179   constexpr int64_t kFirstIterNumStart = 10;
180   constexpr int64_t kSecondIterNumStart = 0;
181 
182   XSpace space;
183   XPlaneBuilder host_plane_builder(GetOrCreateHostXPlane(&space));
184   host_plane_builder.ReserveLines(2);
185 
186   auto first_tf_executor_thread = host_plane_builder.GetOrCreateLine(0);
187   CreateXEvent(&host_plane_builder, &first_tf_executor_thread,
188                HostEventType::kExecutorStateProcess, 220, 80,
189                {{StatType::kStepId, kSecondStepId},
190                 {StatType::kIterNum, kSecondIterNumStart}});
191   CreateXEvent(&host_plane_builder, &first_tf_executor_thread,
192                HostEventType::kExecutorStateProcess, 320, 80,
193                {{StatType::kStepId, kSecondStepId},
194                 {StatType::kIterNum, kSecondIterNumStart + 1}});
195   auto second_tf_executor_thread = host_plane_builder.GetOrCreateLine(1);
196   CreateXEvent(&host_plane_builder, &second_tf_executor_thread,
197                HostEventType::kExecutorStateProcess, 20, 80,
198                {{StatType::kStepId, kFirstStepId},
199                 {StatType::kIterNum, kFirstIterNumStart}});
200   CreateXEvent(&host_plane_builder, &second_tf_executor_thread,
201                HostEventType::kExecutorStateProcess, 120, 80,
202                {{StatType::kStepId, kFirstStepId},
203                 {StatType::kIterNum, kFirstIterNumStart + 1}});
204 
205   EventForest event_forest;
206   GroupTfEvents(&space, &event_forest);
207   const GroupMetadataMap& group_metadata_map =
208       event_forest.GetGroupMetadataMap();
209   EXPECT_EQ(group_metadata_map.size(), 4);
210   // group_id is assigned using a list of consecutive number starting from 0,
211   // event with an earlier start time will get a smaller group_id.
212   // group name of ExecutorState::Process event is assigned using iter_num.
213   ASSERT_TRUE(group_metadata_map.contains(0));
214   // iter_num 10 starts at timestamp 20, so it has the smallest group_id.
215   EXPECT_EQ(group_metadata_map.at(0).name, "10");
216   ASSERT_TRUE(group_metadata_map.contains(1));
217   EXPECT_EQ(group_metadata_map.at(1).name, "11");
218   ASSERT_TRUE(group_metadata_map.contains(2));
219   EXPECT_EQ(group_metadata_map.at(2).name, "0");
220   ASSERT_TRUE(group_metadata_map.contains(3));
221   // iter_num 1 starts at timestamp 320, so it has the largest group_id.
222   EXPECT_EQ(group_metadata_map.at(3).name, "1");
223 }
224 
TEST(GroupEventsTest,EagerOpTest)225 TEST(GroupEventsTest, EagerOpTest) {
226   XSpace space;
227   XPlane* host_plane = GetOrCreateHostXPlane(&space);
228   XPlaneBuilder host_plane_builder(host_plane);
229   host_plane_builder.ReserveLines(1);
230   auto main_thread = host_plane_builder.GetOrCreateLine(0);
231 
232   XPlane* device_plane = space.add_planes();
233   XPlaneBuilder device_plane_builder(device_plane);
234   device_plane_builder.ReserveLines(1);
235   auto gpu_stream = device_plane_builder.GetOrCreateLine(0);
236 
237   int64_t correlation_id = 100;
238   // TF1 ops are NOT scheduled under kEagerKernelExecute events, they should be
239   // considered NOT eager.
240   const char* kTF1GpuLaunchEvent = "tf1 matmul";
241   const char* kTF1GpuEvent = "tf1_kernel_matmul";
242   CreateXEvent(&host_plane_builder, &main_thread, kTF1GpuLaunchEvent, 10, 90,
243                {{StatType::kCorrelationId, correlation_id}});
244   CreateXEvent(&device_plane_builder, &gpu_stream, kTF1GpuEvent, 200, 300,
245                {{StatType::kCorrelationId, correlation_id}});
246   ++correlation_id;
247 
248   // Eagerly scheduled GPU operator w/o is_func Xstat (legacy). The legacy trace
249   // will also fall into this case, due to the fact we changed the EagerExecute
250   // TraceMe format. We treat them as NOT eager
251   const char* kLegacyGpuLaunchEvent = "legacy matmul";
252   const char* kLegacyGpuEvent = "legacy_kernel_matmul";
253   CreateXEvent(&host_plane_builder, &main_thread,
254                HostEventType::kEagerKernelExecute, 100, 200);
255   CreateXEvent(&host_plane_builder, &main_thread, kLegacyGpuLaunchEvent, 110,
256                190, {{StatType::kCorrelationId, correlation_id}});
257   CreateXEvent(&device_plane_builder, &gpu_stream, kLegacyGpuEvent, 300, 400,
258                {{StatType::kCorrelationId, correlation_id}});
259   ++correlation_id;
260 
261   // Eagerly scheduled GPU op with is_func Xstat.
262   const char* kEagerOpGpuLaunchEvent = "eager op matmul";
263   const char* kEagerOpGpuEvent = "eager_op_kernel_matmul";
264   CreateXEvent(&host_plane_builder, &main_thread,
265                HostEventType::kEagerKernelExecute, 200, 300,
266                {{StatType::kIsFunc, static_cast<int64_t>(0)}});
267   CreateXEvent(&host_plane_builder, &main_thread, kEagerOpGpuLaunchEvent, 210,
268                290, {{StatType::kCorrelationId, correlation_id}});
269   CreateXEvent(&device_plane_builder, &gpu_stream, kEagerOpGpuEvent, 400, 500,
270                {{StatType::kCorrelationId, correlation_id}});
271   ++correlation_id;
272 
273   // Eagerly scheduled GPU func with is_func Xstat.
274   const char* kEagerFuncGpuLaunchEvent = "eager func matmul";
275   const char* kEagerFuncGpuEvent = "eager_func_kernel_matmul";
276   CreateXEvent(&host_plane_builder, &main_thread,
277                HostEventType::kEagerKernelExecute, 300, 400,
278                {{StatType::kIsFunc, static_cast<int64_t>(1)}});
279   CreateXEvent(&host_plane_builder, &main_thread, kEagerFuncGpuLaunchEvent, 310,
280                390, {{StatType::kCorrelationId, correlation_id}});
281   CreateXEvent(&device_plane_builder, &gpu_stream, kEagerFuncGpuEvent, 500, 600,
282                {{StatType::kCorrelationId, correlation_id}});
283   ++correlation_id;
284 
285   // Eagerly executed CPU TF op.
286   const char* kEagerOpCpuEvent = "eager_op_cpu_kernel:Matmul";
287   CreateXEvent(&host_plane_builder, &main_thread,
288                HostEventType::kEagerKernelExecute, 400, 500,
289                {{StatType::kIsFunc, static_cast<int64_t>(0)}});
290   CreateXEvent(&host_plane_builder, &main_thread, kEagerOpCpuEvent, 410, 490);
291 
292   // Eagerly executed CPU TF function.
293   const char* kEagerFuncCpuEvent = "eager_func_cpu_kernel:Matmul";
294   CreateXEvent(&host_plane_builder, &main_thread,
295                HostEventType::kEagerKernelExecute, 500, 600,
296                {{StatType::kIsFunc, static_cast<int64_t>(1)}});
297   CreateXEvent(&host_plane_builder, &main_thread, kEagerFuncCpuEvent, 510, 590);
298 
299   GroupTfEvents(&space);
300 
301   auto is_eager = [](const XEventVisitor& event) {
302     auto eager_stats = event.GetStat(StatType::kIsEager);
303     return eager_stats && eager_stats->IntValue();
304   };
305   // verify host ops.
306   XPlaneVisitor host_plane_visitor = CreateTfXPlaneVisitor(host_plane);
307   int interested_events_encountered = 0;
308   host_plane_visitor.ForEachLine([&](const XLineVisitor& line) {
309     line.ForEachEvent([&](const XEventVisitor& event) {
310       if (event.Name() == kEagerOpCpuEvent) {
311         interested_events_encountered++;
312         EXPECT_TRUE(is_eager(event));
313       } else if (event.Name() == kEagerFuncCpuEvent) {
314         interested_events_encountered++;
315         EXPECT_FALSE(is_eager(event));
316       }
317     });
318   });
319   EXPECT_EQ(interested_events_encountered, 2);
320 
321   // verify device ops.
322   XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
323   interested_events_encountered = 0;
324   device_plane_visitor.ForEachLine([&](const XLineVisitor& line) {
325     line.ForEachEvent([&](const XEventVisitor& event) {
326       if (event.Name() == kTF1GpuEvent) {
327         interested_events_encountered++;
328         EXPECT_FALSE(is_eager(event));
329       } else if (event.Name() == kLegacyGpuEvent) {
330         interested_events_encountered++;
331         EXPECT_FALSE(is_eager(event));
332       } else if (event.Name() == kEagerOpGpuEvent) {
333         interested_events_encountered++;
334         EXPECT_TRUE(is_eager(event));
335       } else if (event.Name() == kEagerFuncGpuEvent) {
336         interested_events_encountered++;
337         EXPECT_FALSE(is_eager(event));
338       }
339     });
340   });
341   EXPECT_EQ(interested_events_encountered, 4);
342 }
343 
TEST(GroupEventsTest,FunctionOpTest)344 TEST(GroupEventsTest, FunctionOpTest) {
345   constexpr int64_t kStepNum = 123;
346   constexpr int64_t kStepId = 0;
347   constexpr int64_t kCorrelationId = 100;
348 
349   XSpace space;
350   XPlane* host_plane = GetOrCreateHostXPlane(&space);
351   XPlaneBuilder host_plane_builder(host_plane);
352   host_plane_builder.ReserveLines(2);
353 
354   auto main_thread = host_plane_builder.GetOrCreateLine(0);
355   CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext,
356                0, 100, {{StatType::kStepNum, kStepNum}});
357   CreateXEvent(&host_plane_builder, &main_thread,
358                HostEventType::kEagerKernelExecute, 10, 90);
359   CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun,
360                10, 90, {{StatType::kStepId, kStepId}});
361 
362   auto tf_executor_thread = host_plane_builder.GetOrCreateLine(1);
363   CreateXEvent(&host_plane_builder, &tf_executor_thread,
364                HostEventType::kExecutorStateProcess, 20, 80,
365                {{StatType::kStepId, kStepId}});
366   // GPU kernel scheduled inside tf.function.
367   CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 30, 30,
368                {{StatType::kCorrelationId, kCorrelationId}});
369   // CPU TF op executed inside tf.function.
370   CreateXEvent(&host_plane_builder, &tf_executor_thread, "add:Add", 70, 20);
371 
372   XPlane* device_plane = space.add_planes();
373   XPlaneBuilder device_plane_builder(device_plane);
374   device_plane_builder.ReserveLines(1);
375 
376   auto stream = device_plane_builder.GetOrCreateLine(0);
377   // GPU kernel executed as part of tf.function.
378   CreateXEvent(&device_plane_builder, &stream, "matmul", 200, 300,
379                {{StatType::kCorrelationId, kCorrelationId}});
380 
381   GroupTfEvents(&space);
382   XPlaneVisitor host_plane_visitor = CreateTfXPlaneVisitor(host_plane);
383   const XEvent& cpu_tf_op = host_plane->lines(1).events(2);
384   EXPECT_EQ(cpu_tf_op.stats_size(), 2);
385   EXPECT_EQ(host_plane_visitor.GetStatType(cpu_tf_op.stats(1).metadata_id()),
386             StatType::kIsEager);
387   EXPECT_EQ(cpu_tf_op.stats(1).int64_value(), 0);
388   XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
389   const XEvent& gpu_kernel = device_plane->lines(0).events(0);
390   EXPECT_EQ(gpu_kernel.stats_size(), 3);
391   EXPECT_EQ(device_plane_visitor.GetStatType(gpu_kernel.stats(2).metadata_id()),
392             StatType::kIsEager);
393   EXPECT_EQ(gpu_kernel.stats(2).int64_value(), 0);
394 }
395 
TEST(GroupEventsTest,SemanticArgTest)396 TEST(GroupEventsTest, SemanticArgTest) {
397   constexpr int64_t kIsRoot = 1;
398   constexpr int64_t kStepNum = 100;
399   constexpr int64_t kContextType = 123;
400   constexpr uint64 kContextId = 456;
401 
402   XSpace raw_space;
403   XPlane* raw_plane = raw_space.add_planes();
404   XPlaneBuilder plane(raw_plane);
405   plane.ReserveLines(2);
406   auto root_producer = plane.GetOrCreateLine(0);
407   CreateXEvent(&plane, &root_producer, HostEventType::kTraceContext, 0, 100,
408                {{StatType::kIsRoot, kIsRoot}, {StatType::kStepNum, kStepNum}});
409   CreateXEvent(&plane, &root_producer, HostEventType::kFunctionRun, 10, 90,
410                {{StatType::kProducerType, kContextType},
411                 {StatType::kProducerId, kContextId}});
412   auto consumer = plane.GetOrCreateLine(1);
413   CreateXEvent(&plane, &consumer, HostEventType::kExecutorStateProcess, 20, 80,
414                {{StatType::kConsumerType, kContextType},
415                 {StatType::kConsumerId, kContextId}});
416 
417   GroupTfEvents(&raw_space);
418   int num_events = 0;
419   CreateTfXPlaneVisitor(raw_plane).ForEachLine(
420       [&](const tensorflow::profiler::XLineVisitor& line) {
421         num_events += line.NumEvents();
422         line.ForEachEvent(
423             [&](const tensorflow::profiler::XEventVisitor& event) {
424               absl::optional<int64_t> group_id;
425               if (absl::optional<XStatVisitor> stat =
426                       event.GetStat(StatType::kGroupId)) {
427                 group_id = stat->IntValue();
428               }
429               EXPECT_TRUE(group_id.has_value());
430               EXPECT_EQ(*group_id, 0);
431             });
432       });
433   EXPECT_EQ(num_events, 3);
434 }
435 
TEST(GroupEventsTest,SemanticIntArgNoMatchTest)436 TEST(GroupEventsTest, SemanticIntArgNoMatchTest) {
437   constexpr int64_t kIsRoot = 1;
438   constexpr int64_t kStepNum = 100;
439   constexpr int64_t kContextType = 123;
440   constexpr uint64 kProducerId = 456;
441   constexpr uint64 kConsumerId = 789;
442 
443   XSpace raw_space;
444   XPlane* raw_plane = raw_space.add_planes();
445   XPlaneBuilder plane(raw_plane);
446   plane.ReserveLines(2);
447   auto root_producer = plane.GetOrCreateLine(0);
448   CreateXEvent(&plane, &root_producer, HostEventType::kTraceContext, 0, 100,
449                {{StatType::kIsRoot, kIsRoot}, {StatType::kStepNum, kStepNum}});
450   CreateXEvent(&plane, &root_producer, HostEventType::kFunctionRun, 10, 90,
451                {{StatType::kProducerType, kContextType},
452                 {StatType::kProducerId, kProducerId}});
453   auto consumer = plane.GetOrCreateLine(1);
454   CreateXEvent(&plane, &consumer, HostEventType::kExecutorStateProcess, 20, 80,
455                {{StatType::kConsumerType, kContextType},
456                 {StatType::kConsumerId, kConsumerId}});
457 
458   GroupTfEvents(&raw_space);
459   int num_events = 0;
460   CreateTfXPlaneVisitor(raw_plane).ForEachLine(
461       [&](const tensorflow::profiler::XLineVisitor& line) {
462         num_events += line.NumEvents();
463         line.ForEachEvent(
464             [&](const tensorflow::profiler::XEventVisitor& event) {
465               absl::optional<int64_t> group_id;
466               if (absl::optional<XStatVisitor> stat =
467                       event.GetStat(StatType::kGroupId)) {
468                 group_id = stat->IntValue();
469               }
470               if (event.Type() == HostEventType::kExecutorStateProcess) {
471                 EXPECT_FALSE(group_id.has_value());
472               } else {
473                 EXPECT_TRUE(group_id.has_value());
474                 EXPECT_EQ(*group_id, 0);
475               }
476             });
477       });
478   EXPECT_EQ(num_events, 3);
479 }
480 
TEST(GroupEventsTest,SemanticUintArgNoMatchTest)481 TEST(GroupEventsTest, SemanticUintArgNoMatchTest) {
482   constexpr int64_t kIsRoot = 1;
483   constexpr int64_t kStepNum = 100;
484   constexpr int64_t kContextType = 123;
485   constexpr uint64 kProducerId = UINT64_MAX;
486   constexpr uint64 kConsumerId = UINT64_MAX - 1;
487 
488   XSpace raw_space;
489   XPlane* raw_plane = raw_space.add_planes();
490   XPlaneBuilder plane(raw_plane);
491   plane.ReserveLines(2);
492   auto root_producer = plane.GetOrCreateLine(0);
493   CreateXEvent(&plane, &root_producer, HostEventType::kTraceContext, 0, 100,
494                {{StatType::kIsRoot, kIsRoot}, {StatType::kStepNum, kStepNum}});
495   CreateXEvent(&plane, &root_producer, HostEventType::kFunctionRun, 10, 90,
496                {{StatType::kProducerType, kContextType},
497                 {StatType::kProducerId, kProducerId}});
498   auto consumer = plane.GetOrCreateLine(1);
499   CreateXEvent(&plane, &consumer, HostEventType::kExecutorStateProcess, 20, 80,
500                {{StatType::kConsumerType, kContextType},
501                 {StatType::kConsumerId, kConsumerId}});
502 
503   GroupTfEvents(&raw_space);
504   int num_events = 0;
505   CreateTfXPlaneVisitor(raw_plane).ForEachLine(
506       [&](const tensorflow::profiler::XLineVisitor& line) {
507         num_events += line.NumEvents();
508         line.ForEachEvent(
509             [&](const tensorflow::profiler::XEventVisitor& event) {
510               absl::optional<int64_t> group_id;
511               if (absl::optional<XStatVisitor> stat =
512                       event.GetStat(StatType::kGroupId)) {
513                 group_id = stat->IntValue();
514               }
515               if (event.Type() == HostEventType::kExecutorStateProcess) {
516                 EXPECT_FALSE(group_id.has_value());
517               } else {
518                 EXPECT_TRUE(group_id.has_value());
519                 EXPECT_EQ(*group_id, 0);
520               }
521             });
522       });
523   EXPECT_EQ(num_events, 3);
524 }
525 
TEST(GroupEventsTest,AsyncEventTest)526 TEST(GroupEventsTest, AsyncEventTest) {
527   constexpr int64_t kIsRoot = 1;
528   constexpr int64_t kIsAsync = 1;
529   constexpr absl::string_view kParent = "parent";
530   constexpr absl::string_view kAsync = "async";
531   constexpr absl::string_view kChild = "child";
532 
533   XSpace raw_space;
534   XPlane* raw_plane = raw_space.add_planes();
535   XPlaneBuilder plane(raw_plane);
536   plane.ReserveLines(1);
537   auto line = plane.GetOrCreateLine(0);
538   CreateXEvent(&plane, &line, kParent, 0, 100, {{StatType::kIsRoot, kIsRoot}});
539   CreateXEvent(&plane, &line, kAsync, 10, 200,
540                {{StatType::kIsAsync, kIsAsync}});
541   CreateXEvent(&plane, &line, kChild, 20, 80);
542 
543   GroupTfEvents(&raw_space);
544   CreateTfXPlaneVisitor(raw_plane).ForEachLine(
545       [&](const tensorflow::profiler::XLineVisitor& line) {
546         EXPECT_EQ(line.NumEvents(), 3);
547         line.ForEachEvent(
548             [&](const tensorflow::profiler::XEventVisitor& event) {
549               absl::optional<int64_t> group_id;
550               if (absl::optional<XStatVisitor> stat =
551                       event.GetStat(StatType::kGroupId)) {
552                 group_id = stat->IntValue();
553               }
554               if (event.Name() == kAsync) {
555                 EXPECT_FALSE(group_id.has_value());
556               } else {
557                 EXPECT_TRUE(group_id.has_value());
558                 EXPECT_EQ(*group_id, 0);
559               }
560             });
561       });
562 }
563 
TEST(GroupEventsTest,WorkerTest)564 TEST(GroupEventsTest, WorkerTest) {
565   constexpr uint64 kEagerKernelExecuteDuration = 100;
566   constexpr uint64 kFunctionRunDuration = 50;
567   constexpr uint64 kFirstEagerKernelExecuteStartTime = 0;
568   constexpr uint64 kSecondEagerKernelExecuteStartTime = 200;
569   constexpr uint64 kThirdEagerKernelExecuteStartTime = 400;
570   constexpr uint64 kFourthEagerKernelExecuteStartTime = 600;
571   constexpr uint64 kFirstFunctionRunStartTime = 210;
572   constexpr uint64 kSecondFunctionRunStartTime = 610;
573 
574   XSpace raw_space;
575   XPlane* raw_plane = raw_space.add_planes();
576   XPlaneBuilder plane(raw_plane);
577   plane.ReserveLines(1);
578   auto line = plane.GetOrCreateLine(0);
579   // Eager op. It doesn't belong to any group.
580   CreateXEvent(&plane, &line, HostEventType::kEagerKernelExecute,
581                kFirstEagerKernelExecuteStartTime, kEagerKernelExecuteDuration);
582   // First function. It creates the first group.
583   CreateXEvent(&plane, &line, HostEventType::kEagerKernelExecute,
584                kSecondEagerKernelExecuteStartTime, kEagerKernelExecuteDuration);
585   CreateXEvent(&plane, &line, HostEventType::kFunctionRun,
586                kFirstFunctionRunStartTime, kFunctionRunDuration);
587   // Eager op. It belongs to the first group.
588   CreateXEvent(&plane, &line, HostEventType::kEagerKernelExecute,
589                kThirdEagerKernelExecuteStartTime, kEagerKernelExecuteDuration);
590   // Second function. It creates the second group.
591   CreateXEvent(&plane, &line, HostEventType::kEagerKernelExecute,
592                kFourthEagerKernelExecuteStartTime, kEagerKernelExecuteDuration);
593   CreateXEvent(&plane, &line, HostEventType::kFunctionRun,
594                kSecondFunctionRunStartTime, kFunctionRunDuration);
595 
596   GroupTfEvents(&raw_space);
597   CreateTfXPlaneVisitor(raw_plane).ForEachLine(
598       [&](const tensorflow::profiler::XLineVisitor& line) {
599         EXPECT_EQ(line.NumEvents(), 6);
600         line.ForEachEvent(
601             [&](const tensorflow::profiler::XEventVisitor& event) {
602               absl::optional<int64_t> group_id;
603               if (absl::optional<XStatVisitor> stat =
604                       event.GetStat(StatType::kGroupId)) {
605                 group_id = stat->IntValue();
606               }
607               if (event.TimestampPs() < kSecondEagerKernelExecuteStartTime) {
608                 EXPECT_FALSE(group_id.has_value());
609               } else if (event.TimestampPs() <
610                          kFourthEagerKernelExecuteStartTime) {
611                 EXPECT_TRUE(group_id.has_value());
612                 EXPECT_EQ(*group_id, 0);
613               } else {
614                 EXPECT_TRUE(group_id.has_value());
615                 EXPECT_EQ(*group_id, 1);
616               }
617             });
618       });
619 }
620 
TEST(GroupEventsTest,BatchingSessionTest)621 TEST(GroupEventsTest, BatchingSessionTest) {
622   constexpr absl::string_view kSchedule = "Schedule";
623   constexpr int64_t kBatchContextType =
624       static_cast<int64_t>(ContextType::kSharedBatchScheduler);
625   constexpr int64_t kBatchContextId = 123;
626   constexpr int64_t kBatchingSessionRunRootLevel = 1;
627   constexpr int64_t kProcessBatchRootLevel = 2;
628 
629   XSpace raw_space;
630   XPlane* raw_plane = raw_space.add_planes();
631   XPlaneBuilder plane(raw_plane);
632   plane.ReserveLines(2);
633   auto request_thread = plane.GetOrCreateLine(0);
634   // First request.
635   CreateXEvent(&plane, &request_thread, HostEventType::kBatchingSessionRun, 0,
636                100, {{StatType::kIsRoot, kBatchingSessionRunRootLevel}});
637   CreateXEvent(&plane, &request_thread, kSchedule, 0, 100,
638                {{StatType::kProducerType, kBatchContextType},
639                 {StatType::kProducerId, kBatchContextId}});
640   // Second request.
641   CreateXEvent(&plane, &request_thread, HostEventType::kBatchingSessionRun, 200,
642                100, {{StatType::kIsRoot, kBatchingSessionRunRootLevel}});
643   CreateXEvent(&plane, &request_thread, kSchedule, 200, 100,
644                {{StatType::kProducerType, kBatchContextType},
645                 {StatType::kProducerId, kBatchContextId}});
646   auto batch_thread = plane.GetOrCreateLine(1);
647   CreateXEvent(&plane, &batch_thread, HostEventType::kProcessBatch, 200, 100,
648                {{StatType::kConsumerType, kBatchContextType},
649                 {StatType::kConsumerId, kBatchContextId},
650                 {StatType::kIsRoot, kProcessBatchRootLevel}});
651 
652   EventForest event_forest;
653   GroupTfEvents(&raw_space, &event_forest);
654   const GroupMetadataMap& group_metadata_map =
655       event_forest.GetGroupMetadataMap();
656   EXPECT_EQ(group_metadata_map.size(), 3);
657   // Check that the ProcessBatch group has two BatchingSessionRun groups as
658   // parents.
659   EXPECT_EQ(group_metadata_map.at(0).parents.size(), 2);
660   // Check that the BatchingSessionRun groups have one ProcessBatch group as a
661   // child.
662   EXPECT_EQ(group_metadata_map.at(1).children.size(), 1);
663   EXPECT_EQ(group_metadata_map.at(2).children.size(), 1);
664   // Check that the events have the selected_group_ids stat set.
665   uint64 num_checked = 0;
666   CreateTfXPlaneVisitor(raw_plane).ForEachLine(
667       [&](const tensorflow::profiler::XLineVisitor& line) {
668         line.ForEachEvent(
669             [&](const tensorflow::profiler::XEventVisitor& event) {
670               absl::optional<int64_t> group_id;
671               if (absl::optional<XStatVisitor> stat =
672                       event.GetStat(StatType::kGroupId)) {
673                 group_id = stat->IntValue();
674               }
675               EXPECT_TRUE(group_id.has_value());
676               if (line.Id() == 0 &&
677                   event.Type() == HostEventType::kBatchingSessionRun) {
678                 ++num_checked;
679               } else if (line.Id() == 1 &&
680                          event.Type() == HostEventType::kProcessBatch) {
681                 ++num_checked;
682               }
683             });
684       });
685   EXPECT_EQ(num_checked, 3);
686 }
687 
688 }  // namespace
689 }  // namespace profiler
690 }  // namespace tensorflow
691