1# Owner(s): ["oncall: quantization"] 2# Copied from pytorch/test/fx/test_subgraph_rewriter.py 3 4import os 5import sys 6 7import torch 8from torch.fx import symbolic_trace, subgraph_rewriter 9from torch.fx.annotate import annotate 10# Make the helper files in test/ importable 11from torch.fx.experimental.rewriter import RewritingTracer 12 13pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 14sys.path.append(pytorch_test_dir) 15from torch.testing._internal.jit_utils import JitTestCase 16 17if __name__ == '__main__': 18 raise RuntimeError("This test file is not meant to be run directly, use:\n\n" 19 "\tpython test/test_fx.py TESTNAME\n\n" 20 "instead.") 21 22class TestSubgraphRewriter(JitTestCase): 23 24 def test_subgraph_rewriter_preserves_logic(self): 25 class M(torch.nn.Module): 26 def forward(self, x): 27 val = torch.neg(x) + torch.relu(x) 28 return torch.add(val, val) 29 30 def pattern(x): 31 return torch.neg(x) + torch.relu(x) 32 33 def comparison(x): 34 val = torch.neg(x) + torch.relu(x) 35 return torch.add(val, val) 36 37 traced = symbolic_trace(M()) 38 comparison_fn = symbolic_trace(comparison) 39 40 x = torch.rand(1, 3) 41 42 # Replace `pattern` with the same pattern (shouldn't change 43 # the underlying logic) 44 subgraph_rewriter.replace_pattern(traced, pattern, pattern) 45 46 traced.graph.lint() 47 48 ref_output = comparison_fn(x) 49 test_output = traced.forward(x) 50 self.assertEqual(ref_output, test_output) 51 52 def test_subgraph_rewriter_with_oneliner_pattern(self): 53 class M(torch.nn.Module): 54 def forward(self, x): 55 val = torch.neg(x) 56 return torch.add(val, val) 57 58 def pattern(x): 59 return torch.neg(x) 60 61 def replacement(x): 62 return torch.relu(x) 63 64 def comparison(x): 65 val = torch.relu(x) 66 return torch.add(val, val) 67 68 traced = symbolic_trace(M()) 69 comparison_fn = symbolic_trace(comparison) 70 71 x = torch.rand(1, 3) 72 73 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 74 75 traced.graph.lint() 76 77 ref_output = comparison_fn(x) 78 test_output = traced.forward(x) 79 self.assertEqual(ref_output, test_output) 80 81 def test_subgraph_rewriter_single_pattern_match(self): 82 class M(torch.nn.Module): 83 def forward(self, x): 84 val = torch.neg(x) + torch.relu(x) 85 return torch.add(val, val) 86 87 def pattern(x): 88 return torch.neg(x) + torch.relu(x) 89 90 def replacement(x): 91 return torch.relu(x) 92 93 def comparison(x): 94 val = torch.relu(x) 95 return torch.add(val, val) 96 97 traced = symbolic_trace(M()) 98 comparison_fn = symbolic_trace(comparison) 99 100 x = torch.rand(1, 3) 101 102 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 103 104 traced.graph.lint() 105 106 ref_output = comparison_fn(x) 107 test_output = traced.forward(x) 108 self.assertEqual(ref_output, test_output) 109 110 def test_subgraph_rewriter_multiple_pattern_match(self): 111 class M(torch.nn.Module): 112 def forward(self, x, w1, w2): 113 m1 = torch.cat([w1, w2]).sum() 114 m2 = torch.cat([w1, w2]).sum() 115 return x + torch.max(m1) + torch.max(m2) 116 117 def pattern(w1, w2): 118 return torch.cat([w1, w2]).sum() 119 120 def replacement(w1, w2): 121 return torch.stack([w1, w2]) 122 123 def comparison(x, w1, w2): 124 m1 = torch.stack([w1, w2]) 125 m2 = torch.stack([w1, w2]) 126 return x + torch.max(m1) + torch.max(m2) 127 128 traced = symbolic_trace(M()) 129 comparison_fn = symbolic_trace(comparison) 130 131 x = torch.rand(1, 3) 132 w1 = torch.rand(1, 3) 133 w2 = torch.rand(1, 3) 134 135 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 136 137 traced.graph.lint() 138 139 ref_outs = comparison_fn(x, w1, w2) 140 test_outs = traced.forward(x, w1, w2) 141 self.assertEqual(ref_outs, test_outs) 142 143 def test_subgraph_rewriter_graph_argument_order(self): 144 class M(torch.nn.Module): 145 def forward(self, x, y): 146 return torch.mm(x, y) 147 148 def pattern(x, y): 149 return torch.mm(x, y) 150 151 def comparison(x, y): 152 return torch.mm(x, y) 153 154 traced = symbolic_trace(M()) 155 comparison_fn = symbolic_trace(comparison) 156 157 x = torch.randn(3, 4) 158 y = torch.randn(4, 5) 159 160 subgraph_rewriter.replace_pattern(traced, pattern, pattern) 161 162 traced.graph.lint() 163 164 ref_outs = comparison_fn(x, y) 165 test_outs = traced.forward(x, y) 166 self.assertEqual(ref_outs, test_outs) 167 168 def test_subgraph_rewriter_correct_output_replacement(self): 169 class M(torch.nn.Module): 170 def forward(self, x, y): 171 val = torch.neg(y) + torch.relu(x) 172 return torch.add(val, val) 173 174 def pattern(x): 175 return torch.relu(x) 176 177 def replacement(x): 178 return torch.neg(x) 179 180 def comparison(x, y): 181 val = torch.neg(y) + torch.neg(x) 182 return torch.add(val, val) 183 184 traced = symbolic_trace(M()) 185 comparison_fn = symbolic_trace(comparison) 186 187 x = torch.randn(4, 4) 188 y = torch.randn(4, 4) 189 190 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 191 192 traced.graph.lint() 193 194 ref_outs = comparison_fn(x, y) 195 test_outs = traced.forward(x, y) 196 self.assertEqual(ref_outs, test_outs) 197 198 def test_subgraph_rewriter_traced_as_callable(self): 199 class M(torch.nn.Module): 200 def forward(self, x): 201 val = torch.neg(x) + torch.relu(x) 202 return torch.add(val, val) 203 204 class Pattern(torch.nn.Module): 205 def forward(self, x): 206 return torch.neg(x) + torch.relu(x) 207 208 class Replacement(torch.nn.Module): 209 def forward(self, x): 210 return torch.sigmoid(x) 211 212 def comparison(x): 213 val = torch.sigmoid(x) 214 return torch.add(val, val) 215 216 traced = symbolic_trace(M()) 217 traced_pattern = symbolic_trace(Pattern()) 218 traced_replacement = symbolic_trace(Replacement()) 219 comparison_fn = symbolic_trace(comparison) 220 221 x = torch.randn(3, 4) 222 223 subgraph_rewriter.replace_pattern(traced, traced_pattern, traced_replacement) 224 225 traced.graph.lint() 226 227 ref_outs = comparison_fn(x) 228 test_outs = traced.forward(x) 229 self.assertEqual(ref_outs, test_outs) 230 231 def test_subgraph_rewriter_pattern_is_entire_graph(self): 232 class M(torch.nn.Module): 233 def forward(self, x): 234 a = torch.neg(x) 235 return torch.add(a, a) 236 237 def pattern(x): 238 a = torch.neg(x) 239 return torch.add(a, a) 240 241 def replacement(x): 242 a = torch.sigmoid(x) 243 return torch.cat([a, a]) 244 245 traced = symbolic_trace(M()) 246 comparison_fn = symbolic_trace(replacement) 247 248 x = torch.randn(3, 4) 249 250 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 251 252 traced.graph.lint() 253 254 ref_outs = comparison_fn(x) 255 test_outs = traced.forward(x) 256 self.assertEqual(ref_outs, test_outs) 257 258 def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched(self): 259 class M(torch.nn.Module): 260 def forward(self, x): 261 y = torch.relu(x) 262 return torch.neg(y) - y 263 264 def pattern(x): 265 return torch.relu(x) 266 267 def replacement(x): 268 return torch.sigmoid(x) 269 270 def comparison(x): 271 y = torch.sigmoid(x) 272 return torch.neg(y) - y 273 274 traced = symbolic_trace(M()) 275 comparison_fn = symbolic_trace(comparison) 276 277 x = torch.randn(3, 4) 278 279 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 280 281 traced.graph.lint() 282 283 ref_outs = comparison_fn(x) 284 test_outs = traced.forward(x) 285 self.assertEqual(ref_outs, test_outs) 286 287 def test_subgraph_rewriter_internal_pattern_nodes_cannot_have_users_that_are_not_matched(self): 288 class M(torch.nn.Module): 289 def forward(self, x, w1, w2, b1, b2): 290 m0 = torch.cat([w1, w2]) 291 m1 = torch.cat([w1, w2]) 292 m2 = torch.cat([x, b2]) 293 t0 = torch.addmm(b1, m1, m2.t()) 294 t1 = torch.sum(w1, 1) 295 t2 = torch.addmm(b1, m1, m2.t()) 296 return torch.sum(t1), torch.sum(t2) 297 298 def pattern(x, w1, w2, b1, b2): 299 m1 = torch.cat([w1, w2]) 300 m2 = torch.cat([x, b2]) 301 return torch.addmm(b1, m1, m2.t()) 302 303 def replacement(x, w1, w2, b1, b2): 304 return torch.cat([x, w1, w2]) 305 306 traced = symbolic_trace(M()) 307 308 # Result should be [] since no matches can be found 309 res = subgraph_rewriter.replace_pattern(traced, pattern, replacement) 310 311 traced.graph.lint() 312 313 self.assertEqual(res, []) 314 315 def test_subgraph_rewriter_placeholder_matching(self): 316 """ 317 This tests that a placeholder Node can be matched to a Node with 318 a different number of input Nodes. In the example below, the 319 original traced Module looks like this: 320 opcode target args kwargs 321 ------------- ---------------------------------------------------------- ------------------------ -------- 322 placeholder x () {} 323 call_function <built-in function add> (x, 3) {} 324 call_method dequantize (add,) {} 325 call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {} 326 call_method to (sigmoid, torch.float16) {} 327 output output (to,) {} 328 while the pattern we want to match looks like this: 329 opcode target args kwargs 330 ------------- ---------------------------------------------------------- ------------------------ -------- 331 placeholder x () {} 332 call_method dequantize (x,) {} 333 call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {} 334 call_method to (sigmoid, torch.float16) {} 335 output output (to,) {} 336 Here, we want to be able to match the original graph's 337 `call_function.add` Node with the pattern graph's 338 `plaeholder.x` Node. 339 Credit to Jerry Zhang (GitHub: jerryzh168) for this test case 340 """ 341 class M(torch.nn.Module): 342 def __init__(self) -> None: 343 super().__init__() 344 self.dtype = torch.float16 345 346 def forward(self, x): 347 x += 3 348 x = x.dequantize() 349 x = torch.sigmoid(x) 350 dtype = self.dtype 351 x = x.to(dtype) 352 return x 353 354 def pattern(x): 355 x = x.dequantize() 356 x = torch.sigmoid(x) 357 x = x.to(torch.float16) 358 return x 359 360 def replacement(x): 361 return x 362 363 def comparison(x): 364 return x + 3 365 366 traced = symbolic_trace(M()) 367 comparison_fn = symbolic_trace(comparison) 368 369 x = torch.randn(3, 4) 370 371 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 372 373 traced.graph.lint() 374 375 ref_outs = comparison_fn(x) 376 test_outs = traced.forward(x) 377 self.assertEqual(ref_outs, test_outs) 378 379 def test_subgraph_rewriter_replaces_referenced_submodules(self): 380 class M(torch.nn.Module): 381 def __init__(self) -> None: 382 super().__init__() 383 self.sigmoid = torch.nn.Sigmoid() 384 self.submod = torch.nn.ReLU() 385 386 def forward(self, x): 387 x = x + 1 388 return self.submod(self.sigmoid(x)) 389 390 class Pattern(torch.nn.Module): 391 def __init__(self) -> None: 392 super().__init__() 393 self.sigmoid = torch.nn.Sigmoid() 394 self.submod = torch.nn.ReLU() 395 396 def forward(self, x): 397 return self.submod(self.sigmoid(x)) 398 399 class Replacement(torch.nn.Module): 400 def __init__(self) -> None: 401 super().__init__() 402 self.id = torch.nn.Identity() 403 self.submod = torch.nn.ReLU() 404 405 def forward(self, x): 406 return self.submod(self.id(x)) 407 408 class Comparison(torch.nn.Module): 409 def __init__(self) -> None: 410 super().__init__() 411 self.id = torch.nn.Identity() 412 self.submod = torch.nn.ReLU() 413 414 def forward(self, x): 415 x = x + 1 416 return self.submod(self.id(x)) 417 418 traced = symbolic_trace(M()) 419 comparison = Comparison() 420 421 x = torch.randn(3, 4) 422 423 subgraph_rewriter.replace_pattern(traced, Pattern(), Replacement()) 424 425 traced.graph.lint() 426 427 ref_outs = comparison(x) 428 test_outs = traced.forward(x) 429 self.assertEqual(ref_outs, test_outs) 430 431 traced.get_submodule("id") 432 with self.assertRaisesRegex(AttributeError, "has no attribute"): 433 traced.get_submodule("sigmoid") 434 435 submod = traced.get_submodule("submod") 436 self.assertEqual(type(submod), torch.nn.ReLU) 437 438 def test_subgraph_rewriter_annotations_int(self): 439 440 class M1(torch.nn.Module): 441 def forward(self, x): 442 y: int = x 443 return torch.add(x, y) 444 445 class M2(torch.nn.Module): 446 def forward(self, x): 447 y = annotate(x, int) 448 return torch.add(x, y) 449 450 ast_rewriter = RewritingTracer() 451 graph = ast_rewriter.trace(M1()) 452 453 module = M2() 454 symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) 455 for n, m in zip(symbolic_traced.graph.nodes, graph.nodes): 456 if n.op == 'placeholder': 457 assert n.type == int 458 assert m.type == int 459 460 def test_subgraph_writer_replace_consecutive_submodules(self): 461 462 def f(x): 463 x = torch.sigmoid(x) 464 x = torch.sigmoid(x) 465 return torch.sigmoid(x) 466 467 def pattern(x): 468 return torch.sigmoid(x) 469 470 def replacement(x): 471 return torch.exp(x) 472 473 def comparison(x): 474 x = torch.exp(x) 475 x = torch.exp(x) 476 return torch.exp(x) 477 478 traced = symbolic_trace(f) 479 comparison_fn = symbolic_trace(comparison) 480 481 x = torch.randn(3, 4) 482 483 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 484 485 traced.graph.lint() 486 487 ref_outs = comparison_fn(x) 488 test_outs = traced.forward(x) 489 self.assertEqual(ref_outs, test_outs) 490