1 #include <gtest/gtest.h>
2
3 #include <c10/core/ScalarType.h>
4 #include <c10/util/Exception.h>
5 #include <torch/csrc/lazy/core/config.h>
6 #include <torch/csrc/lazy/core/debug_util.h>
7 #include <torch/csrc/lazy/core/dynamic_ir.h>
8 #include <torch/csrc/lazy/core/ir.h>
9 #include <torch/csrc/lazy/core/ir_builder.h>
10 #include <torch/csrc/lazy/core/ir_metadata.h>
11 #include <torch/csrc/lazy/generated/LazyIr.h>
12 #include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
13 #include <torch/csrc/lazy/ts_backend/ts_node.h>
14 #include <memory>
15
16 namespace torch {
17 namespace lazy {
18
19 class TestLeafNode : public Node {
20 public:
ClassOpKind()21 static OpKind ClassOpKind() {
22 return OpKind();
23 }
24
TestLeafNode(size_t param)25 explicit TestLeafNode(size_t param)
26 : Node(ClassOpKind(), /* num_outputs */ 1), hash_(Hash(param)) {}
27 ~TestLeafNode() override = default;
28
operands() const29 const std::vector<Output>& operands() const override {
30 TORCH_INTERNAL_ASSERT(false, "Can't access operands of leaf node");
31 }
32
operand(size_t i) const33 const Output& operand(size_t i) const override {
34 TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of leaf node");
35 }
36
hash() const37 hash_t hash() const override {
38 return hash_;
39 }
shapeHash() const40 hash_t shapeHash() const override {
41 return hash_;
42 }
43
44 private:
45 hash_t hash_;
46 };
47
TEST(IrTest,BasicTest)48 TEST(IrTest, BasicTest) {
49 NodePtr node1 = MakeNode<TestLeafNode>(1);
50 NodePtr node2 = MakeNode<TestLeafNode>(2);
51 EXPECT_NE(node1->hash(), node2->hash());
52
53 EXPECT_EQ(node1->num_outputs(), 1);
54
55 const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get());
56 EXPECT_TRUE(leafptr != nullptr);
57 }
58
TEST(IrTest,MetaDataTest)59 TEST(IrTest, MetaDataTest) {
60 bool restore_FLAGS_torch_lazy_ir_debug = FLAGS_torch_lazy_ir_debug;
61 FLAGS_torch_lazy_ir_debug = false;
62 NodePtr node = MakeNode<TestLeafNode>(1);
63 auto metaWithoutDebug = node->metadata();
64 EXPECT_EQ(metaWithoutDebug.scope.size(), 0);
65 EXPECT_EQ(metaWithoutDebug.frame_info.size(), 0);
66
67 FLAGS_torch_lazy_ir_debug = true;
68 node = MakeNode<TestLeafNode>(1);
69 auto metaWithEmptyDebug = node->metadata();
70 EXPECT_EQ(metaWithEmptyDebug.scope.size(), 0);
71 EXPECT_EQ(metaWithEmptyDebug.frame_info.size(), 1);
72
73 {
74 ScopePusher scope("TestScope");
75 node = MakeNode<TestLeafNode>(1);
76 auto metaWithScope = node->metadata();
77 EXPECT_EQ(metaWithScope.scope, "TestScope.1");
78 EXPECT_EQ(metaWithScope.frame_info.size(), 1);
79 }
80
81 SourceLocation dummySourceLocation;
82 dummySourceLocation.file = "file";
83 dummySourceLocation.function = "function";
84 dummySourceLocation.line = 10;
85 GetPythonFramesFunction() = [&]() -> std::vector<SourceLocation> {
86 return {dummySourceLocation};
87 };
88 node = MakeNode<TestLeafNode>(1);
89 auto metaWithSourceLoc = node->metadata();
90 EXPECT_EQ(metaWithSourceLoc.scope.size(), 0);
91 EXPECT_EQ(metaWithSourceLoc.frame_info.size(), 1);
92 EXPECT_EQ(metaWithSourceLoc.frame_info[0].file, "file");
93 EXPECT_EQ(metaWithSourceLoc.frame_info[0].function, "function");
94 EXPECT_EQ(metaWithSourceLoc.frame_info[0].line, 10);
95 FLAGS_torch_lazy_ir_debug = restore_FLAGS_torch_lazy_ir_debug;
96 }
97
TEST(IrTest,TsNodeTest)98 TEST(IrTest, TsNodeTest) {
99 NodePtr node1 = MakeNode<TsNode>(
100 OpKind(at::aten::view),
101 Shape(),
102 /*num_outputs*/ 1,
103 /*hash_seed*/ kHashSeed);
104 NodePtr node2 = MakeNode<TsNode>(
105 OpKind(at::aten::view),
106 Shape(),
107 /*num_outputs*/ 1,
108 /*hash_seed*/ kHashSeed);
109 EXPECT_EQ(node1->hash(), node2->hash());
110
111 EXPECT_EQ(node1->num_outputs(), 1);
112
113 const TsNode* leafptr = dynamic_cast<const TsNode*>(node1.get());
114 EXPECT_TRUE(leafptr != nullptr);
115 }
116
TEST(IrTest,DimensionNodeTest)117 TEST(IrTest, DimensionNodeTest) {
118 const size_t DIM0 = 5;
119 const size_t DIM1 = 8;
120 NodePtr node1 = MakeNode<TsNode>(
121 OpKind(at::aten::view),
122 Shape(c10::kFloat, {DIM0, DIM1}),
123 /*num_outputs*/ 1,
124 /*hash_seed*/ kHashSeed);
125
126 auto size0 =
127 std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 0));
128 auto size1 =
129 std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 1));
130
131 ASSERT_EQ(DIM0, size0->getStaticValue());
132 ASSERT_EQ(DIM1, size1->getStaticValue());
133
134 NodePtr size0_np = size0;
135 auto size0_dn = std::dynamic_pointer_cast<DimensionNode>(size0_np);
136 ASSERT_EQ(DIM0, size0_dn->getStaticValue());
137
138 auto add_dim = std::dynamic_pointer_cast<SizeAdd>(
139 MakeNode<SizeAdd>(Value{size0}, Value{size1}));
140 ASSERT_EQ(DIM0 + DIM1, add_dim->getStaticValue());
141
142 auto mul_dim = std::dynamic_pointer_cast<SizeMul>(
143 MakeNode<SizeMul>(Value{size0}, Value{size1}));
144 ASSERT_EQ(DIM0 * DIM1, mul_dim->getStaticValue());
145 }
146
TEST(IrTest,DimensionIsDynamicTest)147 TEST(IrTest, DimensionIsDynamicTest) {
148 const size_t DIM0 = 5;
149 const size_t DIM1 = 8;
150 const auto shape = Shape(c10::kFloat, {DIM0, DIM1});
151 NodePtr node1 = MakeNode<TsNode>(
152 OpKind(at::aten::view),
153 shape.with_symbolic_dims(std::vector<bool>{true, false}),
154 /*num_outputs*/ 1,
155 /*hash_seed*/ kHashSeed);
156
157 auto size0 =
158 std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 0));
159 auto size1 =
160 std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 1));
161
162 ASSERT_EQ(true, size0->isSymbolic());
163 ASSERT_EQ(false, size1->isSymbolic());
164
165 auto add_dim = std::dynamic_pointer_cast<SizeAdd>(
166 MakeNode<SizeAdd>(Value{size0}, Value{size1}));
167 ASSERT_EQ(true, add_dim->isSymbolic());
168
169 add_dim = std::dynamic_pointer_cast<SizeAdd>(
170 MakeNode<SizeAdd>(Value{size1}, Value{size1}));
171 ASSERT_EQ(false, add_dim->isSymbolic());
172
173 auto mul_dim = std::dynamic_pointer_cast<SizeMul>(
174 MakeNode<SizeMul>(Value{size0}, Value{size0}));
175 ASSERT_EQ(true, mul_dim->isSymbolic());
176 }
177
178 } // namespace lazy
179 } // namespace torch
180