1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import unittest 10 11import executorch.exir as exir 12 13import torch 14from executorch.exir.pass_manager import PassManager 15from executorch.exir.passes import ScalarToTensorPass 16from executorch.exir.passes.pass_registry import PassRegistry 17from torch.fx.passes.infra.pass_base import PassBase 18 19 20class TestPassInfra(unittest.TestCase): 21 def test_fail_passbase(self) -> None: 22 """ 23 Tests if we catch errors when we do not inherit PassBase correctly 24 """ 25 26 # Catches error if we do not implement call() 27 class TestPass3(PassBase): 28 def __init__(self): 29 pass 30 31 with self.assertRaises(TypeError): 32 # pyre-ignore 33 TestPass3() 34 35 def test_pass_registry_func(self) -> None: 36 """ 37 Test if we register a callable correctly 38 """ 39 40 # Registering w/o specifying pass_name 41 @PassRegistry.register() 42 def test_pass1(graph_module: torch.fx.GraphModule) -> None: 43 pass 44 45 self.assertEqual(len(PassRegistry.get("test_pass1")), 1) 46 47 # Registering with a specified pass_name 48 @PassRegistry.register(pass_name="test_pass1_1") 49 def test_pass11(graph_module: torch.fx.GraphModule) -> None: 50 pass 51 52 self.assertEqual(len(PassRegistry.get("test_pass1_1")), 1) 53 54 def test_pass_registry_passbase(self) -> None: 55 """ 56 Test if we register a PassBase subclass correctly 57 """ 58 59 class TestPass2(PassBase): 60 def __init__(self) -> None: 61 pass 62 63 def call(self, graph_module: torch.fx.GraphModule) -> None: 64 pass 65 66 PassRegistry.register("test_pass2")(TestPass2()) 67 68 self.assertEqual(len(PassRegistry.get("test_pass2")), 1) 69 70 def test_pass_registry_list(self) -> None: 71 def test_pass1(graph_module: torch.fx.GraphModule) -> None: 72 pass 73 74 class TestPass2(PassBase): 75 def __init__(self) -> None: 76 pass 77 78 def call(self, graph_module: torch.fx.GraphModule) -> None: 79 pass 80 81 # Register a list of passes 82 PassRegistry.register_list( 83 pass_name="test_pass3", pass_list=[test_pass1, TestPass2()] 84 ) 85 self.assertEqual(len(PassRegistry.get("test_pass3")), 2) 86 87 def test_pass_manager(self) -> None: 88 """ 89 Tests that the pass manager runs the passes correctly. 90 """ 91 92 def replace_add_with_mul(gm: torch.fx.GraphModule) -> None: 93 for node in gm.graph.nodes: 94 if node.op == "call_function" and "aten.add.Tensor" in str(node.target): 95 node.target = torch.mul 96 97 def replace_mul_with_div(gm: torch.fx.GraphModule) -> None: 98 for node in gm.graph.nodes: 99 if node.op == "call_function" and node.target == torch.mul: 100 node.target = torch.div 101 102 def f(x: torch.Tensor) -> torch.Tensor: 103 y = torch.add(x, x) 104 z = torch.add(y, x) 105 return z 106 107 f = ( 108 exir.capture(f, (torch.randn(10),), exir.CaptureConfig()) 109 .to_edge() 110 .exported_program.graph_module 111 ) 112 pm = PassManager(passes=[replace_add_with_mul, replace_mul_with_div]) 113 self.assertEqual(len(pm.passes), 2) 114 pm(f) 115 116 # Check that all call_function nodes are divs 117 for node in f.graph.nodes: 118 if node.op == "call_function": 119 self.assertEqual(node.target, torch.div) 120 121 def test_pass_manager_invalid_passes(self) -> None: 122 """ 123 Tests that the pass manager detects invalid passes 124 """ 125 126 class Foo(torch.nn.Module): 127 def __init__(self) -> None: 128 super().__init__() 129 130 def forward(self, x: torch.Tensor) -> torch.Tensor: 131 return x 132 133 def introduce_call_method(gm: torch.fx.GraphModule) -> None: 134 node = list(gm.graph.nodes)[-2] 135 with gm.graph.inserting_after(node): 136 new_node = gm.graph.call_method("torch.ops.relu", (torch.randn(2),)) 137 node.replace_all_uses_with(new_node) 138 139 def introduce_call_module(gm: torch.fx.GraphModule) -> None: 140 node = list(gm.graph.nodes)[-2] 141 gm.add_submodule("foo", Foo()) 142 143 with gm.graph.inserting_after(node): 144 new_node = gm.graph.call_module("foo", (torch.randn(2),)) 145 node.replace_all_uses_with(new_node) 146 147 def f(x: torch.Tensor) -> torch.Tensor: 148 y = torch.add(x, x) 149 z = torch.add(y, x) 150 return z 151 152 traced_f1 = ( 153 exir.capture(f, (torch.randn(10),), exir.CaptureConfig()) 154 .to_edge() 155 .exported_program.graph_module 156 ) 157 pm1 = PassManager( 158 passes=[introduce_call_method], run_checks_after_each_pass=True 159 ) 160 161 with self.assertRaisesRegex(Exception, "call_method"): 162 pm1(traced_f1) 163 164 def test_pass_metadata(self) -> None: 165 def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 166 return x + y 167 168 sample_inputs = (torch.randn(1, 3), torch.randn(1, 3)) 169 gm = exir.capture( 170 f, sample_inputs, exir.CaptureConfig() 171 ).exported_program.graph_module 172 173 pass_result = ScalarToTensorPass()(gm) 174 self.assertIsNotNone(pass_result) 175 new_gm = pass_result.graph_module 176 177 for node in new_gm.graph.nodes: 178 if node.target != "output": 179 self.assertIn("val", node.meta) 180