xref: /aosp_15_r20/external/pytorch/torch/fx/passes/tests/test_pass_manager.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import unittest
2
3from ..pass_manager import (
4    inplace_wrapper,
5    PassManager,
6    these_before_those_pass_constraint,
7    this_before_that_pass_constraint,
8)
9
10
11class TestPassManager(unittest.TestCase):
12    def test_pass_manager_builder(self) -> None:
13        passes = [lambda x: 2 * x for _ in range(10)]
14        pm = PassManager(passes)
15        pm.validate()
16
17    def test_this_before_that_pass_constraint(self) -> None:
18        passes = [lambda x: 2 * x for _ in range(10)]
19        pm = PassManager(passes)
20
21        # add unfulfillable constraint
22        pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0]))
23
24        self.assertRaises(RuntimeError, pm.validate)
25
26    def test_these_before_those_pass_constraint(self) -> None:
27        passes = [lambda x: 2 * x for _ in range(10)]
28        constraint = these_before_those_pass_constraint(passes[-1], passes[0])
29        pm = PassManager(
30            [inplace_wrapper(p) for p in passes]
31        )
32
33        # add unfulfillable constraint
34        pm.add_constraint(constraint)
35
36        self.assertRaises(RuntimeError, pm.validate)
37
38    def test_two_pass_managers(self) -> None:
39        """Make sure we can construct the PassManager twice and not share any
40        state between them"""
41
42        passes = [lambda x: 2 * x for _ in range(3)]
43        constraint = these_before_those_pass_constraint(passes[0], passes[1])
44        pm1 = PassManager()
45        for p in passes:
46            pm1.add_pass(p)
47        pm1.add_constraint(constraint)
48        output1 = pm1(1)
49        self.assertEqual(output1, 2 ** 3)
50
51        passes = [lambda x: 3 * x for _ in range(3)]
52        constraint = these_before_those_pass_constraint(passes[0], passes[1])
53        pm2 = PassManager()
54        for p in passes:
55            pm2.add_pass(p)
56        pm2.add_constraint(constraint)
57        output2 = pm2(1)
58        self.assertEqual(output2, 3 ** 3)
59