1# Owner(s): ["module: ProxyTensor"] 2 3from torch.testing._internal.common_utils import TestCase, run_tests 4import torch 5import torch._dynamo 6import unittest 7import warnings 8import operator 9from collections.abc import Iterable 10from torch.nn.utils import stateless 11from torch.testing._internal.common_device_type import instantiate_device_type_tests 12from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps 13from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode 14from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode 15from torch._decomp import decomposition_table 16from torch.fx.experimental.symbolic_shapes import ( 17 eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets, 18 guard_int, GuardOnDataDependentSymNode 19) 20from torch.testing._internal.custom_op_db import custom_op_db 21from torch.testing._internal.hop_db import hop_db 22from torch.testing._internal.common_device_type import ops 23import torch.testing._internal.optests as optests 24from torch._C import _disabled_torch_function_impl 25from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule 26from torch.utils._pytree import tree_map 27from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts 28from torch import nn 29import torch._functorch.config 30import re 31 32import functools 33import itertools 34 35aten = torch.ops.aten 36 37HAS_CUDA = torch.cuda.is_available() 38 39 40def strip_end(s, suffix): 41 if suffix and s.endswith(suffix): 42 return s[:-len(suffix)] 43 else: 44 return s 45 46 47def show_guards(gm): 48 names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)] 49 return "\n".join( 50 gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, input_contexts=None) 51 ) 52 53 54def process_failures(): 55 """ 56 Takes file containing failures like 57 58 FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition # noqa: B950 59 60 and processes them into a list of opinfo xfails 61 """ 62 f = open('pytest_failures') 63 failures = f.readlines() 64 failures = [i.strip() for i in failures] 65 66 def process_failure_string(s, matcher): 67 out = re.search(matcher, s) 68 return out.groups() 69 70 SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)' 71 failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures] 72 73 def create_normalized_name(op): 74 if op.variant_test_name == '': 75 s = op.name 76 else: 77 s = f"{op.name}.{op.variant_test_name}" 78 return s.replace('.', '_') 79 80 remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db} 81 82 print("symbolic_tensor_failures = {") 83 for failure, reason in failures: 84 print(f" xfail{remap_opinfo[failure]}, # {reason}") 85 print("}") 86 87 88USE_TORCHVISION = False 89try: 90 import torchvision 91 USE_TORCHVISION = True 92except ImportError: 93 warnings.warn("Couldn't import torchvision. Some of our tests use it, try " 94 "to install it with commands from pytorch.org, post-fixed with " 95 "`--no-deps` to avoid overwriting the pytorch installation", 96 UserWarning) 97 98 99def _create_new_input(x): 100 if not isinstance(x, torch.Tensor): 101 return x 102 if x.dtype != torch.float: 103 return x + 1 104 if x.is_leaf: 105 return torch.rand_like(x, requires_grad=x.requires_grad) 106 else: 107 return torch.rand_like(x) 108 109""" 110Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used 111""" 112class UnwrapTensor(torch.Tensor): 113 @staticmethod 114 def __new__(cls, tensor: torch.Tensor): 115 r = torch.Tensor._make_wrapper_subclass( 116 cls, 117 tensor.size(), 118 dtype=tensor.dtype, 119 device=tensor.device, 120 layout=tensor.layout, 121 requires_grad=tensor.requires_grad, 122 ) 123 r._tensor = tensor 124 return r 125 126 def __repr__(self): 127 # TODO: consider all_gather the local tensors for better debugging 128 return f"UnwrapTensor({self._tensor})" 129 130 __torch_function__ = _disabled_torch_function_impl 131 132 @classmethod 133 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 134 def unwrap(e): 135 ret = e 136 if isinstance(e, UnwrapTensor): 137 ret = e._tensor.cos() 138 139 return ret 140 141 args = tree_map(unwrap, args) 142 kwargs = tree_map(unwrap, kwargs) 143 return func(*args, **kwargs) 144 145class TestGenericProxyTensor(TestCase): 146 # WARNING: if any of your inputs are index tensors, DO NOT use this 147 # function 148 def _test(self, f, inps): 149 fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps) 150 new_inps = tree_map(_create_new_input, inps) 151 r1 = fx_f(*new_inps) 152 r2 = f(*new_inps) 153 self.assertEqual(r1, r2) 154 155 def test_pre_dispatch_mode_stack(self): 156 def f(a): 157 b = torch.ones(4, 4) 158 return torch.matmul(a, b) 159 # We expect to see matmul in the trace - it should NOT be decomposed into mm. 160 # Also, torch.ones() doesn't show up in the trace. 161 # This is annoying but expected: ones() never dispatches to the Autograd dispatch key, 162 # so our mode never sees it - it goes directly to the BackendSelect key. 163 inp = torch.ones(4, 4) 164 # Test that make_fx(pre_dispatch=True) clears caches properly. 165 from torch._dispatch.python import enable_python_dispatcher 166 with enable_python_dispatcher(): 167 out1 = f(inp) 168 fx_g = make_fx(f, pre_dispatch=True)(inp) 169 self.assertExpectedInline(fx_g.code.strip(), """\ 170def forward(self, a_1): 171 ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False) 172 matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None 173 return matmul""") 174 175 def test_pre_dispatch_linear(self): 176 def f(a, b, c): 177 return torch.nn.functional.linear(a, b, c) 178 a = torch.ones(4, 4) 179 b = torch.ones(4, 4) 180 c = torch.ones(4) 181 fx_g = make_fx(f, pre_dispatch=True)(a, b, c) 182 out1 = f(a, b, c) 183 out2 = fx_g(a, b, c) 184 self.assertEqual(out1, out2) 185 186 def test_pre_dispatch_no_grad(self): 187 def f(a): 188 b = a.sin() 189 torch.set_grad_enabled(False) 190 c = b.cos() 191 torch.set_grad_enabled(True) 192 return b + c.sin() 193 a1 = torch.randn(4, requires_grad=True) 194 a2 = a1.clone().detach().requires_grad_(True) 195 a_tmp = a1.clone().detach().requires_grad_(True) 196 fx_g = make_fx(f, pre_dispatch=True)(a_tmp) 197 out1 = f(a1) 198 out2 = fx_g(a2) 199 self.assertEqual(out1, out2) 200 out1.sum().backward() 201 out2.sum().backward() 202 self.assertEqual(a1.grad, a2.grad) 203 204 def test_make_fx_simple(self): 205 def f(x): 206 return torch.sin(x) 207 self._test(f, (torch.randn(3),)) 208 209 def test_scalar_device(self, device='cpu'): 210 def f(a, b): 211 return a + b 212 self._test(f, [torch.randn(3, device=device), torch.tensor(5)]) 213 214 def test_isolated_graphmodule(self): 215 def is_any_sum(gm): 216 return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes) 217 218 def is_any_digamma(gm): 219 return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes) 220 221 def is_any_sigmoid(gm): 222 return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes) 223 224 def inner(x): 225 return torch.sum(x) 226 227 def f(x): 228 gm = get_isolated_graphmodule(inner, (x,), {}) 229 self.assertTrue(is_any_sum(gm)) 230 return x + torch.randn(x.shape) 231 232 # get_isolated_graphmodule uses make_fx internally that shouldn't be traced 233 # by the outer make_fx call 234 traced = make_fx(f)(torch.randn(3)) 235 self.assertFalse(is_any_sum(traced)) 236 237 # When factory functions are used, they should not be traced 238 # by the outer make_fx call 239 def inner_with_factory(): 240 val = torch.tensor(float(1)) 241 val.add_(2) 242 return torch.full((10, 10), val).sum() 243 244 def f1(x): 245 gm = get_isolated_graphmodule(inner_with_factory, (), {}) 246 self.assertTrue(is_any_sum(gm)) 247 return torch.sigmoid(x) 248 249 def f2(x): 250 gm = get_isolated_graphmodule(f1, (x,), {}) 251 self.assertFalse(is_any_sum(gm)) 252 self.assertTrue(is_any_sigmoid(gm)) 253 return torch.digamma(x) 254 255 traced = make_fx(f2)(torch.randn(3)) 256 self.assertFalse(is_any_sum(traced)) 257 self.assertFalse(is_any_sigmoid(traced)) 258 self.assertTrue(is_any_digamma(traced)) 259 260 # Verify nested make_fx calls don't make factory functions to be leaked 261 # into the outer graph. Verify that `make_fx`` itself does not leak its execution. 262 def f2(x): 263 gm = make_fx(f1)(x) 264 self.assertFalse(is_any_sum(gm)) 265 self.assertTrue(is_any_sigmoid(gm)) 266 return torch.digamma(x) 267 268 traced = make_fx(f2)(torch.randn(3)) 269 self.assertFalse(is_any_sum(traced)) 270 self.assertFalse(is_any_sigmoid(traced)) 271 self.assertTrue(is_any_digamma(traced)) 272 273 # Verify that the `forward`` function of a graph module produced as a 274 # side effect of an interior `make_fx` is still traced 275 def f3(x): 276 gm = make_fx(f1)(x) 277 self.assertFalse(is_any_sum(gm)) 278 self.assertTrue(is_any_sigmoid(gm)) 279 # `gm.forward`` is still traced 280 return torch.digamma(gm(x)) 281 282 traced = make_fx(f3)(torch.randn(3)) 283 self.assertFalse(is_any_sum(traced)) 284 self.assertTrue(is_any_sigmoid(traced)) 285 self.assertTrue(is_any_digamma(traced)) 286 287 # Verify interaction with non-ProxyTensor modes 288 from torch.testing._internal.logging_tensor import LoggingTensorMode 289 290 def f1_logging(x): 291 with LoggingTensorMode(): 292 gm = get_isolated_graphmodule(inner_with_factory, (), {}) 293 self.assertTrue(is_any_sum(gm)) 294 return torch.sigmoid(x) 295 296 def f2_logging(x): 297 with LoggingTensorMode(), LoggingTensorMode(): 298 gm = get_isolated_graphmodule(f1_logging, (x,), {}) 299 self.assertFalse(is_any_sum(gm)) 300 self.assertTrue(is_any_sigmoid(gm)) 301 return torch.digamma(x) 302 303 traced = make_fx(f2_logging)(torch.randn(3)) 304 self.assertFalse(is_any_sum(traced)) 305 self.assertFalse(is_any_sigmoid(traced)) 306 self.assertTrue(is_any_digamma(traced)) 307 308 # Verify interaction with another tensor subclass 309 # This case currently doesn't work and should raise an error 310 # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068 311 from torch.testing._internal.logging_tensor import LoggingTensor 312 313 def f1_logging_tensor(x): 314 gm = get_isolated_graphmodule(inner_with_factory, (), {}) 315 self.assertTrue(is_any_sum(gm)) 316 return torch.sigmoid(x) 317 318 def f2_logging_tensor(x): 319 x = LoggingTensor(x) 320 gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {}) 321 self.assertFalse(is_any_sum(gm)) 322 self.assertTrue(is_any_sigmoid(gm)) 323 return torch.digamma(x) 324 325 traced = make_fx(f2_logging_tensor)(torch.randn(3)) 326 self.assertFalse(is_any_sum(traced)) 327 self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor 328 self.assertTrue(is_any_digamma(traced)) 329 330 # See https://github.com/pytorch/pytorch/issues/97541 331 def test_empty_like_doesnt_burn_in_defaults(self): 332 def f(x): 333 return torch.empty_like(x) 334 out = make_fx(f)(torch.randn(3)) 335 self.assertExpectedInline(out.code.strip(), """\ 336def forward(self, x_1): 337 empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False); x_1 = None 338 return empty_like""") 339 340 def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self): 341 def f(x): 342 y = x.new_zeros(x.size()) 343 y.copy_(x) 344 return y 345 346 def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None): 347 return torch.zeros(size, dtype=inp.dtype, device=inp.device) 348 349 factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp} 350 351 # When new_zeros() decomposes into torch.zero(), we expect ProxyTensorMode 352 # to still be (re-entrantly) enabled, so that the `torch.zero()` call 353 # returns a ProxyTensor. 354 out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2)) 355 self.assertExpectedInline(out.code, """\ 356 357 358 359def forward(self, x_1): 360 zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) 361 copy_ = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None 362 return copy_ 363 """) 364 365 def test_make_fx_reentrant_dispatch(self): 366 def f(x): 367 return torch.ops.aten.norm.Scalar(x, 2.0) 368 369 def norm_decomp(x, p=2.0): 370 if p != 2.0: 371 raise RuntimeError("can't handle with p != 2") 372 return torch.sqrt(torch.sum(torch.square(x))) 373 374 decomp = {torch.ops.aten.norm.Scalar: norm_decomp} 375 376 traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3)) 377 378 for n in traced.graph.nodes: 379 self.assertTrue("square" not in str(n.target)) 380 self.assertTrue("norm" not in str(n.target)) 381 382 @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") 383 def test_resnet18_backward_trace(self): 384 mod = torchvision.models.resnet18() 385 386 # An old version of this test called the module directly. This works 387 # for tracing_mode == "real", but for fake tensors, we also have to 388 # ensure that the parameters and buffers get wrapped in fake tensors 389 # because free fake tensors are not supported. Fortunately functional_call 390 # does precisely this for us. 391 def f(x, params, buffers): 392 for p in params.values(): 393 p.grad = None 394 loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum() 395 # I could have done this with the functional API, but there is 396 # plenty of exercising this; I want to show mutating API still 397 # works 398 loss.backward() 399 return [p.grad for p in params.values()] 400 401 inp = torch.randn(3, 3, 250, 250) 402 self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())]) 403 404 def test_varargs(self): 405 def f(*args): 406 return sum(args) 407 408 self._test(f, [torch.randn(2), torch.randn(2)]) 409 410 def test_proxy_tensor(self): 411 def f_grad(x): 412 val = x.cos().cos().sum() 413 return torch.autograd.grad(val, x) 414 415 def f_backward(x): 416 val = x.cos().cos().sum() 417 val.backward() 418 return x.grad 419 420 for f in [f_grad, f_backward]: 421 self._test(f, [torch.randn(3, requires_grad=True)]) 422 423 def test_pickle_issue89626(self): 424 import pickle 425 x = torch.randn(2) 426 make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x) 427 pickle.dumps(x) 428 429 def test_inplace_metadata(self): 430 def f(x): 431 x = x.clone() 432 x.unsqueeze_(-1) 433 assert x.shape[-1] == 1 434 return x 435 436 self._test(f, [torch.randn(5)]) 437 438 def test_mode_tracing_factory_function(self): 439 def f(x): 440 return x + torch.randn(x.shape) 441 442 # default behavior should trace factory functions 443 traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3)) 444 self.assertTrue( 445 any( 446 node.target == aten.randn.default 447 for node in traced.graph.nodes 448 ) 449 ) 450 451 def test_pre_dispatch_functionalization(self): 452 def f(x): 453 a = FunctionalTensorMode(pre_dispatch=True) 454 with a: 455 x_unwrapped = FunctionalTensor.to_functional(x) 456 y = torch.matmul(x_unwrapped, x_unwrapped) 457 y = y + x_unwrapped 458 y.mul_(5) 459 y_unwrapped = torch._from_functional_tensor(y.elem) 460 return y_unwrapped 461 462 from torch._dispatch.python import enable_python_dispatcher 463 464 with enable_python_dispatcher(): 465 inp = torch.randn(4, 4) 466 gm = make_fx(f, pre_dispatch=True)(inp) 467 468 # TODO actually not decompose 469 self.assertExpectedInline(gm.code.strip(), """\ 470def forward(self, x_1): 471 matmul = torch.ops.aten.matmul.default(x_1, x_1) 472 add = torch.ops.aten.add.Tensor(matmul, x_1); matmul = x_1 = None 473 mul = torch.ops.aten.mul.Tensor(add, 5); add = None 474 return mul""") 475 476 def test_pre_dispatch_functionalization_view_op(self): 477 def f(x): 478 a = FunctionalTensorMode(pre_dispatch=True) 479 with a: 480 x_unwrapped = FunctionalTensor.to_functional(x) 481 y = torch.matmul(x_unwrapped, x_unwrapped) 482 x_unwrapped = x_unwrapped.transpose(1, 0) 483 y = y + x_unwrapped 484 y = y.view(2, 8) 485 y_unwrapped = torch._from_functional_tensor(y.elem) 486 return y_unwrapped 487 488 from torch._dispatch.python import enable_python_dispatcher 489 490 with enable_python_dispatcher(): 491 inp = torch.randn(4, 4) 492 gm = make_fx(f, pre_dispatch=True)(inp) 493 494 # TODO actually not decompose 495 self.assertExpectedInline(gm.code.strip(), """\ 496def forward(self, x_1): 497 matmul = torch.ops.aten.matmul.default(x_1, x_1) 498 transpose = torch.ops.aten.transpose.int(x_1, 1, 0); x_1 = None 499 add = torch.ops.aten.add.Tensor(matmul, transpose); matmul = transpose = None 500 view = torch.ops.aten.view.default(add, [2, 8]); add = None 501 return view""") 502 503 def test_val_metadata_mutation(self): 504 def f(x): 505 y = x.clone() 506 y.unsqueeze_(0) 507 return y 508 509 traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True)) 510 self.assertEqual([ 511 tuple(node.meta['val'].shape) 512 for node in traced.graph.nodes 513 if 'val' in node.meta 514 ], [(3,), (3,), (1, 3)]) 515 516 def test_make_fx_overloads(self): 517 def f(x): 518 return x.cos() + torch.randn(x.shape) 519 520 traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3)) 521 522 self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload) 523 for node in traced.graph.nodes if node.op == 'call_function')) 524 525 def test_tensor_constants(self): 526 def f(): 527 val = torch.tensor(float('inf')) 528 return torch.full((100, 100), val) 529 530 self._test(f, []) 531 532 def test_allclose(self): 533 def f(a, b): 534 return torch.allclose(a, b) 535 536 def test_f(): 537 make_fx(f, tracing_mode=self.tracing_mode)( 538 torch.zeros(3), torch.zeros(3) 539 ) 540 541 if self.tracing_mode != "real": 542 self.assertRaises(DataDependentOutputException, test_f) 543 else: 544 self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) 545 546 def test_constant_proxy_tensor_mut(self): 547 def f(): 548 val = torch.tensor(float(1)) 549 val.add_(2) 550 return torch.full((100, 100), val) 551 552 g = make_fx(f, tracing_mode=self.tracing_mode)() 553 self.assertEqual(g(), f()) 554 # In case we mutated shared state in the g graph! 555 self.assertEqual(g(), f()) 556 557 def test_constant_unbind(self): 558 def f(): 559 val = torch.tensor([2]) 560 r, = torch.unbind(val, 0) 561 return r.item() 562 563 g = make_fx(f, tracing_mode=self.tracing_mode)() 564 self.assertEqual(g(), f()) 565 566 def test_constant_blowup(self): 567 def f(): 568 val = torch.tensor([2]) 569 blowup = val.repeat(1000) 570 return bool(blowup.sum().item() == 2) 571 572 def test_f(): 573 make_fx(f, tracing_mode=self.tracing_mode)() 574 575 self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) 576 577 def test_constant_random(self): 578 def f(): 579 val = torch.tensor([2.0]) 580 val.normal_() 581 return bool(val.item() == 2.1) 582 583 def test_f(): 584 make_fx(f, tracing_mode=self.tracing_mode)() 585 586 self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) 587 588 def test_decomposition_interpreter(self): 589 def fn(x): 590 return torch.nn.functional.silu(x) 591 592 x = torch.rand((4, 4)) 593 fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x) 594 595 found_silu = False 596 for n in fx_module.graph.nodes: 597 if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default: 598 found_silu = True 599 600 self.assertTrue(found_silu) 601 602 new_graph = torch.fx.Graph() 603 silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]} 604 DecompositionInterpreter( 605 fx_module, 606 new_graph=new_graph, 607 decomposition_table=silu_decomp_table, 608 ).run(x) 609 610 decomposed_module = torch.fx.GraphModule(fx_module, new_graph) 611 612 for n in decomposed_module.graph.nodes: 613 self.assertTrue(n.target != torch.ops.aten.silu) 614 self.assertTrue(n.target != torch.ops.aten.silu.default) 615 616 self.assertEqual(fx_module(x), decomposed_module(x)) 617 618 def test_make_fx_model_fwd_bwd(self): 619 class Foo(torch.nn.Module): 620 def __init__(self) -> None: 621 super().__init__() 622 self.linear = torch.nn.Linear(5, 5) 623 624 def forward(self, x): 625 return self.linear(x).relu() 626 627 model = Foo() 628 629 def f(x, params): 630 out = torch.func.functional_call(model, params, x).sum() 631 out.backward() 632 return list(params.values()) 633 input = torch.randn(3, 5, requires_grad=True) 634 params = dict(model.named_parameters()) 635 fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params) 636 # fx may change the order of parameters in list, so using set() to compare 637 self.assertTrue( 638 torch.allclose(fx_f(input, params)[0], f(input, params)[0]) 639 or 640 torch.allclose(fx_f(input, params)[0], f(input, params)[1]) 641 ) 642 self.assertTrue( 643 torch.allclose(fx_f(input, params)[1], f(input, params)[0]) 644 or 645 torch.allclose(fx_f(input, params)[1], f(input, params)[1]) 646 ) 647 648 def test_make_fx_model_double_param(self): 649 class Emformer(torch.nn.Module): 650 def __init__( 651 self, 652 input_dim: int = 256, 653 ) -> None: 654 super().__init__() 655 656 self.layer_norm = torch.nn.LayerNorm(input_dim) 657 658 def forward(mod_self, x): # noqa: B902 659 self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) 660 y = mod_self.layer_norm(x) 661 self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) 662 z = mod_self.layer_norm(y) 663 return z 664 665 666 gm = make_fx(Emformer())(torch.randn(16, 1, 256)) 667 ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'} 668 self.assertEqual(len(ops), 2) 669 670 671 def test_make_fx_model_fwd_bwd_wgtupdate(self): 672 class Foo(torch.nn.Module): 673 def __init__(self) -> None: 674 super().__init__() 675 self.linear = torch.nn.Linear(5, 5) 676 677 def forward(self, x): 678 return self.linear(x).relu() 679 680 model = Foo() 681 682 def f(args, params, buffers): 683 for p in params.values(): 684 p.grad = None 685 if not isinstance(args, Iterable): 686 args = [args] 687 params_and_buffers = {**params, **buffers} 688 out = torch.func.functional_call(model, params_and_buffers, args) 689 out.sum().backward() 690 return [p - 1e-4 * p.grad for p in params.values()] 691 692 input = torch.randn(3, 5, requires_grad=True) 693 params = dict(model.named_parameters()) 694 buffers = dict(model.named_buffers()) 695 fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers) 696 # fx may change the order of parameters in list, so using set() to compare 697 # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03 698 self.assertTrue( 699 torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03) 700 or 701 torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03) 702 ) 703 self.assertTrue( 704 torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03) 705 or 706 torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03) 707 ) 708 709 def test_trace_subclasses(self): 710 def f1(x): 711 x = UnwrapTensor(x) 712 y = x * 2 713 return y 714 715 def f2(x): 716 wrapped = UnwrapTensor(x) 717 y = x * wrapped 718 return y 719 720 inp = [torch.randn(5)] 721 self._test(f1, inp) 722 self._test(f2, inp) 723 724 def test_partial_decomp(self): 725 def f(a, b, c): 726 x = torch.addmm(a, b, c) 727 y = torch.addmm(a, b, c, beta=2, alpha=1) 728 return x + y 729 inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)] 730 fx_g = make_fx(f)(*inps) 731 732 def addmm(a, b, c, beta=1, alpha=1): 733 if beta == 1 and alpha == 1: 734 return NotImplemented 735 return beta * a + alpha * (b @ c) 736 737 decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps) 738 739 self.assertEqual(fx_g(*inps), decomposed_fx(*inps)) 740 self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2) 741 self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1) 742 743 def test_decomp_of_capture(self): 744 val = torch.randn(5) 745 746 def f(x): 747 return x.t() + val.t() 748 749 def nop(x): 750 return x.cos() 751 752 traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5)) 753 self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0) 754 755 756 @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') 757 def test_amp_cache(self): 758 layer = torch.nn.Conv2d(3, 3, 3).cuda() 759 760 def f(x, w): 761 return torch.nn.functional.conv2d(x, w, stride=layer.stride) 762 763 inp = torch.randn(4, 3, 10, 10, device='cuda') 764 with torch.autocast('cuda'): 765 out_graph = make_fx(f)(inp, layer.weight).graph 766 out_graph2 = make_fx(f)(inp, layer.weight).graph 767 768 self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes)) 769 for a, b in zip(out_graph.nodes, out_graph2.nodes): 770 self.assertEqual(a.op, b.op) 771 772 def test_strides(self): 773 def f(x): 774 self.assertTrue(x.is_contiguous()) 775 self.assertFalse(x.is_contiguous(memory_format=torch.channels_last)) 776 x = x.permute(0, 3, 1, 2) 777 self.assertFalse(x.is_contiguous()) 778 self.assertTrue(x.is_contiguous(memory_format=torch.channels_last)) 779 return x 780 make_fx(f)(torch.randn(2, 3, 4, 5)) 781 782 def f(x): 783 self.assertTrue(x.is_contiguous()) 784 y = x[:, 1] 785 self.assertFalse(y.is_contiguous()) 786 y = x[:, ::2] 787 self.assertFalse(y.is_contiguous()) 788 return x.cos() 789 790 make_fx(f)(torch.randn(2, 3, 4, 5)) 791 792 def test_pr_86917(self): 793 # Tests the issue brought up here https://github.com/pytorch/pytorch/pull/86917#issuecomment-1283155344 794 def f(a, b): 795 return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10) 796 797 self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)]) 798 799class TestGenericProxyTensorReal(TestGenericProxyTensor): 800 tracing_mode = "real" 801 802 803class TestGenericProxyTensorFake(TestGenericProxyTensor): 804 tracing_mode = "fake" 805 806 807class TestGenericProxyTensorSymbolic(TestGenericProxyTensor): 808 tracing_mode = "symbolic" 809 810 811del TestGenericProxyTensor 812 813 814class TestRealProxyTensor(TestCase): 815 def test_error_on_data_dependent_ops(self): 816 def f(): 817 x = torch.randn([]) 818 y = torch.randn([]) 819 assert torch.allclose(x * y, y * x) 820 z = float(x) 821 z2 = float(y) 822 823 # Smoke tests 824 make_fx(f, _error_on_data_dependent_ops=False)() 825 make_fx(f, pre_dispatch=True, _error_on_data_dependent_ops=False)() 826 827class TestFakeProxyTensor(TestCase): 828 def test_issue82547(self): 829 x = nn.Parameter(torch.randn(3, 3)) 830 831 def f(): 832 return torch.ops.aten.t.default(x) 833 self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")()) 834 835 class A(torch.Tensor): 836 pass 837 838 x = A(torch.randn(3, 3)) 839 self.assertRaisesRegex(TypeError, "Multiple dispatch failed", lambda: make_fx(f, tracing_mode="fake")()) 840 841 def test_use_fake_and_tensor(self): 842 def f(x, y): 843 z = torch.tensor([2.0, 3.0]) 844 return x + y + z 845 846 g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2)) 847 x, y = torch.randn(2), torch.randn(2) 848 self.assertEqual(g(x, y), f(x, y)) 849 850 def test_free_fake(self): 851 def f(x): 852 return torch.add(x, y) 853 854 with FakeTensorMode() as fake_mode: 855 y = torch.randn(2) 856 make_fx(f, tracing_mode="real")(torch.randn(2)) 857 858 def test_fused_adam(self): 859 # See https://github.com/pytorch/pytorch/issues/99356 860 params = [torch.randn(10, 10) for _ in range(10)] 861 grads = [torch.randn(10, 10) for _ in range(10)] 862 exp_avgs = [torch.randn(10, 10) for _ in range(10)] 863 exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)] 864 max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)] 865 state_steps = [torch.tensor(0) for _ in range(10)] 866 867 def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps): 868 (new_params, _, _, _, _) = aten._fused_adam.default( 869 params, 870 grads, 871 exp_avgs, 872 exp_avg_sqs, 873 max_exp_avg_sqs, 874 state_steps, 875 lr=0.1, 876 beta1=0.9, 877 beta2=0.999, 878 weight_decay=0.01, 879 eps=1e-8, 880 amsgrad=False, 881 maximize=False, 882 ) 883 884 for p, new_p in zip(params, new_params): 885 p.copy_(new_p) 886 887 return params 888 889 gm = make_fx(fused_adam, tracing_mode='fake')( 890 params, 891 grads, 892 exp_avgs, 893 exp_avg_sqs, 894 max_exp_avg_sqs, 895 state_steps, 896 ) 897 ensure_ops_have_val = [aten._fused_adam.default, operator.getitem] 898 for n in gm.graph.nodes: 899 if n.op == "call_function" and n.target in ensure_ops_have_val: 900 self.assertIn('val', n.meta) 901 902 def test_alias(self): 903 def f(x): 904 return torch.ops.aten.alias(x) 905 906 r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip() 907 # NB: this should not have a detach call 908 self.assertExpectedInline(r, """\ 909def forward(self, x_1): 910 alias = torch.ops.aten.alias.default(x_1); x_1 = None 911 return alias""") 912 913 def test_meta(self): 914 def f(x): 915 a = x.cos() 916 b = torch.var_mean(a, dim=0) 917 c = b * 2 918 return c 919 920 out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5)) 921 for n in out.graph.nodes: 922 if n.op == 'output': 923 continue 924 self.assertTrue('val' in n.meta) 925 926def _get_node(fx_g, cond): 927 for n in fx_g.graph.nodes: 928 if cond(n): 929 return n 930 raise AssertionError 931 932def _get_free_symbols(shape_env): 933 vars = tuple(shape_env.var_to_val.keys()) 934 return len([var for var in vars if var not in shape_env.replacements]) 935 936def _trace(f, *args): 937 inps = [torch.randn(arg) for arg in args] 938 return make_fx(f, tracing_mode="symbolic")(*inps) 939 940# TODO: Need to test the guards themselves specifically as well 941class TestSymbolicTracing(TestCase): 942 def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True): 943 """ 944 Tests fn traced with trace_inputs against test_inputs 945 Also returns shape env 946 """ 947 trace_inputs = [torch.randn(shape) for shape in trace_inputs] 948 traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs) 949 for input in test_inputs: 950 input = [torch.randn(shape) for shape in input] 951 rx, ry = traced_f(*input), fn(*input) 952 if assert_eq: 953 self.assertEqual(rx, ry) 954 return traced_f 955 956 957 def test_debug_interpreter(self): 958 import torch.library 959 from torch.library import Library 960 961 foo = Library("foo", "DEF") # noqa: TOR901 962 foo.define("foo(Tensor self) -> Tensor") 963 964 # Operator where meta and cpu disagree on strides 965 @torch.library.impl(foo, "foo", "CPU") 966 def foo_cpu(x): 967 return x.clone().T 968 969 @torch.library.impl(foo, "foo", "Meta") 970 def foo_meta(x): 971 return x.clone() 972 973 def f(x): 974 return torch.ops.foo.foo.default(x) 975 976 gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2)) 977 from torch._functorch.compilers import DebugInterpreter 978 979 interp = DebugInterpreter(gm) 980 981 # input mismatch is caught (indicates guard problem) 982 self.assertRaisesRegex( 983 AssertionError, r"3 != 1", 984 lambda: interp.run(torch.randn(3, 3).T), 985 ) 986 987 # Catch the incorrect meta 988 self.assertRaisesRegex( 989 AssertionError, r"\(3, 1\) != \(1, 3\)", 990 lambda: interp.run(torch.randn(3, 3)) 991 ) 992 993 def test_int_input(self): 994 def f(x, y): 995 return x.view(y) 996 997 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 4), 12).code).strip() 998 self.assertExpectedInline(r, """\ 999def forward(self, x_1, y_1): 1000 view = torch.ops.aten.view.default(x_1, [y_1]); x_1 = y_1 = None 1001 return view""") 1002 1003 def test_resize_from_zero(self): 1004 def f(x, y): 1005 x.resize_(y.size(0)) 1006 1007 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip() 1008 self.assertExpectedInline(r, """\ 1009def forward(self, x_1, y_1): 1010 sym_size_int = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None 1011 resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]); x_1 = sym_size_int = resize_ = None 1012 return None""") 1013 1014 def test_broadcast_shapes(self): 1015 def f(x, y): 1016 return torch.functional.broadcast_shapes(x.size(), y.size()[0]) 1017 1018 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 1), torch.empty(5)).code).strip() 1019 self.assertExpectedInline(r, """\ 1020def forward(self, x_1, y_1): 1021 sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None 1022 sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None 1023 return (sym_size_int, sym_size_int_1)""") 1024 1025 def test_deduped_shape(self): 1026 def f(s0, s1, x, y): 1027 return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0]) 1028 1029 x = torch.empty(3, 1) 1030 y = torch.empty(5) 1031 from torch.fx.experimental.symbolic_shapes import ShapeEnv 1032 shape_env = ShapeEnv() 1033 1034 with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode: 1035 x = fake_mode.from_tensor(x) 1036 y = fake_mode.from_tensor(y) 1037 r = str(make_fx(f, tracing_mode="real")(x.shape[0], y.shape[0], x, y).code).strip() 1038 self.assertExpectedInline(r, """\ 1039def forward(self, s0_1, s1_1, x_1, y_1): 1040 empty = torch.ops.aten.empty.memory_format([s0_1], device = device(type='cpu'), pin_memory = False) 1041 return ((s0_1, s1_1), empty)""") 1042 1043 def test_non_deduped_shape(self): 1044 def f(x, y): 1045 return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0]) 1046 1047 x = torch.empty(3, 1) 1048 y = torch.empty(5) 1049 from torch.fx.experimental.symbolic_shapes import ShapeEnv 1050 shape_env = ShapeEnv() 1051 1052 with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode: 1053 x = fake_mode.from_tensor(x) 1054 y = fake_mode.from_tensor(y) 1055 r = str(make_fx(f, tracing_mode="real")(x, y).code).strip() 1056 self.assertExpectedInline(r, """\ 1057def forward(self, x_1, y_1): 1058 sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None 1059 sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None 1060 empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False) 1061 return ((sym_size_int, sym_size_int_1), empty)""") 1062 1063 def test_unary(self): 1064 def f(x): 1065 assert x.shape[0] < 20 1066 return x.cos() 1067 test_inputs = [] 1068 test_inputs.append([(2, 5)]) 1069 test_inputs.append([(6, 8)]) 1070 gm = self._test_dynamic(f, [(3, 4)], test_inputs) 1071 self.assertTrue(eval_guards(gm, torch.randn(4, 5))) 1072 self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}") 1073 self.assertFalse(eval_guards(gm, torch.randn(25, 5))) 1074 self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""") 1075 1076 def test_repeat_interleave(self): 1077 def f(src_tokens, beam_size_src): 1078 return src_tokens.repeat_interleave(beam_size_src.size(0), 0) 1079 1080 prompt_size = 64 1081 vocab_size = 64 1082 batch_size = 4 1083 src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size)) 1084 gm = make_fx(f, tracing_mode="symbolic")(src_tokens, torch.randn(5)) 1085 self.assertEqual(len(gm.shape_env.guards), 0) 1086 1087 def test_non_symint_size_spec(self): 1088 # this isn't really a proxy tensor test, but it's the most convenient 1089 # way to get a fake tensor with symbolic sizes 1090 def f(x): 1091 torch._C._non_sym_sizes(x) 1092 return x + 1 1093 1094 x = torch.randn(2, 3) 1095 make_fx(f, tracing_mode="symbolic")(x) 1096 1097 # https://github.com/pytorch/pytorch/issues/108195 1098 def test_symbolic_repeat_interleave(self): 1099 def f(y, x): 1100 return y.repeat_interleave(x, dim=1) 1101 1102 y = torch.tensor([[1, 2], [3, 4]]) 1103 x = torch.tensor([2, 3]) 1104 r = str(make_fx(f, tracing_mode="symbolic")(y, x).code).strip() 1105 self.assertExpectedInline(r, """\ 1106def forward(self, y_1, x_1): 1107 repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1); x_1 = None 1108 index_select = torch.ops.aten.index_select.default(y_1, 1, repeat_interleave); y_1 = repeat_interleave = None 1109 return index_select""") 1110 1111 def test_mod_gcd_unbacked(self): 1112 def f(_a, _b, _stride): 1113 a = _a.item() 1114 b = _b.item() 1115 stride = _stride.item() 1116 torch._check_is_size(a) 1117 torch._check_is_size(b) 1118 torch._check_is_size(stride) 1119 ta = torch.randn(a * stride) 1120 tb = torch.randn(b * stride) 1121 r = torch.cat([ta, tb]) 1122 return r.view(a + b, stride) 1123 1124 _a = torch.tensor(30) 1125 _b = torch.tensor(20) 1126 _stride = torch.tensor(10) 1127 r = str(make_fx(f, tracing_mode="symbolic")(_a, _b, _stride).code).strip() 1128 self.assertExpectedInline(r, """\ 1129def forward(self, _a_1, _b_1, _stride_1): 1130 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(_a_1); _a_1 = None 1131 _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(_b_1); _b_1 = None 1132 _local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(_stride_1); _stride_1 = None 1133 mul = _local_scalar_dense * _local_scalar_dense_2 1134 randn = torch.ops.aten.randn.default([mul], device = device(type='cpu'), pin_memory = False); mul = None 1135 mul_1 = _local_scalar_dense_1 * _local_scalar_dense_2 1136 randn_1 = torch.ops.aten.randn.default([mul_1], device = device(type='cpu'), pin_memory = False); mul_1 = None 1137 cat = torch.ops.aten.cat.default([randn, randn_1]); randn = randn_1 = None 1138 add = _local_scalar_dense + _local_scalar_dense_1; _local_scalar_dense = _local_scalar_dense_1 = None 1139 view = torch.ops.aten.view.default(cat, [add, _local_scalar_dense_2]); cat = add = _local_scalar_dense_2 = None 1140 return view""") 1141 1142 def test_cumsum_unbacked(self): 1143 def f(x): 1144 y = x.item() 1145 z = torch.randn((3, y, 3)) 1146 return z.cumsum(0) 1147 1148 r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([5])).code).strip() 1149 self.assertExpectedInline( 1150 r, """\ 1151def forward(self, x_1): 1152 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None 1153 randn = torch.ops.aten.randn.default([3, _local_scalar_dense, 3], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None 1154 cumsum = torch.ops.aten.cumsum.default(randn, 0); randn = None 1155 return cumsum""" # noqa: B950 1156 ) 1157 1158 1159 def test_repeat_interleave_unbacked_output_size(self): 1160 def f(x, y): 1161 s = x.sum().item() 1162 return y.repeat_interleave(x, dim=0, output_size=s) 1163 1164 r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([2, 3]), torch.randn(2)).code).strip() 1165 self.assertExpectedInline( 1166 r, """\ 1167def forward(self, x_1, y_1): 1168 sum_1 = torch.ops.aten.sum.default(x_1) 1169 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(sum_1); sum_1 = None 1170 repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1, output_size = _local_scalar_dense); x_1 = _local_scalar_dense = None 1171 index_select = torch.ops.aten.index_select.default(y_1, 0, repeat_interleave); y_1 = repeat_interleave = None 1172 return index_select""" # noqa: B950 1173 ) 1174 1175 def test_arange_unbacked_output_size(self): 1176 def f(x): 1177 return torch.arange(0, x) 1178 1179 r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10)).code).strip() 1180 self.assertExpectedInline( 1181 r, """\ 1182def forward(self, x_1): 1183 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None 1184 arange = torch.ops.aten.arange.start(0, _local_scalar_dense, device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None 1185 return arange""" # noqa: B950 1186 ) 1187 1188 def test_adv_index_batch(self): 1189 def f(src_tokens): 1190 bsz, src_len = src_tokens.size()[:2] 1191 start_step = src_tokens.shape[1] 1192 beam_size = 1 1193 generate_size = 64 1194 max_len = src_len + generate_size 1195 tokens = torch.zeros(bsz * beam_size, max_len).to(src_tokens).long().fill_(0) 1196 tokens[:, :start_step] = src_tokens.repeat_interleave(beam_size, 0) 1197 return tokens 1198 1199 prompt_size = 64 1200 vocab_size = 64 1201 batch_size = 4 1202 src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size)) 1203 gm = make_fx(f, tracing_mode="symbolic")(src_tokens) 1204 # Guards to rule out batch_size == sys.maxsize (wobbling between 2 and 1205 # 1 ok) 1206 self.assertEqual(len(gm.shape_env.guards), 1) 1207 1208 @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') 1209 def test_cpu_scalar_cuda(self): 1210 # Extracted from wave2vec2 1211 def f(a, b): 1212 return (a * b) @ b 1213 1214 r = str( 1215 make_fx(f, tracing_mode="symbolic")( 1216 torch.tensor(1.0), torch.randn(2, 2, device='cuda') 1217 ).code 1218 ).strip() 1219 self.assertExpectedInline(r, """\ 1220def forward(self, a_1, b_1): 1221 mul = torch.ops.aten.mul.Tensor(a_1, b_1); a_1 = None 1222 mm = torch.ops.aten.mm.default(mul, b_1); mul = b_1 = None 1223 return mm""") 1224 1225 def test_binary_broadcast(self): 1226 def f(a, b): 1227 c = a * b 1228 return c 1229 1230 test_inputs = [] 1231 test_inputs.append([(1, 5), (3, 1)]) 1232 test_inputs.append([(1, 4), (4, 1)]) 1233 shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs).shape_env 1234 assert len(shape_env.guards) == 0 1235 1236 def test_multiply_shape(self): 1237 def f(a): 1238 return torch.empty(a.shape[0] * 2) 1239 1240 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() 1241 self.assertExpectedInline(r, """\ 1242def forward(self, a_1): 1243 sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None 1244 mul = sym_size_int * 2; sym_size_int = None 1245 empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None 1246 return empty""") 1247 1248 def test_item(self): 1249 def f(a): 1250 r = a.item() 1251 return r * a 1252 1253 r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip() 1254 self.assertExpectedInline(r, """\ 1255def forward(self, a_1): 1256 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1) 1257 mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense); a_1 = _local_scalar_dense = None 1258 return mul""") 1259 1260 def test_tensor_symfloat(self): 1261 def f(a): 1262 r = torch.tensor(a.size(0) ** 2.0) 1263 assert r.dtype is torch.float 1264 return r 1265 1266 gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2)) 1267 r = str(gm.code).strip() 1268 # NB: this specializes, which is fine, the point is to make sure the 1269 # dtype inference is correct 1270 self.assertExpectedInline(r, """\ 1271def forward(self, a_1): 1272 _tensor_constant0 = self._tensor_constant0 1273 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None 1274 return lift_fresh_copy""") 1275 self.assertEqual(gm._tensor_constant0, torch.tensor(4.0)) 1276 1277 def test_item_to_constructor(self): 1278 def f(a): 1279 r = a.item() 1280 return torch.empty(r) 1281 1282 r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip() 1283 self.assertExpectedInline( 1284 r, """\ 1285def forward(self, a_1): 1286 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None 1287 empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None 1288 return empty""" # noqa: B950 1289 ) 1290 1291 1292 def test_setitem_symint(self): 1293 # from moco 1294 # https://github.com/pytorch/pytorch/issues/101939 1295 def f(x): 1296 x[0] = x.size(0) 1297 return x 1298 1299 r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(10)).code).strip() 1300 self.assertExpectedInline( 1301 r, """\ 1302def forward(self, x_1): 1303 sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) 1304 scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size_int = None 1305 select = torch.ops.aten.select.int(x_1, 0, 0) 1306 copy_ = torch.ops.aten.copy_.default(select, scalar_tensor); select = scalar_tensor = copy_ = None 1307 return x_1""" # noqa: B950 1308 ) 1309 1310 def test_dynamic_pointwise_scalar(self): 1311 def f(gravity, mask): 1312 gravity[mask, 0] = gravity[mask, 0] * -1 1313 1314 r = str(make_fx(f, tracing_mode="symbolic")( 1315 torch.randn((12, 4)), 1316 torch.randint(0, 2, (12,), dtype=torch.bool) 1317 ).code).strip() 1318 self.assertExpectedInline(r, """\ 1319def forward(self, gravity_1, mask_1): 1320 select = torch.ops.aten.select.int(gravity_1, 1, 0) 1321 index = torch.ops.aten.index.Tensor(select, [mask_1]); select = None 1322 mul = torch.ops.aten.mul.Tensor(index, -1); index = None 1323 select_1 = torch.ops.aten.select.int(gravity_1, 1, 0); gravity_1 = None 1324 index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = index_put_ = None 1325 return None""") 1326 1327 def test_reflect_r_over_x(self): 1328 def reflect_R_over_x(R): 1329 reflect = torch.eye(3, device=R.device) 1330 reflect[0, 0] = -1 1331 return reflect @ R @ reflect 1332 1333 def f(crop_camera, mask): 1334 crop_camera[mask] = reflect_R_over_x(crop_camera[mask]) 1335 1336 r = str(make_fx(f, tracing_mode="symbolic")( 1337 torch.randn((12, 3, 3)), 1338 torch.randint(0, 2, (12,), dtype=torch.bool) 1339 ).code).strip() 1340 self.assertExpectedInline(r, """\ 1341def forward(self, crop_camera_1, mask_1): 1342 index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1]) 1343 eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False) 1344 _tensor_constant0 = self._tensor_constant0 1345 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None 1346 select = torch.ops.aten.select.int(eye, 0, 0) 1347 select_1 = torch.ops.aten.select.int(select, 0, 0); select = None 1348 copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = copy_ = None 1349 sym_size_int = torch.ops.aten.sym_size.int(index, 0) 1350 expand = torch.ops.aten.expand.default(eye, [sym_size_int, 3, 3]) 1351 view = torch.ops.aten.view.default(expand, [sym_size_int, 3, 3]); expand = None 1352 sym_size_int_1 = torch.ops.aten.sym_size.int(crop_camera_1, 1) 1353 sym_size_int_2 = torch.ops.aten.sym_size.int(crop_camera_1, 2) 1354 expand_1 = torch.ops.aten.expand.default(index, [sym_size_int, sym_size_int_1, sym_size_int_2]); index = None 1355 view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None 1356 bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None 1357 view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None 1358 mul_4 = sym_size_int * 3 1359 view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None 1360 mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None 1361 view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None 1362 index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = index_put_ = None 1363 return None""") # noqa: B950 1364 1365 def test_unbacked_slice(self): 1366 def f(x, m): 1367 x = x[m] 1368 return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)] 1369 1370 make_fx(f, tracing_mode="symbolic")( 1371 torch.randn((12, 3, 3)), 1372 torch.randint(0, 2, (12,), dtype=torch.bool) 1373 ) 1374 1375 @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") 1376 def test_unbacked_batch_resnet(self): 1377 mod = torchvision.models.resnet18() 1378 1379 def f(x, mask, params, buffers): 1380 for p in itertools.chain([x, mask], params.values(), buffers.values()): 1381 for s in p.shape: 1382 guard_int(s) 1383 x = x[mask] 1384 torch._check(x.shape[0] >= 1) 1385 for p in params.values(): 1386 p.grad = None 1387 return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum() 1388 1389 make_fx(f, tracing_mode="symbolic")( 1390 torch.randn(3, 3, 250, 250), 1391 torch.randint(0, 2, (3,), dtype=torch.bool), 1392 dict(mod.named_parameters()), 1393 dict(mod.named_buffers()), 1394 ) 1395 1396 def test_boolean_index(self): 1397 def f(images, handedness, valid): 1398 images = images[valid] 1399 handedness = handedness[valid] 1400 right_hand_mask = handedness == 1 1401 images[right_hand_mask] = images[right_hand_mask].flip(-1) 1402 1403 r = str(make_fx(f, tracing_mode="symbolic")( 1404 torch.randint(0, 256, (512, 1, 96, 96)), 1405 torch.randint(0, 1, (512,)), 1406 torch.randint(0, 2, (512,), dtype=torch.bool) 1407 ).code).strip() 1408 self.assertExpectedInline(r, """\ 1409def forward(self, images_1, handedness_1, valid_1): 1410 index = torch.ops.aten.index.Tensor(images_1, [valid_1]); images_1 = None 1411 index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]); handedness_1 = valid_1 = None 1412 eq = torch.ops.aten.eq.Scalar(index_1, 1); index_1 = None 1413 index_2 = torch.ops.aten.index.Tensor(index, [eq]) 1414 flip = torch.ops.aten.flip.default(index_2, [-1]); index_2 = None 1415 index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = index_put_ = None 1416 return None""") 1417 1418 def test_neg_shape(self): 1419 def f(a): 1420 return torch.empty(-a.shape[0] + 10) 1421 1422 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip() 1423 self.assertExpectedInline(r, """\ 1424def forward(self, a_1): 1425 sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None 1426 neg = -sym_size_int; sym_size_int = None 1427 add = neg + 10; neg = None 1428 empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None 1429 return empty""") 1430 1431 def test_unbacked_unification(self): 1432 def f(x, y): 1433 z = torch.zeros(x.item()) 1434 return z + y 1435 1436 r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip() 1437 self.assertExpectedInline(r, """\ 1438def forward(self, x_1, y_1): 1439 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None 1440 zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None 1441 add = torch.ops.aten.add.Tensor(zeros, y_1); zeros = y_1 = None 1442 return add""") # noqa: B950 1443 1444 def test_reshape_divisibility_unbacked(self): 1445 def f(x): 1446 i0 = x.item() 1447 r = torch.zeros(i0, 4, 20) 1448 r = r.transpose(2, 1) 1449 return r.reshape(-1, 80) 1450 make_fx(f, tracing_mode="symbolic")(torch.tensor(24)) 1451 1452 def test_view_divisibility_unbacked(self): 1453 def f(x): 1454 i0 = x.item() 1455 r = torch.zeros(i0, 192) 1456 return r.view(12, -1, 192) 1457 make_fx(f, tracing_mode="symbolic")(torch.tensor(24)) 1458 1459 @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') 1460 def test_view_divisibility_unbacked_relatively_prime(self): 1461 # See https://github.com/pytorch/pytorch/issues/123651 1462 def f(x): 1463 i0 = x.item() 1464 torch._check_is_size(i0) 1465 # To trigger the original issue, the max bound has to 1466 # be chosen such that 448 / 447 < 2 (which it is.) 1467 torch._check(i0 <= 448) 1468 return torch.zeros(256 * i0).view(-1, 447) 1469 make_fx(f, tracing_mode="symbolic")(torch.tensor(256 * 447, device="cuda")) 1470 1471 def test_unbacked_unify_guard(self): 1472 def f(x, y): 1473 z = torch.zeros(x.item()) 1474 torch._check(z.size(0) == y.size(0)) # refines i0 = s0 1475 if z.size(0) == 4: 1476 return y * 2 1477 else: 1478 return y + 2 1479 1480 r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip() 1481 self.assertExpectedInline(r, """\ 1482def forward(self, x_1, y_1): 1483 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None 1484 zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = zeros = None 1485 add = torch.ops.aten.add.Tensor(y_1, 2); y_1 = None 1486 return add""") # noqa: B950 1487 1488 @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') 1489 @unittest.expectedFailure 1490 def test_unbacked_unify_guard_transitivity(self): 1491 def f(x1, x2, y): 1492 z1 = torch.zeros(x1.item()) 1493 z2 = torch.zeros(x2.item()) 1494 torch._check(z1.size(0) == z2.size(0)) # refines i0 = i1 1495 torch._check(z2.size(0) == y.size(0)) # refines i0 = s0 1496 if z1.size(0) == 4: 1497 return y * 2 1498 else: 1499 return y + 2 1500 1501 gm = make_fx(f, tracing_mode="symbolic")( 1502 torch.tensor(10, device="cuda"), 1503 torch.tensor(10, device="cuda"), 1504 torch.randn(10, device="cuda") 1505 ) 1506 insert_deferred_runtime_asserts(gm, gm.shape_env, "test") 1507 gm.recompile() 1508 r = str(gm.code).strip() 1509 # self.assertExpectedInline( 1510 # r, """""" # noqa: B950 1511 # ) 1512 1513 @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') 1514 def test_unbacked_unify_dependency_violation(self): 1515 def f(x1, x2, x3, y): 1516 z1 = x1.item() 1517 torch._check(z1 // 9 == 1) 1518 z2 = x2.item() 1519 z3 = x3.item() 1520 torch._check(z1 == z2 + z3) 1521 return y * 2 1522 if z2 + z3 == z1: 1523 return y * 2 1524 else: 1525 return y + 3 1526 1527 # NB: inputs are done as CUDA to ensure they aren't queried to be 1528 # backed 1529 1530 gm = make_fx(f, tracing_mode="symbolic")( 1531 torch.tensor(10, device="cuda"), torch.tensor(5, device="cuda"), 1532 torch.tensor(5, device="cuda"), torch.randn(1, device="cuda") 1533 ) 1534 insert_deferred_runtime_asserts(gm, gm.shape_env, "test") 1535 gm.recompile() 1536 self.assertEqual(gm( 1537 torch.tensor(12, device="cuda"), torch.tensor(6, device="cuda"), 1538 torch.tensor(6, device="cuda"), torch.tensor([1.0], device="cuda")), 1539 torch.tensor([2.0], device="cuda") 1540 ) 1541 with self.assertRaises(RuntimeError): 1542 gm( 1543 torch.tensor(20, device="cuda"), torch.tensor(10, device="cuda"), 1544 torch.tensor(10, device="cuda"), torch.tensor([1.0], device="cuda") 1545 ) 1546 1547 1548 def test_split_unbacked_sizes(self): 1549 def f(lengths, values): 1550 # tolist not directly supported atm 1551 sizes = [lengths[i].item() for i in range(lengths.size(0))] 1552 for s in sizes: 1553 # TODO(avik): no assertion generated with torch._check_is_size? 1554 torch._constrain_as_size(s) 1555 return torch.split(values, sizes) 1556 1557 r = str(make_fx(f, tracing_mode="symbolic")( 1558 torch.tensor([2, 3, 4]), 1559 torch.randn(9) 1560 ).code).strip() 1561 self.assertExpectedInline(r, """\ 1562def forward(self, lengths_1, values_1): 1563 select = torch.ops.aten.select.int(lengths_1, 0, 0) 1564 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select); select = None 1565 select_1 = torch.ops.aten.select.int(lengths_1, 0, 1) 1566 _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None 1567 select_2 = torch.ops.aten.select.int(lengths_1, 0, 2); lengths_1 = None 1568 _local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2); select_2 = None 1569 sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense); sym_constrain_range_for_size = None 1570 sym_constrain_range_for_size_1 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_1); sym_constrain_range_for_size_1 = None 1571 sym_constrain_range_for_size_2 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_2); sym_constrain_range_for_size_2 = None 1572 split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]); values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None 1573 getitem = split_with_sizes[0] 1574 getitem_1 = split_with_sizes[1] 1575 getitem_2 = split_with_sizes[2]; split_with_sizes = None 1576 return (getitem, getitem_1, getitem_2)""") # noqa: B950 1577 1578 def test_invalidate_nonzero(self): 1579 ok = False 1580 1581 def f(a): 1582 nonlocal ok 1583 b = a.clone() 1584 x = b.nonzero() 1585 x1 = b.nonzero() 1586 x2 = b.nonzero() 1587 assert x1.shape[0] == x2.shape[0] 1588 ok = True 1589 b.normal_() 1590 y = b.nonzero() 1591 try: 1592 bool(x1.shape[0] == y.shape[0]) 1593 self.fail("didn't raise exception") 1594 except GuardOnDataDependentSymNode: 1595 pass 1596 1597 make_fx(f, tracing_mode="symbolic")(torch.randn(4)) 1598 1599 @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True) 1600 def test_invalidate_nonzero_propagate_real_tensors(self): 1601 def f(a): 1602 b = a.clone() 1603 x = b.nonzero() 1604 x1 = b.nonzero() 1605 x2 = b.nonzero() 1606 assert x1.shape[0] == x2.shape[0] 1607 b.normal_() 1608 y = b.nonzero() 1609 # Because you're not actually going to generate exactly zero with 1610 # normal_ lol 1611 assert x1.shape[0] == y.shape[0] 1612 1613 make_fx(f, tracing_mode="symbolic")(torch.randn(4)) 1614 1615 def test_sqrt_size(self): 1616 def f(a): 1617 return a / a.size(-1) ** 0.5 1618 1619 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() 1620 self.assertExpectedInline(r, """\ 1621def forward(self, a_1): 1622 sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) 1623 sym_float = torch.sym_float(sym_size_int); sym_size_int = None 1624 pow_1 = sym_float ** 0.5; sym_float = None 1625 div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None 1626 return div""") 1627 1628 def test_make_fx_with_custom_tracer_preserving_nn_module_stack(self): 1629 1630 class Bar(torch.nn.Module): 1631 def __init__(self) -> None: 1632 super().__init__() 1633 1634 def forward(self, x): 1635 return x + 1 1636 1637 class Foo(torch.nn.Module): 1638 def __init__(self) -> None: 1639 super().__init__() 1640 self.bar = Bar() 1641 1642 def forward(self, x): 1643 return x + self.bar(x) 1644 1645 gm = make_fx(Foo())(torch.randn(4, 4)) 1646 for node in gm.graph.nodes: 1647 self.assertTrue("nn_module_stack" not in node.meta) 1648 1649 foo = Foo() 1650 1651 def functional_call(*args, **kwargs): 1652 with stateless._reparametrize_module(foo, {}): 1653 return foo(*args, **kwargs) 1654 1655 functional_call._orig_mod = foo 1656 1657 gm_with_stack = make_fx(functional_call, record_module_stack=True)(torch.randn(4, 4)) 1658 found = False 1659 for node in gm_with_stack.graph.nodes: 1660 if "nn_module_stack" in node.meta: 1661 if len(node.meta["nn_module_stack"]) == 1: 1662 self.assertTrue("custom_tracer_preserving_nn_module_stack.<locals>.Foo" in str(node.meta["nn_module_stack"])) 1663 found = True 1664 elif len(node.meta["nn_module_stack"]) == 2: 1665 self.assertTrue("preserving_nn_module_stack.<locals>.Bar" in str(node.meta["nn_module_stack"])) 1666 found = True 1667 else: 1668 # there can be at most 2 level 1669 self.assertTrue(False) 1670 1671 self.assertTrue(found) 1672 1673 gm_without_stack = make_fx(functional_call)(torch.randn(4, 4)) 1674 for node in gm_without_stack.graph.nodes: 1675 self.assertTrue("nn_module_stack" not in node.meta) 1676 1677 def test_symint_to_tensor(self): 1678 def f(a): 1679 return a / a.shape[0] 1680 1681 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() 1682 self.assertExpectedInline(r, """\ 1683def forward(self, a_1): 1684 sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) 1685 div = torch.ops.aten.div.Tensor(a_1, sym_size_int); a_1 = sym_size_int = None 1686 return div""") 1687 1688 r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip() 1689 self.assertExpectedInline(r, """\ 1690def forward(self, a_1): 1691 sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) 1692 sym_float = torch.sym_float(sym_size_int); sym_size_int = None 1693 div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None 1694 return div""") 1695 1696 def test_cat(self): 1697 def f(a, b): 1698 val = torch.mul(a, b) 1699 out = torch.cat([val, val]) 1700 if out.shape[0] * out.shape[1] > 20: 1701 out = out.cos() 1702 return out 1703 1704 test_inputs = [] 1705 test_inputs.append([(1, 5), (6, 1)]) 1706 test_inputs.append([(1, 4), (3, 1)]) 1707 gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs) 1708 self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1))) 1709 self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1))) 1710 self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""") 1711 1712 def test_new_empty(self): 1713 def f(a, b): 1714 return a.new_empty(b.shape[0], b.shape[1] * 2) 1715 1716 self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env 1717 1718 def test_size_with_tensor(self): 1719 # I think I messed up writing this test case originally, I think 1720 # I'm supposed to hit an error case, but the code here works in both 1721 # eager and tracing 1722 def f(tensor): 1723 max_size = torch.tensor([800, 1216], dtype=torch.int64) 1724 batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size) 1725 return tensor.new_empty(batch_shape) 1726 1727 a = torch.randn(3, 800, 1199) 1728 f(a) 1729 make_fx(f, tracing_mode="symbolic")(a) 1730 1731 def test_fake_tensor_as_size(self): 1732 def f(x): 1733 r = torch.zeros([x]) 1734 return r 1735 1736 fx_g = make_fx(f, tracing_mode="symbolic")(torch.tensor(4)) 1737 self.assertExpectedInline(fx_g.code.strip(), """\ 1738def forward(self, x_1): 1739 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None 1740 zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None 1741 return zeros""") # noqa: B950 1742 1743 def test_expand(self): 1744 def f(a): 1745 b = torch.mul(a, a) 1746 c = b.expand(a.shape) 1747 return c 1748 1749 self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]]) 1750 self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]]) 1751 1752 def test_metadata(self): 1753 def f(a, b): 1754 d = a.new_empty(a.shape[0] + b.shape[0]) 1755 return d 1756 fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4)) 1757 meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default) 1758 meta_d = _get_node(fx_g, lambda x: x.target == operator.add) 1759 self.assertTrue(meta_c.meta['val'].shape[0].node.expr == meta_d.meta['val'].node.expr) 1760 1761 def test_metadata_fresh(self): 1762 def f(x): 1763 assert x.shape[0] == 3 1764 return x.cos() 1765 1766 fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3)) 1767 meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default) 1768 meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder') 1769 self.assertTrue(meta_cos.meta['val'].shape[0] == 3) 1770 # Checks if the input expr has been updated even though the constraint 1771 # happened afterwards 1772 self.assertTrue(meta_inp.meta['val'].shape[0] == 3) 1773 1774 def test_elementwise_meta_with_sym_numbers(self): 1775 def f(x, offset, as_sym_float=False): 1776 x0 = x.size()[0] 1777 if as_sym_float: 1778 x0 = torch.sym_float(x0) 1779 return torch.add(x0, offset) 1780 1781 fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False) 1782 meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) 1783 self.assertEqual(meta_add.meta['val'].shape, ()) 1784 self.assertEqual(meta_add.meta['val'].dtype, torch.float32) 1785 1786 fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False) 1787 meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) 1788 self.assertEqual(meta_add.meta['val'].shape, ()) 1789 self.assertEqual(meta_add.meta['val'].dtype, torch.int64) 1790 1791 fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True) 1792 meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) 1793 self.assertEqual(meta_add.meta['val'].shape, ()) 1794 self.assertEqual(meta_add.meta['val'].dtype, torch.float32) 1795 1796 def test_return_symint(self): 1797 def f(x): 1798 return x.shape[0], x.cos(), x.shape[0] / 5 1799 self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]]) 1800 1801 def f(x): 1802 return x.shape 1803 self._test_dynamic(f, [(5, 3)], [[(4, 6)]]) 1804 1805 def test_rmethod(self): 1806 def f(x): 1807 return x.size(0) + x 1808 self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]]) 1809 1810 def test_mega_guard(self): 1811 def f(a, b): 1812 assert a.shape[0] == b.shape[0] * 2 1813 return a.cos() 1814 fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8)) 1815 from torch._dynamo.source import LocalSource 1816 self.assertExpectedInline( 1817 str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)), # noqa: B950 1818 """["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]""" # noqa: B950 1819 ) 1820 self.assertExpectedInline( 1821 str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)), # noqa: B950 1822 """["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]""" # noqa: B950 1823 ) 1824 1825 def test_guard_upperbound_range_refinement(self): 1826 def f(a): 1827 assert a.shape[0] > 5 and a.shape[0] > 12 1828 return a.cos() 1829 tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15)) 1830 self.assertExpectedInline(show_guards(tensor), """13 <= L['a'].size()[0]""") 1831 1832 def test_guard_lowerbound_range_refinement(self): 1833 def f(a): 1834 assert a.shape[0] < 20 and a.shape[0] < 30 1835 return a.cos() 1836 tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15)) 1837 self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] <= 19""") 1838 1839 def test_guard_upperbound_range_refinement_multivariate(self): 1840 def f(a): 1841 assert a.shape[0] > 5 and a.shape[0] > 12 1842 assert a.shape[1] > 5 and a.shape[1] > a.shape[0] 1843 return a.cos() 1844 tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 20))) 1845 self.assertExpectedInline(show_guards(tensor), """\ 1846L['a'].size()[1] > L['a'].size()[0] 184713 <= L['a'].size()[0] 184814 <= L['a'].size()[1]""") 1849 1850 def test_guard_lowerbound_range_refinement_multivariate(self): 1851 def f(a): 1852 assert a.shape[0] < 20 and a.shape[0] < 30 1853 assert a.shape[1] < 30 and a.shape[1] < a.shape[0] 1854 return a.cos() 1855 tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 5))) 1856 self.assertExpectedInline( 1857 show_guards(tensor), 1858 """\ 1859L['a'].size()[1] < L['a'].size()[0] 1860L['a'].size()[0] <= 19 1861L['a'].size()[1] <= 18""") 1862 1863 def test_sym_storage_offset(self): 1864 def f(x, y): 1865 return x + y 1866 1867 inp = (torch.randn(8)[3:], torch.randn(5)) 1868 fx_g = make_fx(f, tracing_mode="symbolic")(*inp) 1869 inp = (torch.randn(8)[3:], torch.randn(5)) 1870 self.assertEqual(fx_g(*inp), f(*inp)) 1871 1872 def _assert_no_guards(self, fx_g, free_symbols): 1873 assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val 1874 assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards() 1875 1876 def test_guards_equal(self): 1877 def f(a, b): 1878 return a * b 1879 1880 # NB: Numbers are carefully chosen to avoid duck shaping from applying 1881 1882 fx_g = _trace(f, (5, 6), (5, 6)) 1883 self._assert_no_guards(fx_g, 2) 1884 1885 fx_g = _trace(f, (5, 6, 7), (5, 6, 7)) 1886 self._assert_no_guards(fx_g, 3) 1887 1888 fx_g = _trace(f, (5, 1), (1, 6)) 1889 self._assert_no_guards(fx_g, 2) 1890 1891 def f(a, b, c, d): 1892 a = a + b 1893 cat = torch.cat([c, d]) 1894 return a + cat 1895 1896 fx_g = _trace(f, 7, 7, 4, 3) 1897 self._assert_no_guards(fx_g, 2) 1898 1899 def f(a, b, c, d, e): 1900 vals = [a, b, c, d, e] 1901 x = a 1902 for idx in range(len(vals) - 1): 1903 x = torch.cat([x, vals[idx]]) + vals[idx + 1] 1904 return x 1905 1906 fx_g = _trace(f, 2, 4, 8, 16, 32) 1907 self._assert_no_guards(fx_g, 1) 1908 1909 def f(a, b): 1910 a = a.view(b.shape[0]) 1911 return a + b.sum() 1912 1913 fx_g = _trace(f, (4, 2), 8) 1914 self._assert_no_guards(fx_g, 2) 1915 1916 fx_g = _trace(f, (4, 2), (8, 5)) 1917 self._assert_no_guards(fx_g, 3) 1918 1919 fx_g = _trace(f, (2, 3, 4), 24) 1920 self._assert_no_guards(fx_g, 3) 1921 1922 def test_nonidentity_transitive_guards(self): 1923 def f(a, b, c, d, e): 1924 vals = [a, b, c, d, e] 1925 cat_vals = [] 1926 for idx in range(len(vals) - 1): 1927 cat_vals.append(torch.cat([vals[idx], vals[idx]])) 1928 final_vals = [] 1929 for a, b in reversed(list(zip(cat_vals, vals[1:]))): 1930 final_vals.append(a + b) 1931 return final_vals 1932 1933 fx_g = _trace(f, 2, 4, 8, 16, 32) 1934 self.assertExpectedInline(show_guards(fx_g), """""") 1935 1936 @torch.fx.experimental._config.patch(translation_validation=True) 1937 def test_constant_specialization(self): 1938 def f(t): 1939 assert t.shape[0] == 10 1940 return t 1941 1942 tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(10)) 1943 self.assertExpectedInline(show_guards(tensor), """""") 1944 1945 1946make_fx_failures = { 1947 # unknown 1948 xfail('allclose'), 1949 xfail('equal'), 1950 # empty 1951 skip('new_empty'), 1952 skip('empty_like'), 1953 skip('empty'), 1954 skip('empty_permuted'), 1955 # flaky 1956 skip('linalg.lstsq', 'grad_oriented'), 1957 skip('nn.functional.max_unpool1d', '', device_type='cpu'), 1958 skip('nn.functional.max_unpool2d', '', device_type='cpu'), 1959 skip('nn.functional.max_unpool3d', '', device_type='cpu'), 1960 skip('linalg.lstsq'), # flaky, probably just a precision issue 1961 1962 # data-dependent control flow 1963 skip('item'), 1964 xfail('cov'), 1965 xfail('nn.functional.gaussian_nll_loss'), 1966 xfail('tensor_split'), 1967 xfail('corrcoef'), 1968 xfail('quantile'), 1969 xfail('nanquantile'), 1970 1971 # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse 1972 xfail('sparse.sampled_addmm'), 1973 xfail('sparse.mm', 'reduce'), 1974 1975 # proxy tensor doesn't support sparse correctly right now 1976 skip('to_sparse'), 1977 # segfaults 1978 skip('block_diag'), 1979 1980 # AssertionError: Tensor-likes are not close! 1981 skip('empty_strided', '', device_type='cpu'), 1982} 1983 1984only_real_tensor_failures = { 1985 xfail('narrow'), 1986} 1987 1988only_fake_tensor_failures = { 1989 xfail('narrow'), 1990} 1991 1992fake_tensor_failures = { 1993 # ASAN failures due to divide by 0 1994 skip('nn.functional.nll_loss'), 1995} 1996 1997symbolic_tensor_failures = { 1998 xfail('combinations', ''), 1999 xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition 2000 xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c... 2001 xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition 2002 xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. 2003 xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom... 2004 xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition 2005 xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition 2006 xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. 2007 xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition 2008 2009 xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but... 2010 2011 # many complex operators incorrect striding, metadata 2012 xfail('fft.fft', ''), 2013 xfail('fft.hfft2', ''), 2014 xfail('fft.hfft', ''), 2015 xfail('fft.hfftn', ''), 2016 xfail('fft.ifft', ''), 2017 xfail('fft.ihfft2', ''), 2018 xfail('fft.ihfft', ''), 2019 xfail('fft.ihfftn', ''), 2020 xfail('fft.ihfft2', ''), 2021 xfail('fft.irfft2', ''), 2022 xfail('fft.irfft', ''), 2023 xfail('fft.irfftn', ''), 2024 xfail('fft.rfft2', ''), 2025 xfail('fft.rfft', ''), 2026 xfail('fft.rfftn', ''), 2027 xfail('stft', '') 2028} 2029symbolic_tensor_segfaults = { 2030 skip('nn.functional.batch_norm') # Segfault?? 2031} 2032 2033symbolic_tensor_failures.update(symbolic_tensor_segfaults) 2034 2035inplace_symbolic_tensor_failures = { 2036 # bugs 2037 xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double 2038} 2039 2040out_symbolic_tensor_failures = { 2041 # Cast error details: Unable to cast (...) to Tensor 2042 # 2043 # This happens because the test is set up to call the out variant using the `out` kwarg: 2044 # torch._some_op(arg1, arg2, out=(out1, out2, out3)) 2045 # 2046 # However, this only works on torch ops, not aten ops. For `_batch_norm_with_update`, 2047 # this fails because the op has no python bindings, so it doesn't support the `out` kwarg 2048 # way of calling its out variant. 2049 xfail('_batch_norm_with_update', ''), 2050 xfail('_native_batch_norm_legit', ''), 2051 xfail('angle', ''), 2052 xfail('argmax', ''), 2053 xfail('argmin', ''), 2054 xfail('fft.fft2', ''), 2055 xfail('fft.fftn', ''), 2056 xfail('fft.ifft2', ''), 2057 xfail('fft.ifftn', ''), 2058 xfail('gather', ''), 2059 xfail('linalg.pinv', ''), 2060 xfail('linalg.pinv', 'hermitian'), 2061 xfail('lu', ''), 2062 xfail('scatter_add', ''), 2063 xfail('scatter', ''), 2064 xfail('take_along_dim', ''), 2065 xfail('triangular_solve', ''), 2066 2067 # SymIntArrayRef expected to contain only concrete 2068 xfail('ones', ''), 2069 xfail('randn', ''), 2070 xfail('zeros', ''), 2071 2072 # RuntimeError: Cannot call numel() on tensor with symbolic sizes/strides 2073 xfail('index_reduce', 'prod'), 2074 xfail('index_reduce', 'mean'), 2075 xfail('index_reduce', 'amax'), 2076 xfail('index_reduce', 'amin'), 2077} 2078 2079out_symbolic_tensor_segfaults = { 2080 skip('nanmean', ''), 2081} 2082 2083out_symbolic_tensor_failures.update(out_symbolic_tensor_segfaults) 2084 2085# Copies inputs to inplace operations to avoid inplace modifications 2086# to leaves requiring gradient 2087def _get_safe_inplace(inplace_variant): 2088 @functools.wraps(inplace_variant) 2089 def _fn(t, *args, **kwargs): 2090 return inplace_variant(t.clone(), *args, **kwargs) 2091 2092 return _fn 2093 2094def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False, out=False): 2095 fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op 2096 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) 2097 2098 # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long 2099 count = 100 2100 if out: 2101 count = 5 2102 for sample_input in itertools.islice(sample_inputs_itr, count): 2103 if inplace and sample_input.broadcasts_input: 2104 continue 2105 args = [sample_input.input] + list(sample_input.args) 2106 kwargs = sample_input.kwargs 2107 if out: 2108 expected = fn(*args, **kwargs) 2109 kwargs['out'] = expected 2110 2111 try: 2112 optests.make_fx_check(fn, args, kwargs, tracing_mode, self.assertEqual, 2113 randomize_data=True) 2114 except DynamicOutputShapeException: 2115 self.skipTest("Dynamic output shape operation in trace") 2116 2117 2118def skipIfNameMatches(pattern): 2119 """ 2120 Decorator to skip a test if its name matches the given pattern. 2121 """ 2122 def decorator(test_func): 2123 def wrapper(*args, **kwargs): 2124 if re.match(pattern, test_func.__name__): 2125 raise unittest.SkipTest(f"Test '{test_func.__name__}' skipped because its name matches the pattern '{pattern}'") 2126 return test_func(*args, **kwargs) 2127 return wrapper 2128 return decorator 2129 2130# Auto functionalize shouldn't work with make_fx directly 2131filtered_hop_db = [op for op in hop_db if op.name != "auto_functionalize"] 2132 2133@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "Cond requires dynamo") 2134class TestProxyTensorOpInfo(TestCase): 2135 @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,)) 2136 @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures.union(only_real_tensor_failures)) 2137 def test_make_fx_exhaustive(self, device, dtype, op): 2138 _test_make_fx_helper(self, device, dtype, op, "real") 2139 2140 @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,)) 2141 @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', 2142 make_fx_failures.union(fake_tensor_failures, only_fake_tensor_failures)) 2143 def test_make_fx_fake_exhaustive(self, device, dtype, op): 2144 _test_make_fx_helper(self, device, dtype, op, "fake") 2145 2146 @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,)) 2147 @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', 2148 make_fx_failures | fake_tensor_failures | symbolic_tensor_failures) 2149 def test_make_fx_symbolic_exhaustive(self, device, dtype, op): 2150 _test_make_fx_helper(self, device, dtype, op, "symbolic") 2151 2152 @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) 2153 @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace', 2154 make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures) 2155 def test_make_fx_symbolic_exhaustive_inplace(self, device, dtype, op): 2156 if not op.get_inplace(): 2157 self.skipTest("No inplace variable for this op") 2158 _test_make_fx_helper(self, device, dtype, op, "symbolic", inplace=True) 2159 2160 @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) 2161 @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_out', 2162 make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | out_symbolic_tensor_failures) 2163 def test_make_fx_symbolic_exhaustive_out(self, device, dtype, op): 2164 if not op.supports_out: 2165 self.skipTest("Op doesn't support out") 2166 _test_make_fx_helper(self, device, dtype, op, "symbolic", out=True) 2167 2168 2169only_for = ("cpu") 2170instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for) 2171 2172 2173if __name__ == '__main__': 2174 run_tests() 2175