1# Owner(s): ["oncall: fx"] 2 3import random 4 5import torch 6from torch.fx import symbolic_trace 7from torch.fx.experimental.proxy_tensor import make_fx 8from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops 9from torch.testing._internal.common_utils import run_tests, TestCase 10 11 12banned_ops = get_CSE_banned_ops() 13P_default = CSEPass(banned_ops=banned_ops) 14 15 16def check(self, f, t, delta, check_val=True, graph_input=False, P=None): 17 """ 18 check if the CSE modified graph of ``f`` 19 1) has delta less nodes, and 20 2) do not reduce the number of nodes further on a second pass, and 21 3) modified returned is true only if the number of nodes decreases. 22 23 Args: 24 f: function to be checked 25 t: tensor to be passed to f 26 delta: an integer >= -1. 27 If delta = -1, it only checks if the new graph has less or equal number of nodes 28 check_val: if True, check if the output of f is correct 29 graph_input: True is f is type GraphModule 30 P: the pass to use. If None, use P_default 31 """ 32 if graph_input: 33 fx_g = f 34 else: 35 fx_g = make_fx(f)(t) 36 37 if P is None: 38 P = P_default 39 40 res = P(fx_g) 41 new_g = res.graph_module 42 new_graph = new_g.graph 43 modified = res.modified 44 45 # the number of nodes decrease/ or stay the same 46 old_num_nodes = len(fx_g.graph.nodes) 47 new_num_nodes = len(new_graph.nodes) 48 49 assert ( 50 new_num_nodes < old_num_nodes 51 ) == modified, "modified should be True if the number of nodes decrease" 52 53 if delta == -1: 54 self.assertTrue( 55 old_num_nodes >= new_num_nodes, 56 (f"number of nodes increased {old_num_nodes}, {new_num_nodes}"), 57 ) 58 else: 59 self.assertTrue( 60 old_num_nodes == new_num_nodes + delta, 61 ( 62 f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}" 63 ), 64 ) 65 66 # a second pass should not reduce more nodes 67 res = P(new_g) 68 pass_2_graph = res.graph_module.graph 69 pass_2_num_nodes = len(pass_2_graph.nodes) 70 self.assertTrue( 71 pass_2_num_nodes == new_num_nodes, 72 ( 73 f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}" 74 ), 75 ) 76 77 # check correctness 78 if check_val: 79 true_result = fx_g(t) 80 our_result = new_g(t) 81 if true_result is None: # both return None 82 self.assertTrue( 83 our_result is None, f"true result is None, CSE result is {our_result}" 84 ) 85 else: # results returned are the same 86 self.assertTrue( 87 torch.all(true_result == our_result), 88 (f"results are different {true_result}, {our_result}"), 89 ) # check results are the same 90 91 92class TestCSEPass(TestCase): 93 def test_nochange(self): 94 def f(x): 95 a = x + 1 96 b = x + a 97 a = x 98 d = x + a 99 return b + d 100 101 t = torch.randn(2, 2) 102 check(self, f, t, 0) 103 104 def test_empty(self): 105 def f(x): 106 pass 107 108 t = torch.randn(2, 2) 109 check(self, f, t, 0) 110 111 def test_immutable_list_type(self): 112 def f(x): 113 a = x.sum(dim=1) 114 b = x.sum(dim=1) 115 c = x.sum() 116 d = x.sum() 117 return a + b + c + d 118 119 t = torch.randn(2, 2) 120 check(self, f, t, 2) 121 122 def test_immutable_list_multiple_entries(self): 123 def f(x): 124 a = x.sum(dim=[0, 1]) 125 b = x.sum(dim=[0, 1]) 126 c = x.sum(dim=1) 127 d = x.sum(dim=1) 128 return a + b + c + d 129 130 t = torch.randn(2, 2) 131 check(self, f, t, 2) 132 133 def test_simple(self): 134 def f(x): 135 a = x.cos() 136 b = x.cos() 137 c = a + a 138 d = b + b 139 return c + d 140 141 t = torch.randn(2, 2) 142 check(self, f, t, 2) 143 144 def test_simple_2(self): 145 def f(x): 146 a = x.cos().sin() 147 b = x.cos().sin() 148 c = a + a 149 d = b + b 150 return c + d 151 152 t = torch.randn(1) 153 check(self, f, t, 3) 154 155 def test_two_args_default(self): 156 def f(x): 157 a = x.sum(dim=1) 158 b = x.sum(dim=1, keepdim=False) 159 c = x.sum(dim=1, keepdim=False) 160 d = x.sum(dim=1) 161 return a + b + c + d 162 163 t = torch.randn(2, 2) 164 check(self, f, t, 3) 165 166 def test_two_args(self): 167 def f(x): 168 a = x.sum(dim=1) 169 b = x.sum(dim=1, keepdim=True) 170 c = x.sum(dim=1, keepdim=True) 171 d = x.sum(dim=1) 172 return a + b + c + d 173 174 t = torch.randn(2, 2) 175 check(self, f, t, 2) 176 177 def test_simple_multiple_same_ops(self): 178 def f(x): 179 a = x.sum() 180 b = x.sum() 181 c = x.sum() 182 d = x.sum() 183 return a + b + c + d 184 185 t = torch.randn(2, 2) 186 check(self, f, t, 3) 187 188 def test_nested_immutable_list_type(self): 189 def f(x): 190 a = torch.cat((x, x)) 191 b = torch.cat((x, x)) 192 return a + b 193 194 t = torch.randn(2, 2) 195 check(self, f, t, 1) 196 197 def test_kwarg(self): 198 def f(x): 199 a = torch.ones_like(x) 200 b = torch.ones_like(x) 201 return a + b 202 203 t = torch.randn(2, 2) 204 check(self, f, t, 1) 205 206 """ 207 Generate function with random ops and check if the result is the same 208 """ 209 210 def test_random(self): 211 def f(x): 212 vals = [x] 213 ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu] 214 for _ in range(100): 215 new_val = random.choice(ops)(random.choice(vals)) 216 vals.append(new_val) 217 return vals[-1] 218 219 fx_g = symbolic_trace(f) 220 fx_g.graph.eliminate_dead_code() 221 fx_g.recompile() 222 t = torch.randn(2, 2) 223 224 for _ in range(30): 225 check(self, fx_g, t, -1, graph_input=True) 226 227 """ 228 Test that banned list ban ops as expected. 229 """ 230 231 def test_banned_list(self): 232 def f(x): 233 a = x + 1 234 b = x + 1 235 return a + b 236 237 t = torch.randn(2, 2) 238 P_ban_add = P = CSEPass(banned_ops=[torch.ops.aten.add]) 239 check(self, f, t, 0, P=P_ban_add) # check that add is banned 240 check(self, f, t, 1) # check that add is not banned by default 241 242 def test_rand_like(self): 243 def f(x): 244 a = torch.rand_like(x) 245 b = torch.rand_like(x) 246 return a + b 247 248 t = torch.randn(2, 2) 249 check(self, f, t, 0, check_val=False) 250 251 def test_rand_n(self): 252 def f(x): 253 a = torch.randn(4) 254 b = torch.randn(4) 255 return a + b 256 257 t = torch.randn(2, 2) 258 check(self, f, t, 0, check_val=False) 259 260 261if __name__ == "__main__": 262 run_tests() 263