xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_script_profile.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/Optional.h>
4 #include <test/cpp/jit/test_utils.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/ir/irparser.h>
7 #include <torch/csrc/jit/runtime/script_profile.h>
8 
9 namespace torch {
10 namespace jit {
11 
TEST(ScriptProfileTest,Basic)12 TEST(ScriptProfileTest, Basic) {
13   const std::string source_string = R"V0G0N(
14     def foo(a, b):
15       return a + b #
16   )V0G0N";
17   auto begin = source_string.find("return");
18   auto end = source_string.find(" #");
19 
20   Graph g;
21   const auto graph_string = R"IR(
22     graph(%a : Tensor,
23           %b : Tensor):
24       %2 : int = prim::Constant[value=1]()
25       %3 : Tensor = aten::add(%a, %b, %2)
26       return (%3))IR";
27 
28   torch::jit::parseIR(graph_string, &g);
29   auto source = std::make_shared<Source>(source_string, "", 0);
30   auto node = *g.nodes().begin();
31   node->setSourceRange(SourceRange{source, begin, end});
32 
33   ScriptProfile p;
34   p.enable();
35   {
36     profiling::InstructionSpan g0(*node);
37     profiling::InstructionSpan g1(*node);
38     profiling::InstructionSpan g2(*node);
39   }
40   p.disable();
41 
42   auto stats = p.dumpStats();
43   EXPECT_EQ(stats.size(), 1);
44   auto it = stats.find(*source.get());
45   EXPECT_NE(it, stats.end());
46   auto& lines = it->second;
47   EXPECT_EQ(lines.size(), 1);
48   const auto& stat = lines.at(source->lineno_for_offset(begin));
49   EXPECT_EQ(stat.count, 3);
50 }
51 
TEST(ScriptProfileTest,CallingOrder)52 TEST(ScriptProfileTest, CallingOrder) {
53   ScriptProfile p;
54   p.enable();
55   EXPECT_THROW(p.dumpStats(), c10::Error);
56   p.disable();
57   auto dp = std::make_shared<profiling::Datapoint>(SourceRange{});
58   EXPECT_THROW(p.addDatapoint(std::move(dp)), c10::Error);
59 }
60 
61 } // namespace jit
62 } // namespace torch
63