xref: /aosp_15_r20/external/pytorch/test/test_fx_passes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: fx.passes"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
4*da0073e9SAndroid Build Coastguard Workerimport operator
5*da0073e9SAndroid Build Coastguard Workerimport logging
6*da0073e9SAndroid Build Coastguard Workerimport sys
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerfrom torch.fx._symbolic_trace import symbolic_trace
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
12*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.operator_support import OperatorSupport
13*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.utils.fuser_utils import fuse_by_partitions
14*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.utils.matcher_utils import SubgraphMatcher
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerlogging.basicConfig(level=logging.WARNING)
20*da0073e9SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__)
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerclass TestModule(torch.nn.Module):
23*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
24*da0073e9SAndroid Build Coastguard Worker        super().__init__()
25*da0073e9SAndroid Build Coastguard Worker        self.linear = torch.nn.Linear(4, 4)
26*da0073e9SAndroid Build Coastguard Worker        self.linear2 = torch.nn.Linear(4, 4)
27*da0073e9SAndroid Build Coastguard Worker        self.param = torch.nn.Parameter(torch.rand(4, 4))
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    def forward(self, a, b, c):
30*da0073e9SAndroid Build Coastguard Worker        add = a + b
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker        linear_1 = self.linear(add)
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker        add_1 = add + c
35*da0073e9SAndroid Build Coastguard Worker        add_2 = add_1 + self.param
36*da0073e9SAndroid Build Coastguard Worker        add_3 = add_1 + linear_1
37*da0073e9SAndroid Build Coastguard Worker        add_4 = add_2 + add_3
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker        linear_2 = self.linear2(add_4)
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker        add_5 = linear_2 + add_4
42*da0073e9SAndroid Build Coastguard Worker        add_6 = add_5 + a
43*da0073e9SAndroid Build Coastguard Worker        relu = add_6.relu()
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        return add_4, add_6, relu
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Workerclass TestDeepModule(torch.nn.Module):
48*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
49*da0073e9SAndroid Build Coastguard Worker        super().__init__()
50*da0073e9SAndroid Build Coastguard Worker        self.linear = torch.nn.Linear(4, 4)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    def forward(self, a, b, c):
53*da0073e9SAndroid Build Coastguard Worker        o = a + b
54*da0073e9SAndroid Build Coastguard Worker        o = o + 1.0
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker        # testing to avoid DFS uses in passes. Since Python has max recursion depth.
57*da0073e9SAndroid Build Coastguard Worker        for _ in range(sys.getrecursionlimit() + 1):
58*da0073e9SAndroid Build Coastguard Worker            o = o - c
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker        return o
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Workerclass TestPartitionFunctions:
64*da0073e9SAndroid Build Coastguard Worker    @staticmethod
65*da0073e9SAndroid Build Coastguard Worker    def forward1(a, b, c):
66*da0073e9SAndroid Build Coastguard Worker        add = a + b
67*da0073e9SAndroid Build Coastguard Worker        add_1 = add + b
68*da0073e9SAndroid Build Coastguard Worker        add_2 = add_1 + c
69*da0073e9SAndroid Build Coastguard Worker        relu_1 = add_2.relu()
70*da0073e9SAndroid Build Coastguard Worker        add_3 = add_1 + add_2
71*da0073e9SAndroid Build Coastguard Worker        add_4 = add_1 + relu_1 + add_3
72*da0073e9SAndroid Build Coastguard Worker        relu_2 = add_4.relu()
73*da0073e9SAndroid Build Coastguard Worker        add_5 = relu_2 + add_4
74*da0073e9SAndroid Build Coastguard Worker        add_6 = add_5 + add_4
75*da0073e9SAndroid Build Coastguard Worker        return add_4, add_6
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    @staticmethod
78*da0073e9SAndroid Build Coastguard Worker    def forward2(a, b, _):
79*da0073e9SAndroid Build Coastguard Worker        add = a + b
80*da0073e9SAndroid Build Coastguard Worker        add_1 = add + b
81*da0073e9SAndroid Build Coastguard Worker        relu_1 = add_1.relu()  # blocked by this
82*da0073e9SAndroid Build Coastguard Worker        add_3 = add_1 + relu_1
83*da0073e9SAndroid Build Coastguard Worker        add_4 = add_1 + add_3
84*da0073e9SAndroid Build Coastguard Worker        return add_4, add_1
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    @staticmethod
87*da0073e9SAndroid Build Coastguard Worker    def forward3(a, b, c):
88*da0073e9SAndroid Build Coastguard Worker        add = a + b
89*da0073e9SAndroid Build Coastguard Worker        add_1 = a + c
90*da0073e9SAndroid Build Coastguard Worker        add_2 = b + c
91*da0073e9SAndroid Build Coastguard Worker        return add, add_1, add_2
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    @staticmethod
94*da0073e9SAndroid Build Coastguard Worker    def forward4(a, b, c):
95*da0073e9SAndroid Build Coastguard Worker        add = a + b
96*da0073e9SAndroid Build Coastguard Worker        add_1 = a + c
97*da0073e9SAndroid Build Coastguard Worker        add_2 = b + c
98*da0073e9SAndroid Build Coastguard Worker        return torch.where(add > 0, add_1, add_2)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    @staticmethod
101*da0073e9SAndroid Build Coastguard Worker    def forward5(a, b, c):
102*da0073e9SAndroid Build Coastguard Worker        # add should be fused right branch, as left branch is not supported
103*da0073e9SAndroid Build Coastguard Worker        add = a + 1
104*da0073e9SAndroid Build Coastguard Worker        # left branch
105*da0073e9SAndroid Build Coastguard Worker        relu = add.relu()
106*da0073e9SAndroid Build Coastguard Worker        # right branch
107*da0073e9SAndroid Build Coastguard Worker        add_1 = add + 2
108*da0073e9SAndroid Build Coastguard Worker        return relu, add_1
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    @staticmethod
111*da0073e9SAndroid Build Coastguard Worker    def forward6(a, b, c):
112*da0073e9SAndroid Build Coastguard Worker        # add should have its own partition, as neither branchs are supported
113*da0073e9SAndroid Build Coastguard Worker        add = a + 1
114*da0073e9SAndroid Build Coastguard Worker        # left branch
115*da0073e9SAndroid Build Coastguard Worker        relu = add.relu()
116*da0073e9SAndroid Build Coastguard Worker        # right branch
117*da0073e9SAndroid Build Coastguard Worker        relu_1 = add.relu()
118*da0073e9SAndroid Build Coastguard Worker        return relu, relu_1
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    @staticmethod
121*da0073e9SAndroid Build Coastguard Worker    def forward7(a, b, c):
122*da0073e9SAndroid Build Coastguard Worker        # both branches are supported, all adds should be fused together
123*da0073e9SAndroid Build Coastguard Worker        add = a + 1
124*da0073e9SAndroid Build Coastguard Worker        # left branch
125*da0073e9SAndroid Build Coastguard Worker        add_1 = add + 2
126*da0073e9SAndroid Build Coastguard Worker        # right branch is larger
127*da0073e9SAndroid Build Coastguard Worker        add_2 = add + 1
128*da0073e9SAndroid Build Coastguard Worker        add_3 = add_2 + 1
129*da0073e9SAndroid Build Coastguard Worker        return add_3, add_1
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    @staticmethod
132*da0073e9SAndroid Build Coastguard Worker    def forward8(a, b, c):
133*da0073e9SAndroid Build Coastguard Worker        # both branches are in the same partition, add should join the same partition
134*da0073e9SAndroid Build Coastguard Worker        add = a + 1
135*da0073e9SAndroid Build Coastguard Worker        # left branch
136*da0073e9SAndroid Build Coastguard Worker        add_1 = add + 2
137*da0073e9SAndroid Build Coastguard Worker        # right branch
138*da0073e9SAndroid Build Coastguard Worker        add_2 = add + 1
139*da0073e9SAndroid Build Coastguard Worker        # left and right branch merges
140*da0073e9SAndroid Build Coastguard Worker        add_3 = add_2 + add_1
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        return add_3
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    @staticmethod
145*da0073e9SAndroid Build Coastguard Worker    def forward9(a, b, c):
146*da0073e9SAndroid Build Coastguard Worker        add = a + 1
147*da0073e9SAndroid Build Coastguard Worker        # branch 1
148*da0073e9SAndroid Build Coastguard Worker        add_1 = add + 1
149*da0073e9SAndroid Build Coastguard Worker        # branch 2
150*da0073e9SAndroid Build Coastguard Worker        add_2 = add + 1
151*da0073e9SAndroid Build Coastguard Worker        # branch_3
152*da0073e9SAndroid Build Coastguard Worker        add_3 = add + 1
153*da0073e9SAndroid Build Coastguard Worker        out = torch.stack([add_1, add_2, add_3])
154*da0073e9SAndroid Build Coastguard Worker        return out
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker    @staticmethod
157*da0073e9SAndroid Build Coastguard Worker    def forward10(a, b, c):
158*da0073e9SAndroid Build Coastguard Worker        add = a + 1
159*da0073e9SAndroid Build Coastguard Worker        # branch 1
160*da0073e9SAndroid Build Coastguard Worker        add_1 = add + 1
161*da0073e9SAndroid Build Coastguard Worker        # branch 2
162*da0073e9SAndroid Build Coastguard Worker        add_2 = add + 1
163*da0073e9SAndroid Build Coastguard Worker        # branch 3: depends on branch 2
164*da0073e9SAndroid Build Coastguard Worker        add_3 = add + add_2
165*da0073e9SAndroid Build Coastguard Worker        out = torch.stack([add_1, add_2, add_3])
166*da0073e9SAndroid Build Coastguard Worker        return out
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    @staticmethod
169*da0073e9SAndroid Build Coastguard Worker    def forward11(a, b, c):
170*da0073e9SAndroid Build Coastguard Worker        add = a + 1
171*da0073e9SAndroid Build Coastguard Worker        # branch 1
172*da0073e9SAndroid Build Coastguard Worker        add_1 = add.relu()
173*da0073e9SAndroid Build Coastguard Worker        # branch 2 depends on branch 1
174*da0073e9SAndroid Build Coastguard Worker        add_2 = add + add_1
175*da0073e9SAndroid Build Coastguard Worker        # branch 3
176*da0073e9SAndroid Build Coastguard Worker        add_3 = add.relu()
177*da0073e9SAndroid Build Coastguard Worker        out = torch.stack([add_1, add_2, add_3])
178*da0073e9SAndroid Build Coastguard Worker        return out
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    @staticmethod
181*da0073e9SAndroid Build Coastguard Worker    def forward12(a, b, c):
182*da0073e9SAndroid Build Coastguard Worker        b0 = a + 1.0
183*da0073e9SAndroid Build Coastguard Worker        c0 = a + 1.5
184*da0073e9SAndroid Build Coastguard Worker        x0 = b0.relu()
185*da0073e9SAndroid Build Coastguard Worker        x1 = c0.relu()
186*da0073e9SAndroid Build Coastguard Worker        b1 = b0 + x1
187*da0073e9SAndroid Build Coastguard Worker        c1 = c0 + 1.2
188*da0073e9SAndroid Build Coastguard Worker        # c2 has dependency on x0 & b0, when we merge {c0, c1, c2}
189*da0073e9SAndroid Build Coastguard Worker        # this dependency should be updated to the fusion group and reflected
190*da0073e9SAndroid Build Coastguard Worker        # on the decision to not fuse b0 & b1, which forms a cyclic dependency in
191*da0073e9SAndroid Build Coastguard Worker        # the new graph
192*da0073e9SAndroid Build Coastguard Worker        c2 = x0 + c0
193*da0073e9SAndroid Build Coastguard Worker        return b1, c2
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    @staticmethod
196*da0073e9SAndroid Build Coastguard Worker    def forward13(a, b, c):
197*da0073e9SAndroid Build Coastguard Worker        a0, a1, a2, a3 = a.split(1, 0)
198*da0073e9SAndroid Build Coastguard Worker        b1 = a0 + b
199*da0073e9SAndroid Build Coastguard Worker        c1 = a1 + c
200*da0073e9SAndroid Build Coastguard Worker        return b1 + c1
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    @staticmethod
203*da0073e9SAndroid Build Coastguard Worker    def forward14(a, b, c):
204*da0073e9SAndroid Build Coastguard Worker        a0, a1 = torch.ops.aten.std_mean(a)
205*da0073e9SAndroid Build Coastguard Worker        out = a0 + 1.0
206*da0073e9SAndroid Build Coastguard Worker        return out
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    @staticmethod
209*da0073e9SAndroid Build Coastguard Worker    def forward15(a, b, c):
210*da0073e9SAndroid Build Coastguard Worker        a0 = torch.ops.aten.view(a, [2, 2])
211*da0073e9SAndroid Build Coastguard Worker        a1 = torch.ops.aten.permute(a0, [1, 0])
212*da0073e9SAndroid Build Coastguard Worker        a2 = a1 + 1.0
213*da0073e9SAndroid Build Coastguard Worker        a3 = torch.ops.aten.permute(a2, [1, 0])
214*da0073e9SAndroid Build Coastguard Worker        a4 = a3 + 1.0
215*da0073e9SAndroid Build Coastguard Worker        a5 = torch.ops.aten.permute(a4, [1, 0])
216*da0073e9SAndroid Build Coastguard Worker        return torch.ops.aten.permute(a5, [1, 0])
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker    @staticmethod
219*da0073e9SAndroid Build Coastguard Worker    def forward16(a, b, c):
220*da0073e9SAndroid Build Coastguard Worker        a0 = a - 1.0
221*da0073e9SAndroid Build Coastguard Worker        a1 = torch.ops.aten.view(a0, [2, 2])
222*da0073e9SAndroid Build Coastguard Worker        a2 = torch.ops.aten.permute(a1, [1, 0])
223*da0073e9SAndroid Build Coastguard Worker        a3 = a2 + 1.0
224*da0073e9SAndroid Build Coastguard Worker        a4 = torch.ops.aten.permute(a3, [1, 0])
225*da0073e9SAndroid Build Coastguard Worker        a5 = a4 + 1.0
226*da0073e9SAndroid Build Coastguard Worker        a6 = torch.ops.aten.permute(a5, [1, 0])
227*da0073e9SAndroid Build Coastguard Worker        a7 = torch.ops.aten.permute(a6, [1, 0])
228*da0073e9SAndroid Build Coastguard Worker        return a7 - 1.0
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker    @staticmethod
231*da0073e9SAndroid Build Coastguard Worker    def forward17(a, b, c, d, e, f):
232*da0073e9SAndroid Build Coastguard Worker        a0 = a + b
233*da0073e9SAndroid Build Coastguard Worker        a1 = c + d
234*da0073e9SAndroid Build Coastguard Worker        a2 = e + f
235*da0073e9SAndroid Build Coastguard Worker        return a0, a1, a2
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker    @staticmethod
238*da0073e9SAndroid Build Coastguard Worker    def forward18(a, b, c):
239*da0073e9SAndroid Build Coastguard Worker        a0, a1 = torch.ops.aten.var_mean(a)
240*da0073e9SAndroid Build Coastguard Worker        return a0
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker# A mock OperatorSupport class, where only operator.add is supported
243*da0073e9SAndroid Build Coastguard Workerclass MockOperatorSupport(OperatorSupport):
244*da0073e9SAndroid Build Coastguard Worker    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
245*da0073e9SAndroid Build Coastguard Worker        return (node.op == "call_function" and
246*da0073e9SAndroid Build Coastguard Worker                node.target in {operator.add, operator.getitem,
247*da0073e9SAndroid Build Coastguard Worker                                torch.ops.aten.view,
248*da0073e9SAndroid Build Coastguard Worker                                torch.ops.aten.permute,
249*da0073e9SAndroid Build Coastguard Worker                                torch.ops.aten.std_mean})
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker@instantiate_parametrized_tests
252*da0073e9SAndroid Build Coastguard Workerclass TestFXGraphPasses(JitTestCase):
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker    @parametrize("fn, expected_partition, bookend_non_compute_pass", [
255*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]], False),
256*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]], False),
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker        # 1 horizontal fusion with common producer
259*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]], False),
260*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]], False),
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        # 2 branches cases
263*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward5, [["add_1", "add"]], False),
264*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward6, [["add"]], False),
265*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]], False),
266*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]], False),
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        # 3 branch cases
269*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']], False),
270*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']], False),
271*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward11, [['add_1'], ['add']], False),
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker        # 4 not necessarily the only partition, just to verify that there's no cyclic dependency after partition
274*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward12, [["add_2", "add_3", "add_4"], ["add", "add_1"]], False),
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        # 5 getitem special case
277*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]], False),
278*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward14, [["add", "std_mean", "getitem", "getitem_1"]], False),
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        # 6 bookend non_compute pass
281*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward15, [["permute_1", "add_1", "add"]], True),
282*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward15, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False),
283*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward16, [["permute_1", "add_1", "add"]], True),
284*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward16, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False),
285*da0073e9SAndroid Build Coastguard Worker        # should be empty partition, not a partiton with empty nodes
286*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward18, [], False),
287*da0073e9SAndroid Build Coastguard Worker    ])
288*da0073e9SAndroid Build Coastguard Worker    def test_partitioner(self, fn, expected_partition, bookend_non_compute_pass):
289*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(fn)
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker        non_compute_ops = []
292*da0073e9SAndroid Build Coastguard Worker        if bookend_non_compute_pass:
293*da0073e9SAndroid Build Coastguard Worker            non_compute_ops = ["torch.ops.aten.view", "torch.ops.aten.permute"]
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker        supported_ops = MockOperatorSupport()
296*da0073e9SAndroid Build Coastguard Worker        partitioner = CapabilityBasedPartitioner(traced,
297*da0073e9SAndroid Build Coastguard Worker                                                 supported_ops,
298*da0073e9SAndroid Build Coastguard Worker                                                 allows_single_node_partition=True,
299*da0073e9SAndroid Build Coastguard Worker                                                 non_compute_ops=non_compute_ops)
300*da0073e9SAndroid Build Coastguard Worker        partitions = partitioner.propose_partitions()
301*da0073e9SAndroid Build Coastguard Worker        if bookend_non_compute_pass:
302*da0073e9SAndroid Build Coastguard Worker            partitioner.remove_bookend_non_compute_ops(partitions)
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
305*da0073e9SAndroid Build Coastguard Worker        assert len(partitions_name) == len(expected_partition)
306*da0073e9SAndroid Build Coastguard Worker        for i in range(len(partitions_name)):
307*da0073e9SAndroid Build Coastguard Worker            assert set(partitions_name[i]) == set(expected_partition[i])
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        fused_graph = partitioner.fuse_partitions(partitions)
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker        a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        expected = fn(a, b, c)
314*da0073e9SAndroid Build Coastguard Worker        result = fused_graph(a, b, c)
315*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(expected, result)
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker    @parametrize("fn, expected_partition", [
318*da0073e9SAndroid Build Coastguard Worker        (TestPartitionFunctions.forward17, [['add', 'add_1', 'add_2']]),
319*da0073e9SAndroid Build Coastguard Worker    ])
320*da0073e9SAndroid Build Coastguard Worker    def test_partitioner_independent_output(self, fn, expected_partition):
321*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(fn)
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker        supported_ops = MockOperatorSupport()
324*da0073e9SAndroid Build Coastguard Worker        partitioner = CapabilityBasedPartitioner(traced,
325*da0073e9SAndroid Build Coastguard Worker                                                 supported_ops,
326*da0073e9SAndroid Build Coastguard Worker                                                 allows_single_node_partition=True)
327*da0073e9SAndroid Build Coastguard Worker        partitions = partitioner.propose_partitions()
328*da0073e9SAndroid Build Coastguard Worker        partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
329*da0073e9SAndroid Build Coastguard Worker        assert len(partitions_name) == len(expected_partition)
330*da0073e9SAndroid Build Coastguard Worker        for i in range(len(partitions_name)):
331*da0073e9SAndroid Build Coastguard Worker            assert set(partitions_name[i]) == set(expected_partition[i])
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker        fused_graph = partitioner.fuse_partitions(partitions)
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        a, b, c, d, e, f = torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4)
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker        expected = fn(a, b, c, d, e, f)
338*da0073e9SAndroid Build Coastguard Worker        result = fused_graph(a, b, c, d, e, f)
339*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(expected, result)
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker    @parametrize("partition", [
342*da0073e9SAndroid Build Coastguard Worker        [['add', 'add_1'], ['add_5', 'add_6']],
343*da0073e9SAndroid Build Coastguard Worker        [['add', 'add_1', 'add_2']],  # vertical fusion
344*da0073e9SAndroid Build Coastguard Worker        [['add_2', 'add_3']],         # horizontal fusion
345*da0073e9SAndroid Build Coastguard Worker        [['add_3', 'add_4']],
346*da0073e9SAndroid Build Coastguard Worker        [['add_6', 'add_5']],     # arbitray node order
347*da0073e9SAndroid Build Coastguard Worker        [['add_4', 'add_1', 'add_3', 'add_2']],           # arbitray node order
348*da0073e9SAndroid Build Coastguard Worker        [['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']],  # arbitray partition order
349*da0073e9SAndroid Build Coastguard Worker        [['add_5', 'linear2']],   # includes call_function + call_module node
350*da0073e9SAndroid Build Coastguard Worker        [['add_6', 'relu']],   # includes call_function + call_module node
351*da0073e9SAndroid Build Coastguard Worker        [['param', 'add_2']],   # includes get_attr + call_module nodes
352*da0073e9SAndroid Build Coastguard Worker        [['param', 'add_1', 'linear']],   # includes get_attr + call_function + call_module nodes
353*da0073e9SAndroid Build Coastguard Worker        [["add", "linear", "add_1", "param", "add_2", "add_3", "add_4", "linear2", "add_5", "add_6", "relu"]],  # full graph
354*da0073e9SAndroid Build Coastguard Worker    ])
355*da0073e9SAndroid Build Coastguard Worker    def test_fuser_util(self, partition):
356*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
357*da0073e9SAndroid Build Coastguard Worker        gm = symbolic_trace(m)
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker        nodes_by_name = {node.name : node for node in gm.graph.nodes}
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        partitions = []
362*da0073e9SAndroid Build Coastguard Worker        for node_names in partition:
363*da0073e9SAndroid Build Coastguard Worker            partitions.append([nodes_by_name[name] for name in node_names])
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        fused_graph = fuse_by_partitions(gm, partitions)
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker        a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        expected = m(a, b, c)
370*da0073e9SAndroid Build Coastguard Worker        result = fused_graph(a, b, c)
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(expected, result)
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker    @parametrize("partition", [
375*da0073e9SAndroid Build Coastguard Worker        [['add', 'add_1'], ['add_1', 'add_5', 'add_6']],  # add_1 exists in multiple partitions
376*da0073e9SAndroid Build Coastguard Worker        [['add', 'add_1', 'add_3']],    # invalid partition: circular dependency
377*da0073e9SAndroid Build Coastguard Worker        [['add_4', 'add_5']],    # invalid partition: circular dependency
378*da0073e9SAndroid Build Coastguard Worker        [['relu', 'add_5']],    # invalid partition: circular dependency
379*da0073e9SAndroid Build Coastguard Worker    ])
380*da0073e9SAndroid Build Coastguard Worker    def test_fuser_util_xfail(self, partition):
381*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
382*da0073e9SAndroid Build Coastguard Worker        gm = symbolic_trace(m)
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker        nodes_by_name = {node.name : node for node in gm.graph.nodes}
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker        partitions = []
387*da0073e9SAndroid Build Coastguard Worker        for node_names in partition:
388*da0073e9SAndroid Build Coastguard Worker            partitions.append([nodes_by_name[name] for name in node_names])
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
391*da0073e9SAndroid Build Coastguard Worker            fuse_by_partitions(gm, partitions)
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker    def test_fuser_pass_deep_model(self):
394*da0073e9SAndroid Build Coastguard Worker        m = TestDeepModule()
395*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker        supported_ops = MockOperatorSupport()
398*da0073e9SAndroid Build Coastguard Worker        partitioner = CapabilityBasedPartitioner(traced,
399*da0073e9SAndroid Build Coastguard Worker                                                 supported_ops,
400*da0073e9SAndroid Build Coastguard Worker                                                 allows_single_node_partition=True)
401*da0073e9SAndroid Build Coastguard Worker        partitions = partitioner.propose_partitions()
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker@dataclass
404*da0073e9SAndroid Build Coastguard Workerclass TestCase:
405*da0073e9SAndroid Build Coastguard Worker    match_output: bool
406*da0073e9SAndroid Build Coastguard Worker    match_placeholder: bool
407*da0073e9SAndroid Build Coastguard Worker    num_matches: int
408*da0073e9SAndroid Build Coastguard Worker    remove_overlapping_matches: bool = True
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Workerclass SingleNodePattern:
411*da0073e9SAndroid Build Coastguard Worker    @staticmethod
412*da0073e9SAndroid Build Coastguard Worker    def forward(x):
413*da0073e9SAndroid Build Coastguard Worker        val = torch.neg(x)
414*da0073e9SAndroid Build Coastguard Worker        return torch.add(val, val)
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker    @staticmethod
417*da0073e9SAndroid Build Coastguard Worker    def pattern(a):
418*da0073e9SAndroid Build Coastguard Worker        return torch.neg(a)
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker    test_cases = [
421*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
422*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 1),
423*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 0),
424*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 1),
425*da0073e9SAndroid Build Coastguard Worker        TestCase(True, True, 0)
426*da0073e9SAndroid Build Coastguard Worker    ]
427*da0073e9SAndroid Build Coastguard Workerclass SimplePattern:
428*da0073e9SAndroid Build Coastguard Worker    @staticmethod
429*da0073e9SAndroid Build Coastguard Worker    def forward(x, w1, w2):
430*da0073e9SAndroid Build Coastguard Worker        m1 = torch.cat([w1, w2]).sum()
431*da0073e9SAndroid Build Coastguard Worker        m2 = torch.cat([w2, w1]).sum()
432*da0073e9SAndroid Build Coastguard Worker        m3 = torch.cat([m1, m2]).sum()
433*da0073e9SAndroid Build Coastguard Worker        return x + torch.max(m1) + torch.max(m2) + m3
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker    @staticmethod
436*da0073e9SAndroid Build Coastguard Worker    def pattern(a, b):
437*da0073e9SAndroid Build Coastguard Worker        return torch.cat([a, b]).sum()
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker    test_cases = [
440*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
441*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 3),
442*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 0),
443*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 2),
444*da0073e9SAndroid Build Coastguard Worker        TestCase(True, True, 0)
445*da0073e9SAndroid Build Coastguard Worker    ]
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Workerclass SimpleFullGraphMatching:
448*da0073e9SAndroid Build Coastguard Worker    @staticmethod
449*da0073e9SAndroid Build Coastguard Worker    def forward(x):
450*da0073e9SAndroid Build Coastguard Worker        a = torch.neg(x)
451*da0073e9SAndroid Build Coastguard Worker        return torch.add(a, a)
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker    @staticmethod
454*da0073e9SAndroid Build Coastguard Worker    def pattern(x):
455*da0073e9SAndroid Build Coastguard Worker        a = torch.neg(x)
456*da0073e9SAndroid Build Coastguard Worker        return torch.add(a, a)
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker    test_cases = [
459*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
460*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 1),
461*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 1),
462*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 1),
463*da0073e9SAndroid Build Coastguard Worker        TestCase(True, True, 1)
464*da0073e9SAndroid Build Coastguard Worker    ]
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Workerclass DiamondShapePatternTestCase:
467*da0073e9SAndroid Build Coastguard Worker    @staticmethod
468*da0073e9SAndroid Build Coastguard Worker    def forward(x):
469*da0073e9SAndroid Build Coastguard Worker        a = torch.neg(x)
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker        a = a.relu()
472*da0073e9SAndroid Build Coastguard Worker        left = a.sigmoid()
473*da0073e9SAndroid Build Coastguard Worker        right = a.relu()
474*da0073e9SAndroid Build Coastguard Worker        out = left + right
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker        return out
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker    @staticmethod
479*da0073e9SAndroid Build Coastguard Worker    def pattern(a):
480*da0073e9SAndroid Build Coastguard Worker        a = a.relu()
481*da0073e9SAndroid Build Coastguard Worker        left = a.sigmoid()
482*da0073e9SAndroid Build Coastguard Worker        right = a.relu()
483*da0073e9SAndroid Build Coastguard Worker        out = left + right
484*da0073e9SAndroid Build Coastguard Worker        return out
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker    test_cases = [
487*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
488*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 1),
489*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 1),
490*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 0),
491*da0073e9SAndroid Build Coastguard Worker        TestCase(True, True, 0)
492*da0073e9SAndroid Build Coastguard Worker    ]
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Workerclass NonFullyContainedMatches:
495*da0073e9SAndroid Build Coastguard Worker    @staticmethod
496*da0073e9SAndroid Build Coastguard Worker    def forward(x, w1, w2, b1, b2):
497*da0073e9SAndroid Build Coastguard Worker        # fully contained matched subgraph
498*da0073e9SAndroid Build Coastguard Worker        m1 = torch.cat([w1, w2])
499*da0073e9SAndroid Build Coastguard Worker        m2 = torch.cat([x, b2])
500*da0073e9SAndroid Build Coastguard Worker        t0 = torch.addmm(b1, m1, m2.t())
501*da0073e9SAndroid Build Coastguard Worker        t0_sum = torch.sum(t0)   # use of t0 is not leaking
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker        # leaking matched subgraph, m3 is leaked
504*da0073e9SAndroid Build Coastguard Worker        m3 = torch.cat([w1, w2])
505*da0073e9SAndroid Build Coastguard Worker        m4 = torch.cat([x, b2])
506*da0073e9SAndroid Build Coastguard Worker        t1 = torch.addmm(b1, m3, m4.t())
507*da0073e9SAndroid Build Coastguard Worker        m3_sum = torch.sum(m3)
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker        return t0_sum, m3_sum
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker    @staticmethod
512*da0073e9SAndroid Build Coastguard Worker    def pattern(x, w1, w2, b1, b2):
513*da0073e9SAndroid Build Coastguard Worker        m1 = torch.cat([w1, w2])
514*da0073e9SAndroid Build Coastguard Worker        m2 = torch.cat([x, b2])
515*da0073e9SAndroid Build Coastguard Worker        return torch.addmm(b1, m1, m2.t())
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker    test_cases = [
518*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
519*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 1),
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 0),
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 1),     # leaked used of placeholder is not leaking
524*da0073e9SAndroid Build Coastguard Worker    ]
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Workerclass ChainRepeatedPattern:
527*da0073e9SAndroid Build Coastguard Worker    @staticmethod
528*da0073e9SAndroid Build Coastguard Worker    def forward(x):
529*da0073e9SAndroid Build Coastguard Worker        x = torch.sigmoid(x)
530*da0073e9SAndroid Build Coastguard Worker        x = torch.sigmoid(x)
531*da0073e9SAndroid Build Coastguard Worker        x = torch.sigmoid(x)
532*da0073e9SAndroid Build Coastguard Worker        return torch.sigmoid(x)
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker    @staticmethod
535*da0073e9SAndroid Build Coastguard Worker    def pattern(x):
536*da0073e9SAndroid Build Coastguard Worker        return torch.sigmoid(torch.sigmoid(x))
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker    test_cases = [
539*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
540*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 3, remove_overlapping_matches=False),
541*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 2, remove_overlapping_matches=True),
542*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 1),
543*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 1),
544*da0073e9SAndroid Build Coastguard Worker        TestCase(True, True, 0)
545*da0073e9SAndroid Build Coastguard Worker    ]
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Workerclass QuantizationModel:
548*da0073e9SAndroid Build Coastguard Worker    @staticmethod
549*da0073e9SAndroid Build Coastguard Worker    def forward(x):
550*da0073e9SAndroid Build Coastguard Worker        x += 3
551*da0073e9SAndroid Build Coastguard Worker        x = x.dequantize()
552*da0073e9SAndroid Build Coastguard Worker        x = torch.sigmoid(x)
553*da0073e9SAndroid Build Coastguard Worker        x = x.to(torch.float16)
554*da0073e9SAndroid Build Coastguard Worker        return x
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker    @staticmethod
557*da0073e9SAndroid Build Coastguard Worker    def pattern(x):
558*da0073e9SAndroid Build Coastguard Worker        x = x.dequantize()
559*da0073e9SAndroid Build Coastguard Worker        x = torch.sigmoid(x)
560*da0073e9SAndroid Build Coastguard Worker        x = x.to(torch.float16)
561*da0073e9SAndroid Build Coastguard Worker        return x
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker    test_cases = [
564*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
565*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 1),
566*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 1),
567*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 0),
568*da0073e9SAndroid Build Coastguard Worker        TestCase(True, True, 0)
569*da0073e9SAndroid Build Coastguard Worker    ]
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Workerclass MultipleOutputsWithDependency:
572*da0073e9SAndroid Build Coastguard Worker    @staticmethod
573*da0073e9SAndroid Build Coastguard Worker    def forward(x):
574*da0073e9SAndroid Build Coastguard Worker        y = x.relu()
575*da0073e9SAndroid Build Coastguard Worker        z = y.sigmoid()
576*da0073e9SAndroid Build Coastguard Worker        return z, y
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker    @staticmethod
579*da0073e9SAndroid Build Coastguard Worker    def pattern(a):
580*da0073e9SAndroid Build Coastguard Worker        b = a.relu()
581*da0073e9SAndroid Build Coastguard Worker        c = b.sigmoid()
582*da0073e9SAndroid Build Coastguard Worker        return b, c     # outputs have data dependency
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker    test_cases = [
585*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
586*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 1),
587*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 0),
588*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 1),
589*da0073e9SAndroid Build Coastguard Worker        TestCase(True, True, 0)
590*da0073e9SAndroid Build Coastguard Worker    ]
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Workerclass MultipleOutputsWithoutDependency:
593*da0073e9SAndroid Build Coastguard Worker    @staticmethod
594*da0073e9SAndroid Build Coastguard Worker    def forward(x):
595*da0073e9SAndroid Build Coastguard Worker        x = x + 1
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker        # target subgraph to match
598*da0073e9SAndroid Build Coastguard Worker        x = x.relu()
599*da0073e9SAndroid Build Coastguard Worker        z = x.sum()
600*da0073e9SAndroid Build Coastguard Worker        y = x.sigmoid()
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker        out = y.sigmoid() + z.sum()
603*da0073e9SAndroid Build Coastguard Worker        return out
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker    @staticmethod
606*da0073e9SAndroid Build Coastguard Worker    def pattern(a):
607*da0073e9SAndroid Build Coastguard Worker        a = a.relu()
608*da0073e9SAndroid Build Coastguard Worker        b = a.sigmoid()
609*da0073e9SAndroid Build Coastguard Worker        c = a.sum()
610*da0073e9SAndroid Build Coastguard Worker        return b, c
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker    test_cases = [
613*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
614*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 1),
615*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 0),
616*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 0),
617*da0073e9SAndroid Build Coastguard Worker        TestCase(True, True, 0)
618*da0073e9SAndroid Build Coastguard Worker    ]
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Workerclass MultipleOutputsMultipleOverlappingMatches:
621*da0073e9SAndroid Build Coastguard Worker    @staticmethod
622*da0073e9SAndroid Build Coastguard Worker    def forward(x):
623*da0073e9SAndroid Build Coastguard Worker        x = x + 1
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker        # target subgraph to match
626*da0073e9SAndroid Build Coastguard Worker        x = x.relu()
627*da0073e9SAndroid Build Coastguard Worker        z = x.sum()
628*da0073e9SAndroid Build Coastguard Worker        z1 = x.sum()
629*da0073e9SAndroid Build Coastguard Worker        y = x.sigmoid()
630*da0073e9SAndroid Build Coastguard Worker        y1 = x.sigmoid()
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker        return z + z1 + y + y1
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker    @staticmethod
635*da0073e9SAndroid Build Coastguard Worker    def pattern(a):
636*da0073e9SAndroid Build Coastguard Worker        a = a.relu()
637*da0073e9SAndroid Build Coastguard Worker        b = a.sigmoid()
638*da0073e9SAndroid Build Coastguard Worker        c = a.sum()
639*da0073e9SAndroid Build Coastguard Worker        return a, b, c
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker    test_cases = [
642*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
643*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 4, remove_overlapping_matches=False),
644*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 1, remove_overlapping_matches=True),
645*da0073e9SAndroid Build Coastguard Worker    ]
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Workerclass MultipleOutputsMultipleNonOverlappingMatches:
648*da0073e9SAndroid Build Coastguard Worker    @staticmethod
649*da0073e9SAndroid Build Coastguard Worker    def forward(x):
650*da0073e9SAndroid Build Coastguard Worker        x = x + 1
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker        # target subgraph to match
653*da0073e9SAndroid Build Coastguard Worker        x = x.relu()
654*da0073e9SAndroid Build Coastguard Worker        z = x.sum()
655*da0073e9SAndroid Build Coastguard Worker        y = x.sigmoid()
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker        x = x.relu()
658*da0073e9SAndroid Build Coastguard Worker        z1 = x.sum()
659*da0073e9SAndroid Build Coastguard Worker        y1 = x.sigmoid()
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker        return z + z1 + y + y1
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker    @staticmethod
664*da0073e9SAndroid Build Coastguard Worker    def pattern(a):
665*da0073e9SAndroid Build Coastguard Worker        a = a.relu()
666*da0073e9SAndroid Build Coastguard Worker        b = a.sigmoid()
667*da0073e9SAndroid Build Coastguard Worker        c = a.sum()
668*da0073e9SAndroid Build Coastguard Worker        return b, c
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker    test_cases = [
671*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
672*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 1),
673*da0073e9SAndroid Build Coastguard Worker    ]
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Workerclass MultipleOutputsIdenticalAnchor:
676*da0073e9SAndroid Build Coastguard Worker    @staticmethod
677*da0073e9SAndroid Build Coastguard Worker    def forward(x):
678*da0073e9SAndroid Build Coastguard Worker        x = x + 1
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Worker        # target subgraph to match
681*da0073e9SAndroid Build Coastguard Worker        x = x.relu()
682*da0073e9SAndroid Build Coastguard Worker        y = x.sigmoid()
683*da0073e9SAndroid Build Coastguard Worker        y1 = x.sigmoid()
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker        return y, y1
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker    @staticmethod
688*da0073e9SAndroid Build Coastguard Worker    def pattern(a):
689*da0073e9SAndroid Build Coastguard Worker        a = a.relu()
690*da0073e9SAndroid Build Coastguard Worker        b = a.sigmoid()
691*da0073e9SAndroid Build Coastguard Worker        b1 = a.sigmoid()
692*da0073e9SAndroid Build Coastguard Worker        return b, b1
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker    test_cases = [
695*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
696*da0073e9SAndroid Build Coastguard Worker        # (False, False, 2),  # FIXME: currently still matches to 2, should fix to 1
697*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 1),
698*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 0),
699*da0073e9SAndroid Build Coastguard Worker    ]
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Workerclass MultipleOutputsHorizontalPattern:
703*da0073e9SAndroid Build Coastguard Worker    @staticmethod
704*da0073e9SAndroid Build Coastguard Worker    def forward(x):
705*da0073e9SAndroid Build Coastguard Worker        x = x + 1
706*da0073e9SAndroid Build Coastguard Worker
707*da0073e9SAndroid Build Coastguard Worker        # target subgraph to match
708*da0073e9SAndroid Build Coastguard Worker        y1 = x.relu()
709*da0073e9SAndroid Build Coastguard Worker        y2 = x.sigmoid()
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker        return y1, y2
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker    @staticmethod
714*da0073e9SAndroid Build Coastguard Worker    def pattern(a):
715*da0073e9SAndroid Build Coastguard Worker        b1 = a.relu()
716*da0073e9SAndroid Build Coastguard Worker        b2 = a.sigmoid()
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Worker        return b1, b2
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker    test_cases = [
721*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
722*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 1),
723*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 1),
724*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 0),
725*da0073e9SAndroid Build Coastguard Worker        TestCase(True, True, 0)
726*da0073e9SAndroid Build Coastguard Worker    ]
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Workerclass MultiOutputWithWithInvalidMatches:
729*da0073e9SAndroid Build Coastguard Worker    @staticmethod
730*da0073e9SAndroid Build Coastguard Worker    def forward(x):
731*da0073e9SAndroid Build Coastguard Worker        res0 = torch.nn.functional.linear(x, torch.rand(3, 3))
732*da0073e9SAndroid Build Coastguard Worker        res1 = torch.sigmoid(res0)
733*da0073e9SAndroid Build Coastguard Worker        res2 = res0 * res1
734*da0073e9SAndroid Build Coastguard Worker        res3 = torch.sum(res2, dim=1)
735*da0073e9SAndroid Build Coastguard Worker        return res3
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker    @staticmethod
738*da0073e9SAndroid Build Coastguard Worker    def pattern(a, b, c):
739*da0073e9SAndroid Build Coastguard Worker        lin_res = torch.nn.functional.linear(a, b)
740*da0073e9SAndroid Build Coastguard Worker        mul_res = lin_res * c
741*da0073e9SAndroid Build Coastguard Worker        return lin_res, mul_res
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker    test_cases = [
744*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
745*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 0),
746*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 0),
747*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 0),
748*da0073e9SAndroid Build Coastguard Worker    ]
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Workerclass QuantizationFp8Pattern:
751*da0073e9SAndroid Build Coastguard Worker    @classmethod
752*da0073e9SAndroid Build Coastguard Worker    def setup(cls):
753*da0073e9SAndroid Build Coastguard Worker        cls.quantization = torch.library.Library("fp8_quantization", "DEF")  # noqa: TOR901
754*da0073e9SAndroid Build Coastguard Worker        cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
755*da0073e9SAndroid Build Coastguard Worker        cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker    @classmethod
758*da0073e9SAndroid Build Coastguard Worker    def tearDown(cls):
759*da0073e9SAndroid Build Coastguard Worker        del cls.quantization
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker    @staticmethod
762*da0073e9SAndroid Build Coastguard Worker    def forward(self, arg0_1, arg1_1):
763*da0073e9SAndroid Build Coastguard Worker        qt = torch.ops.fp8_quantization
764*da0073e9SAndroid Build Coastguard Worker        _scale_0 = self._scale_0
765*da0073e9SAndroid Build Coastguard Worker        quantize_per_tensor_affine_fp8 = qt.quantize_per_tensor_affine_fp8(arg0_1, 0, _scale_0)
766*da0073e9SAndroid Build Coastguard Worker        dequantize_per_tensor_affine_fp8 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8, 0, _scale_0)
767*da0073e9SAndroid Build Coastguard Worker        _scale_1 = self._scale_0
768*da0073e9SAndroid Build Coastguard Worker        quantize_per_tensor_affine_fp8_1 = qt.quantize_per_tensor_affine_fp8(arg1_1, 0, _scale_1)
769*da0073e9SAndroid Build Coastguard Worker        dequantize_per_tensor_affine_fp8_1 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_1, 0, _scale_1)
770*da0073e9SAndroid Build Coastguard Worker        add = torch.ops.aten.add.Tensor(dequantize_per_tensor_affine_fp8, dequantize_per_tensor_affine_fp8_1)
771*da0073e9SAndroid Build Coastguard Worker        _scale_2 = self._scale_0
772*da0073e9SAndroid Build Coastguard Worker        quantize_per_tensor_affine_fp8_2 = qt.quantize_per_tensor_affine_fp8(add, 0, _scale_2)
773*da0073e9SAndroid Build Coastguard Worker        dequantize_per_tensor_affine_fp8_2 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_2, 0, _scale_2)
774*da0073e9SAndroid Build Coastguard Worker        return dequantize_per_tensor_affine_fp8_2
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker    @staticmethod
777*da0073e9SAndroid Build Coastguard Worker    def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale):
778*da0073e9SAndroid Build Coastguard Worker        qt = torch.ops.fp8_quantization
779*da0073e9SAndroid Build Coastguard Worker        a = qt.dequantize_per_tensor_affine_fp8(a, a_dtype, a_scale)
780*da0073e9SAndroid Build Coastguard Worker        b = qt.dequantize_per_tensor_affine_fp8(b, b_dtype, b_scale)
781*da0073e9SAndroid Build Coastguard Worker        output = torch.ops.aten.add.Tensor(a, b)
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker        qt.dequantize_per_tensor_affine_fp8
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker        output = qt.quantize_per_tensor_affine_fp8(output, a_dtype, out_scale)
786*da0073e9SAndroid Build Coastguard Worker        return output
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker    test_cases = [
789*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
790*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 1),
791*da0073e9SAndroid Build Coastguard Worker    ]
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Workerclass NoAnchorFound:
794*da0073e9SAndroid Build Coastguard Worker    # This test case is for pattern where no matching anchor is found in the target graph
795*da0073e9SAndroid Build Coastguard Worker    # `anchor` is the starting point of the pattern matching, it's usually the boundary returning nodes
796*da0073e9SAndroid Build Coastguard Worker    @staticmethod
797*da0073e9SAndroid Build Coastguard Worker    def forward(x):
798*da0073e9SAndroid Build Coastguard Worker        x = x + 1
799*da0073e9SAndroid Build Coastguard Worker        return x
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker    @staticmethod
802*da0073e9SAndroid Build Coastguard Worker    def pattern(a):
803*da0073e9SAndroid Build Coastguard Worker        b1 = a.relu()
804*da0073e9SAndroid Build Coastguard Worker        return b1
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Worker    test_cases = [
807*da0073e9SAndroid Build Coastguard Worker        # match_output, match_placeholder, num_matches
808*da0073e9SAndroid Build Coastguard Worker        TestCase(False, False, 0),
809*da0073e9SAndroid Build Coastguard Worker        TestCase(True, False, 0),
810*da0073e9SAndroid Build Coastguard Worker        TestCase(False, True, 0),
811*da0073e9SAndroid Build Coastguard Worker        TestCase(True, True, 0)
812*da0073e9SAndroid Build Coastguard Worker    ]
813*da0073e9SAndroid Build Coastguard Worker
814*da0073e9SAndroid Build Coastguard Worker@instantiate_parametrized_tests
815*da0073e9SAndroid Build Coastguard Workerclass TestFXMatcherUtils(JitTestCase):
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker    @parametrize("test_model", [
818*da0073e9SAndroid Build Coastguard Worker        SingleNodePattern,
819*da0073e9SAndroid Build Coastguard Worker        SimplePattern,
820*da0073e9SAndroid Build Coastguard Worker        SimpleFullGraphMatching,
821*da0073e9SAndroid Build Coastguard Worker        DiamondShapePatternTestCase,
822*da0073e9SAndroid Build Coastguard Worker        NonFullyContainedMatches,
823*da0073e9SAndroid Build Coastguard Worker        ChainRepeatedPattern,
824*da0073e9SAndroid Build Coastguard Worker        QuantizationModel,
825*da0073e9SAndroid Build Coastguard Worker        MultipleOutputsWithDependency,
826*da0073e9SAndroid Build Coastguard Worker        MultipleOutputsWithoutDependency,
827*da0073e9SAndroid Build Coastguard Worker        MultipleOutputsMultipleOverlappingMatches,
828*da0073e9SAndroid Build Coastguard Worker        MultipleOutputsMultipleNonOverlappingMatches,
829*da0073e9SAndroid Build Coastguard Worker        MultipleOutputsIdenticalAnchor,
830*da0073e9SAndroid Build Coastguard Worker        MultipleOutputsHorizontalPattern,
831*da0073e9SAndroid Build Coastguard Worker        MultiOutputWithWithInvalidMatches,
832*da0073e9SAndroid Build Coastguard Worker        QuantizationFp8Pattern,
833*da0073e9SAndroid Build Coastguard Worker        NoAnchorFound,
834*da0073e9SAndroid Build Coastguard Worker    ])
835*da0073e9SAndroid Build Coastguard Worker    def test_subgraph_matcher(self, test_model):
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker        setup = getattr(test_model, "setup", None)
838*da0073e9SAndroid Build Coastguard Worker        if callable(setup):
839*da0073e9SAndroid Build Coastguard Worker            setup()
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(test_model.forward)
842*da0073e9SAndroid Build Coastguard Worker        pattern_traced = symbolic_trace(test_model.pattern)
843*da0073e9SAndroid Build Coastguard Worker
844*da0073e9SAndroid Build Coastguard Worker        for test_case in test_model.test_cases:
845*da0073e9SAndroid Build Coastguard Worker
846*da0073e9SAndroid Build Coastguard Worker            matcher = SubgraphMatcher(pattern_traced.graph,
847*da0073e9SAndroid Build Coastguard Worker                                      match_output=test_case.match_output,
848*da0073e9SAndroid Build Coastguard Worker                                      match_placeholder=test_case.match_placeholder,
849*da0073e9SAndroid Build Coastguard Worker                                      remove_overlapping_matches=test_case.remove_overlapping_matches)
850*da0073e9SAndroid Build Coastguard Worker            matches = matcher.match(traced.graph)
851*da0073e9SAndroid Build Coastguard Worker
852*da0073e9SAndroid Build Coastguard Worker            assert len(matches) == test_case.num_matches
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker            for match in matches:
855*da0073e9SAndroid Build Coastguard Worker                for node in pattern_traced.graph.nodes:
856*da0073e9SAndroid Build Coastguard Worker                    if not test_case.match_placeholder and node.op == "placeholder":
857*da0073e9SAndroid Build Coastguard Worker                        continue
858*da0073e9SAndroid Build Coastguard Worker                    if not test_case.match_output and node.op == "output":
859*da0073e9SAndroid Build Coastguard Worker                        continue
860*da0073e9SAndroid Build Coastguard Worker                    assert node in match.nodes_map
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker        tearDown = getattr(test_model, "tearDown", None)
863*da0073e9SAndroid Build Coastguard Worker        if callable(setup):
864*da0073e9SAndroid Build Coastguard Worker            tearDown()
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
868*da0073e9SAndroid Build Coastguard Worker    run_tests()
869