xref: /aosp_15_r20/external/pytorch/test/cpp/lazy/test_ir_util.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/lazy/core/config.h>
5 #include <torch/csrc/lazy/core/ir.h>
6 #include <torch/csrc/lazy/core/ir_builder.h>
7 #include <torch/csrc/lazy/core/ir_metadata.h>
8 #include <torch/csrc/lazy/core/ir_util.h>
9 
10 namespace torch {
11 namespace lazy {
12 
13 class IrUtilNode : public Node {
14  public:
IrUtilNode()15   explicit IrUtilNode() : Node(OpKind(), /* num_outputs */ 1), hash_(Hash(0)) {}
16   ~IrUtilNode() override = default;
17 
AddOperand(Value v)18   void AddOperand(Value v) {
19     if (!v.node) {
20       return;
21     }
22     operands_as_outputs_.emplace_back(v.node.get(), v.index);
23     operands_.push_back(std::move(v.node));
24   }
25 
hash() const26   hash_t hash() const override {
27     return hash_;
28   }
shapeHash() const29   hash_t shapeHash() const override {
30     return hash_;
31   }
32 
33  private:
34   hash_t hash_;
35 };
36 
37 /*  a
38  * / \
39  *b   c
40  * \ /
41  *  d
42  * Post-order: d c b a
43  */
TEST(IrUtilTest,BasicTest)44 TEST(IrUtilTest, BasicTest) {
45   NodePtr a = MakeNode<IrUtilNode>();
46   NodePtr b = MakeNode<IrUtilNode>();
47   NodePtr c = MakeNode<IrUtilNode>();
48   NodePtr d = MakeNode<IrUtilNode>();
49 
50   dynamic_cast<IrUtilNode*>(a.get())->AddOperand(Value(b, 0));
51   dynamic_cast<IrUtilNode*>(a.get())->AddOperand(Value(c, 1));
52   dynamic_cast<IrUtilNode*>(b.get())->AddOperand(Value(d, 0));
53   dynamic_cast<IrUtilNode*>(c.get())->AddOperand(Value(d, 0));
54 
55   auto postorder = Util::ComputePostOrder({a.get()});
56   EXPECT_EQ(postorder.size(), 4);
57   EXPECT_EQ(postorder.at(0), d.get());
58   EXPECT_EQ(postorder.at(1), c.get());
59   EXPECT_EQ(postorder.at(2), b.get());
60   EXPECT_EQ(postorder.at(3), a.get());
61 }
62 
63 /*  a
64  * / \
65  *b---c
66  * Post-order: not valid
67  */
TEST(IrUtilTest,TestCircle)68 TEST(IrUtilTest, TestCircle) {
69   NodePtr a = MakeNode<IrUtilNode>();
70   NodePtr b = MakeNode<IrUtilNode>();
71   NodePtr c = MakeNode<IrUtilNode>();
72 
73   dynamic_cast<IrUtilNode*>(a.get())->AddOperand(Value(b, 0));
74   dynamic_cast<IrUtilNode*>(b.get())->AddOperand(Value(c, 0));
75   dynamic_cast<IrUtilNode*>(c.get())->AddOperand(Value(a, 0));
76 
77   EXPECT_THROW(Util::ComputePostOrder({a.get()}), c10::Error);
78 }
79 
80 } // namespace lazy
81 } // namespace torch
82