xref: /aosp_15_r20/external/pytorch/test/quantization/fx/test_subgraph_rewriter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2# Copied from pytorch/test/fx/test_subgraph_rewriter.py
3
4import os
5import sys
6
7import torch
8from torch.fx import symbolic_trace, subgraph_rewriter
9from torch.fx.annotate import annotate
10# Make the helper files in test/ importable
11from torch.fx.experimental.rewriter import RewritingTracer
12
13pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14sys.path.append(pytorch_test_dir)
15from torch.testing._internal.jit_utils import JitTestCase
16
17if __name__ == '__main__':
18    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
19                       "\tpython test/test_fx.py TESTNAME\n\n"
20                       "instead.")
21
22class TestSubgraphRewriter(JitTestCase):
23
24    def test_subgraph_rewriter_preserves_logic(self):
25        class M(torch.nn.Module):
26            def forward(self, x):
27                val = torch.neg(x) + torch.relu(x)
28                return torch.add(val, val)
29
30        def pattern(x):
31            return torch.neg(x) + torch.relu(x)
32
33        def comparison(x):
34            val = torch.neg(x) + torch.relu(x)
35            return torch.add(val, val)
36
37        traced = symbolic_trace(M())
38        comparison_fn = symbolic_trace(comparison)
39
40        x = torch.rand(1, 3)
41
42        # Replace `pattern` with the same pattern (shouldn't change
43        # the underlying logic)
44        subgraph_rewriter.replace_pattern(traced, pattern, pattern)
45
46        traced.graph.lint()
47
48        ref_output = comparison_fn(x)
49        test_output = traced.forward(x)
50        self.assertEqual(ref_output, test_output)
51
52    def test_subgraph_rewriter_with_oneliner_pattern(self):
53        class M(torch.nn.Module):
54            def forward(self, x):
55                val = torch.neg(x)
56                return torch.add(val, val)
57
58        def pattern(x):
59            return torch.neg(x)
60
61        def replacement(x):
62            return torch.relu(x)
63
64        def comparison(x):
65            val = torch.relu(x)
66            return torch.add(val, val)
67
68        traced = symbolic_trace(M())
69        comparison_fn = symbolic_trace(comparison)
70
71        x = torch.rand(1, 3)
72
73        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
74
75        traced.graph.lint()
76
77        ref_output = comparison_fn(x)
78        test_output = traced.forward(x)
79        self.assertEqual(ref_output, test_output)
80
81    def test_subgraph_rewriter_single_pattern_match(self):
82        class M(torch.nn.Module):
83            def forward(self, x):
84                val = torch.neg(x) + torch.relu(x)
85                return torch.add(val, val)
86
87        def pattern(x):
88            return torch.neg(x) + torch.relu(x)
89
90        def replacement(x):
91            return torch.relu(x)
92
93        def comparison(x):
94            val = torch.relu(x)
95            return torch.add(val, val)
96
97        traced = symbolic_trace(M())
98        comparison_fn = symbolic_trace(comparison)
99
100        x = torch.rand(1, 3)
101
102        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
103
104        traced.graph.lint()
105
106        ref_output = comparison_fn(x)
107        test_output = traced.forward(x)
108        self.assertEqual(ref_output, test_output)
109
110    def test_subgraph_rewriter_multiple_pattern_match(self):
111        class M(torch.nn.Module):
112            def forward(self, x, w1, w2):
113                m1 = torch.cat([w1, w2]).sum()
114                m2 = torch.cat([w1, w2]).sum()
115                return x + torch.max(m1) + torch.max(m2)
116
117        def pattern(w1, w2):
118            return torch.cat([w1, w2]).sum()
119
120        def replacement(w1, w2):
121            return torch.stack([w1, w2])
122
123        def comparison(x, w1, w2):
124            m1 = torch.stack([w1, w2])
125            m2 = torch.stack([w1, w2])
126            return x + torch.max(m1) + torch.max(m2)
127
128        traced = symbolic_trace(M())
129        comparison_fn = symbolic_trace(comparison)
130
131        x = torch.rand(1, 3)
132        w1 = torch.rand(1, 3)
133        w2 = torch.rand(1, 3)
134
135        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
136
137        traced.graph.lint()
138
139        ref_outs = comparison_fn(x, w1, w2)
140        test_outs = traced.forward(x, w1, w2)
141        self.assertEqual(ref_outs, test_outs)
142
143    def test_subgraph_rewriter_graph_argument_order(self):
144        class M(torch.nn.Module):
145            def forward(self, x, y):
146                return torch.mm(x, y)
147
148        def pattern(x, y):
149            return torch.mm(x, y)
150
151        def comparison(x, y):
152            return torch.mm(x, y)
153
154        traced = symbolic_trace(M())
155        comparison_fn = symbolic_trace(comparison)
156
157        x = torch.randn(3, 4)
158        y = torch.randn(4, 5)
159
160        subgraph_rewriter.replace_pattern(traced, pattern, pattern)
161
162        traced.graph.lint()
163
164        ref_outs = comparison_fn(x, y)
165        test_outs = traced.forward(x, y)
166        self.assertEqual(ref_outs, test_outs)
167
168    def test_subgraph_rewriter_correct_output_replacement(self):
169        class M(torch.nn.Module):
170            def forward(self, x, y):
171                val = torch.neg(y) + torch.relu(x)
172                return torch.add(val, val)
173
174        def pattern(x):
175            return torch.relu(x)
176
177        def replacement(x):
178            return torch.neg(x)
179
180        def comparison(x, y):
181            val = torch.neg(y) + torch.neg(x)
182            return torch.add(val, val)
183
184        traced = symbolic_trace(M())
185        comparison_fn = symbolic_trace(comparison)
186
187        x = torch.randn(4, 4)
188        y = torch.randn(4, 4)
189
190        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
191
192        traced.graph.lint()
193
194        ref_outs = comparison_fn(x, y)
195        test_outs = traced.forward(x, y)
196        self.assertEqual(ref_outs, test_outs)
197
198    def test_subgraph_rewriter_traced_as_callable(self):
199        class M(torch.nn.Module):
200            def forward(self, x):
201                val = torch.neg(x) + torch.relu(x)
202                return torch.add(val, val)
203
204        class Pattern(torch.nn.Module):
205            def forward(self, x):
206                return torch.neg(x) + torch.relu(x)
207
208        class Replacement(torch.nn.Module):
209            def forward(self, x):
210                return torch.sigmoid(x)
211
212        def comparison(x):
213            val = torch.sigmoid(x)
214            return torch.add(val, val)
215
216        traced = symbolic_trace(M())
217        traced_pattern = symbolic_trace(Pattern())
218        traced_replacement = symbolic_trace(Replacement())
219        comparison_fn = symbolic_trace(comparison)
220
221        x = torch.randn(3, 4)
222
223        subgraph_rewriter.replace_pattern(traced, traced_pattern, traced_replacement)
224
225        traced.graph.lint()
226
227        ref_outs = comparison_fn(x)
228        test_outs = traced.forward(x)
229        self.assertEqual(ref_outs, test_outs)
230
231    def test_subgraph_rewriter_pattern_is_entire_graph(self):
232        class M(torch.nn.Module):
233            def forward(self, x):
234                a = torch.neg(x)
235                return torch.add(a, a)
236
237        def pattern(x):
238            a = torch.neg(x)
239            return torch.add(a, a)
240
241        def replacement(x):
242            a = torch.sigmoid(x)
243            return torch.cat([a, a])
244
245        traced = symbolic_trace(M())
246        comparison_fn = symbolic_trace(replacement)
247
248        x = torch.randn(3, 4)
249
250        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
251
252        traced.graph.lint()
253
254        ref_outs = comparison_fn(x)
255        test_outs = traced.forward(x)
256        self.assertEqual(ref_outs, test_outs)
257
258    def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched(self):
259        class M(torch.nn.Module):
260            def forward(self, x):
261                y = torch.relu(x)
262                return torch.neg(y) - y
263
264        def pattern(x):
265            return torch.relu(x)
266
267        def replacement(x):
268            return torch.sigmoid(x)
269
270        def comparison(x):
271            y = torch.sigmoid(x)
272            return torch.neg(y) - y
273
274        traced = symbolic_trace(M())
275        comparison_fn = symbolic_trace(comparison)
276
277        x = torch.randn(3, 4)
278
279        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
280
281        traced.graph.lint()
282
283        ref_outs = comparison_fn(x)
284        test_outs = traced.forward(x)
285        self.assertEqual(ref_outs, test_outs)
286
287    def test_subgraph_rewriter_internal_pattern_nodes_cannot_have_users_that_are_not_matched(self):
288        class M(torch.nn.Module):
289            def forward(self, x, w1, w2, b1, b2):
290                m0 = torch.cat([w1, w2])
291                m1 = torch.cat([w1, w2])
292                m2 = torch.cat([x, b2])
293                t0 = torch.addmm(b1, m1, m2.t())
294                t1 = torch.sum(w1, 1)
295                t2 = torch.addmm(b1, m1, m2.t())
296                return torch.sum(t1), torch.sum(t2)
297
298        def pattern(x, w1, w2, b1, b2):
299            m1 = torch.cat([w1, w2])
300            m2 = torch.cat([x, b2])
301            return torch.addmm(b1, m1, m2.t())
302
303        def replacement(x, w1, w2, b1, b2):
304            return torch.cat([x, w1, w2])
305
306        traced = symbolic_trace(M())
307
308        # Result should be [] since no matches can be found
309        res = subgraph_rewriter.replace_pattern(traced, pattern, replacement)
310
311        traced.graph.lint()
312
313        self.assertEqual(res, [])
314
315    def test_subgraph_rewriter_placeholder_matching(self):
316        """
317        This tests that a placeholder Node can be matched to a Node with
318        a different number of input Nodes. In the example below, the
319        original traced Module looks like this:
320            opcode         target                                                      args                      kwargs
321            -------------  ----------------------------------------------------------  ------------------------  --------
322            placeholder    x                                                           ()                        {}
323            call_function  <built-in function add>                                     (x, 3)                    {}
324            call_method    dequantize                                                  (add,)                    {}
325            call_function  <built-in method sigmoid of type object at 0x7f7c1f440fe0>  (dequantize,)             {}
326            call_method    to                                                          (sigmoid, torch.float16)  {}
327            output         output                                                      (to,)                     {}
328        while the pattern we want to match looks like this:
329            opcode         target                                                      args                      kwargs
330            -------------  ----------------------------------------------------------  ------------------------  --------
331            placeholder    x                                                           ()                        {}
332            call_method    dequantize                                                  (x,)                      {}
333            call_function  <built-in method sigmoid of type object at 0x7f7c1f440fe0>  (dequantize,)             {}
334            call_method    to                                                          (sigmoid, torch.float16)  {}
335            output         output                                                      (to,)                     {}
336        Here, we want to be able to match the original graph's
337        `call_function.add` Node with the pattern graph's
338        `plaeholder.x` Node.
339        Credit to Jerry Zhang (GitHub: jerryzh168) for this test case
340        """
341        class M(torch.nn.Module):
342            def __init__(self) -> None:
343                super().__init__()
344                self.dtype = torch.float16
345
346            def forward(self, x):
347                x += 3
348                x = x.dequantize()
349                x = torch.sigmoid(x)
350                dtype = self.dtype
351                x = x.to(dtype)
352                return x
353
354        def pattern(x):
355            x = x.dequantize()
356            x = torch.sigmoid(x)
357            x = x.to(torch.float16)
358            return x
359
360        def replacement(x):
361            return x
362
363        def comparison(x):
364            return x + 3
365
366        traced = symbolic_trace(M())
367        comparison_fn = symbolic_trace(comparison)
368
369        x = torch.randn(3, 4)
370
371        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
372
373        traced.graph.lint()
374
375        ref_outs = comparison_fn(x)
376        test_outs = traced.forward(x)
377        self.assertEqual(ref_outs, test_outs)
378
379    def test_subgraph_rewriter_replaces_referenced_submodules(self):
380        class M(torch.nn.Module):
381            def __init__(self) -> None:
382                super().__init__()
383                self.sigmoid = torch.nn.Sigmoid()
384                self.submod = torch.nn.ReLU()
385
386            def forward(self, x):
387                x = x + 1
388                return self.submod(self.sigmoid(x))
389
390        class Pattern(torch.nn.Module):
391            def __init__(self) -> None:
392                super().__init__()
393                self.sigmoid = torch.nn.Sigmoid()
394                self.submod = torch.nn.ReLU()
395
396            def forward(self, x):
397                return self.submod(self.sigmoid(x))
398
399        class Replacement(torch.nn.Module):
400            def __init__(self) -> None:
401                super().__init__()
402                self.id = torch.nn.Identity()
403                self.submod = torch.nn.ReLU()
404
405            def forward(self, x):
406                return self.submod(self.id(x))
407
408        class Comparison(torch.nn.Module):
409            def __init__(self) -> None:
410                super().__init__()
411                self.id = torch.nn.Identity()
412                self.submod = torch.nn.ReLU()
413
414            def forward(self, x):
415                x = x + 1
416                return self.submod(self.id(x))
417
418        traced = symbolic_trace(M())
419        comparison = Comparison()
420
421        x = torch.randn(3, 4)
422
423        subgraph_rewriter.replace_pattern(traced, Pattern(), Replacement())
424
425        traced.graph.lint()
426
427        ref_outs = comparison(x)
428        test_outs = traced.forward(x)
429        self.assertEqual(ref_outs, test_outs)
430
431        traced.get_submodule("id")
432        with self.assertRaisesRegex(AttributeError, "has no attribute"):
433            traced.get_submodule("sigmoid")
434
435        submod = traced.get_submodule("submod")
436        self.assertEqual(type(submod), torch.nn.ReLU)
437
438    def test_subgraph_rewriter_annotations_int(self):
439
440        class M1(torch.nn.Module):
441            def forward(self, x):
442                y: int = x
443                return torch.add(x, y)
444
445        class M2(torch.nn.Module):
446            def forward(self, x):
447                y = annotate(x, int)
448                return torch.add(x, y)
449
450        ast_rewriter = RewritingTracer()
451        graph = ast_rewriter.trace(M1())
452
453        module = M2()
454        symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
455        for n, m in zip(symbolic_traced.graph.nodes, graph.nodes):
456            if n.op == 'placeholder':
457                assert n.type == int
458                assert m.type == int
459
460    def test_subgraph_writer_replace_consecutive_submodules(self):
461
462        def f(x):
463            x = torch.sigmoid(x)
464            x = torch.sigmoid(x)
465            return torch.sigmoid(x)
466
467        def pattern(x):
468            return torch.sigmoid(x)
469
470        def replacement(x):
471            return torch.exp(x)
472
473        def comparison(x):
474            x = torch.exp(x)
475            x = torch.exp(x)
476            return torch.exp(x)
477
478        traced = symbolic_trace(f)
479        comparison_fn = symbolic_trace(comparison)
480
481        x = torch.randn(3, 4)
482
483        subgraph_rewriter.replace_pattern(traced, pattern, replacement)
484
485        traced.graph.lint()
486
487        ref_outs = comparison_fn(x)
488        test_outs = traced.forward(x)
489        self.assertEqual(ref_outs, test_outs)
490