xref: /aosp_15_r20/external/pytorch/test/cpp/lazy/test_ir.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/core/ScalarType.h>
4 #include <c10/util/Exception.h>
5 #include <torch/csrc/lazy/core/config.h>
6 #include <torch/csrc/lazy/core/debug_util.h>
7 #include <torch/csrc/lazy/core/dynamic_ir.h>
8 #include <torch/csrc/lazy/core/ir.h>
9 #include <torch/csrc/lazy/core/ir_builder.h>
10 #include <torch/csrc/lazy/core/ir_metadata.h>
11 #include <torch/csrc/lazy/generated/LazyIr.h>
12 #include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
13 #include <torch/csrc/lazy/ts_backend/ts_node.h>
14 #include <memory>
15 
16 namespace torch {
17 namespace lazy {
18 
19 class TestLeafNode : public Node {
20  public:
ClassOpKind()21   static OpKind ClassOpKind() {
22     return OpKind();
23   }
24 
TestLeafNode(size_t param)25   explicit TestLeafNode(size_t param)
26       : Node(ClassOpKind(), /* num_outputs */ 1), hash_(Hash(param)) {}
27   ~TestLeafNode() override = default;
28 
operands() const29   const std::vector<Output>& operands() const override {
30     TORCH_INTERNAL_ASSERT(false, "Can't access operands of leaf node");
31   }
32 
operand(size_t i) const33   const Output& operand(size_t i) const override {
34     TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of leaf node");
35   }
36 
hash() const37   hash_t hash() const override {
38     return hash_;
39   }
shapeHash() const40   hash_t shapeHash() const override {
41     return hash_;
42   }
43 
44  private:
45   hash_t hash_;
46 };
47 
TEST(IrTest,BasicTest)48 TEST(IrTest, BasicTest) {
49   NodePtr node1 = MakeNode<TestLeafNode>(1);
50   NodePtr node2 = MakeNode<TestLeafNode>(2);
51   EXPECT_NE(node1->hash(), node2->hash());
52 
53   EXPECT_EQ(node1->num_outputs(), 1);
54 
55   const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get());
56   EXPECT_TRUE(leafptr != nullptr);
57 }
58 
TEST(IrTest,MetaDataTest)59 TEST(IrTest, MetaDataTest) {
60   bool restore_FLAGS_torch_lazy_ir_debug = FLAGS_torch_lazy_ir_debug;
61   FLAGS_torch_lazy_ir_debug = false;
62   NodePtr node = MakeNode<TestLeafNode>(1);
63   auto metaWithoutDebug = node->metadata();
64   EXPECT_EQ(metaWithoutDebug.scope.size(), 0);
65   EXPECT_EQ(metaWithoutDebug.frame_info.size(), 0);
66 
67   FLAGS_torch_lazy_ir_debug = true;
68   node = MakeNode<TestLeafNode>(1);
69   auto metaWithEmptyDebug = node->metadata();
70   EXPECT_EQ(metaWithEmptyDebug.scope.size(), 0);
71   EXPECT_EQ(metaWithEmptyDebug.frame_info.size(), 1);
72 
73   {
74     ScopePusher scope("TestScope");
75     node = MakeNode<TestLeafNode>(1);
76     auto metaWithScope = node->metadata();
77     EXPECT_EQ(metaWithScope.scope, "TestScope.1");
78     EXPECT_EQ(metaWithScope.frame_info.size(), 1);
79   }
80 
81   SourceLocation dummySourceLocation;
82   dummySourceLocation.file = "file";
83   dummySourceLocation.function = "function";
84   dummySourceLocation.line = 10;
85   GetPythonFramesFunction() = [&]() -> std::vector<SourceLocation> {
86     return {dummySourceLocation};
87   };
88   node = MakeNode<TestLeafNode>(1);
89   auto metaWithSourceLoc = node->metadata();
90   EXPECT_EQ(metaWithSourceLoc.scope.size(), 0);
91   EXPECT_EQ(metaWithSourceLoc.frame_info.size(), 1);
92   EXPECT_EQ(metaWithSourceLoc.frame_info[0].file, "file");
93   EXPECT_EQ(metaWithSourceLoc.frame_info[0].function, "function");
94   EXPECT_EQ(metaWithSourceLoc.frame_info[0].line, 10);
95   FLAGS_torch_lazy_ir_debug = restore_FLAGS_torch_lazy_ir_debug;
96 }
97 
TEST(IrTest,TsNodeTest)98 TEST(IrTest, TsNodeTest) {
99   NodePtr node1 = MakeNode<TsNode>(
100       OpKind(at::aten::view),
101       Shape(),
102       /*num_outputs*/ 1,
103       /*hash_seed*/ kHashSeed);
104   NodePtr node2 = MakeNode<TsNode>(
105       OpKind(at::aten::view),
106       Shape(),
107       /*num_outputs*/ 1,
108       /*hash_seed*/ kHashSeed);
109   EXPECT_EQ(node1->hash(), node2->hash());
110 
111   EXPECT_EQ(node1->num_outputs(), 1);
112 
113   const TsNode* leafptr = dynamic_cast<const TsNode*>(node1.get());
114   EXPECT_TRUE(leafptr != nullptr);
115 }
116 
TEST(IrTest,DimensionNodeTest)117 TEST(IrTest, DimensionNodeTest) {
118   const size_t DIM0 = 5;
119   const size_t DIM1 = 8;
120   NodePtr node1 = MakeNode<TsNode>(
121       OpKind(at::aten::view),
122       Shape(c10::kFloat, {DIM0, DIM1}),
123       /*num_outputs*/ 1,
124       /*hash_seed*/ kHashSeed);
125 
126   auto size0 =
127       std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 0));
128   auto size1 =
129       std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 1));
130 
131   ASSERT_EQ(DIM0, size0->getStaticValue());
132   ASSERT_EQ(DIM1, size1->getStaticValue());
133 
134   NodePtr size0_np = size0;
135   auto size0_dn = std::dynamic_pointer_cast<DimensionNode>(size0_np);
136   ASSERT_EQ(DIM0, size0_dn->getStaticValue());
137 
138   auto add_dim = std::dynamic_pointer_cast<SizeAdd>(
139       MakeNode<SizeAdd>(Value{size0}, Value{size1}));
140   ASSERT_EQ(DIM0 + DIM1, add_dim->getStaticValue());
141 
142   auto mul_dim = std::dynamic_pointer_cast<SizeMul>(
143       MakeNode<SizeMul>(Value{size0}, Value{size1}));
144   ASSERT_EQ(DIM0 * DIM1, mul_dim->getStaticValue());
145 }
146 
TEST(IrTest,DimensionIsDynamicTest)147 TEST(IrTest, DimensionIsDynamicTest) {
148   const size_t DIM0 = 5;
149   const size_t DIM1 = 8;
150   const auto shape = Shape(c10::kFloat, {DIM0, DIM1});
151   NodePtr node1 = MakeNode<TsNode>(
152       OpKind(at::aten::view),
153       shape.with_symbolic_dims(std::vector<bool>{true, false}),
154       /*num_outputs*/ 1,
155       /*hash_seed*/ kHashSeed);
156 
157   auto size0 =
158       std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 0));
159   auto size1 =
160       std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 1));
161 
162   ASSERT_EQ(true, size0->isSymbolic());
163   ASSERT_EQ(false, size1->isSymbolic());
164 
165   auto add_dim = std::dynamic_pointer_cast<SizeAdd>(
166       MakeNode<SizeAdd>(Value{size0}, Value{size1}));
167   ASSERT_EQ(true, add_dim->isSymbolic());
168 
169   add_dim = std::dynamic_pointer_cast<SizeAdd>(
170       MakeNode<SizeAdd>(Value{size1}, Value{size1}));
171   ASSERT_EQ(false, add_dim->isSymbolic());
172 
173   auto mul_dim = std::dynamic_pointer_cast<SizeMul>(
174       MakeNode<SizeMul>(Value{size0}, Value{size0}));
175   ASSERT_EQ(true, mul_dim->isSymbolic());
176 }
177 
178 } // namespace lazy
179 } // namespace torch
180