# Owner(s): ["module: dynamo"] import math import random import unittest import numpy as np import torch import torch._dynamo.test_case import torch._dynamo.testing import torch.nn.functional as F from torch._dynamo.comptime import comptime from torch._dynamo.testing import CompileCounter, same from torch.testing._internal.common_utils import skipIfWindows from torch.testing._internal.logging_utils import logs_to_string # The intention of this test file is you should put test cases specifically # for assume_static_by_default=False, aka you want to YOLO make everything as # dynamic as possible. If you want to test the more normal situation where # you assume static by default, put it in a regular test file and # test_dynamic_shapes will cover both the YOLO and non-YOLO cases. @torch._dynamo.config.patch(assume_static_by_default=False) class UnspecTests(torch._dynamo.test_case.TestCase): def test_numpy_correctness(self): def fn(x, y, z): xy = [x + y, y, False] np_x = x.numpy() np_y = y.numpy() return { "x": x, "z": z, "a": np_y.sum(), "b": xy, "c": np_y[0][0] / 68, "d": np_x.sum(), "e": np_x + np_y, }, x + np_y.sum() + z x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64) y = torch.ones([2, 2], dtype=torch.int64) z = np.int64(12) res1 = fn(x, y, z) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res2 = opt_fn(x, y, z) self.assertEqual(res1, res2) def test_no_recompilations(self): # no recompilations if passing on different numpy int values def fn(x, y): return {"a": x + 1, "b": y / 2} x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) for i in range(10): opt_fn(x, np.int64(i)) self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 2) @unittest.expectedFailure # array scalars decay to 0D arrays def test_builtin_max_min(self): # test unspecialized primitive max/min def fn(x, y, z): return z + 1, max(x, y), min(x - 4, y) x = np.int64(12) y = 10 z = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64) res1 = fn(x, y, z) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res2 = opt_fn(x, y, z) self.assertTrue(same(res1, res2, relax_numpy_equality=True)) def test_feed_random_values_into_graph_only(self): def fn(shape): torch.manual_seed(123) x = torch.randn(shape, device="cpu") * random.randint(30, 100) return x shape = [2, 3] random.seed(1) res1 = fn(shape) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) random.seed(1) res2 = opt_fn(shape) self.assertTrue(same(res1, res2)) def test_random_values_with_graph_break(self): def fn(x): r1 = random.random() y = x + random.uniform(10, 20) y.sum().item() r2 = random.randint(2, 18) # no graph output in this frame y.sum().item() return y + r1, r2 x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) random.seed(1) res1 = fn(x) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) random.seed(1) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) # Really annoying intersection of specialization and RandomValueSource # If we get a RandomValueSource with a single element tensor, we should return a ConstantVariable like other # unspects... but if we do, we break the bytecode assumptions and guards will not work as we will be referring # to a name from a source that is not there. If we call .item() and take the wrapped_value out, where we do # wrapped_value = wrapped_value.item() where we send unspec down to wrap_fx_proxy, this test passes and then # some models fail on missing codegen.tx.output.random_values_var. If we let the tensor value go into wrap as # it is, this test fails. # The real solution here is to rewrite RandomValueSource and all the codegen it does from the ground up. def test_multiple_consecutive_random_calls_before_graph(self): def fn(x): dim1 = random.randrange(start=0, stop=5) dim2 = random.randrange(start=0, stop=5) dim3 = random.randrange(start=0, stop=5) y = torch.rand(dim1, dim2, dim3) return x + 2, y x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) random.seed(1) res1 = fn(x) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) random.seed(1) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) def test_compiled_random_calls_are_random(self): # For compiled functions with random calls, # it should return different values for every iteration. # https://github.com/pytorch/pytorch/issues/95425 @torch.compile(backend="eager", fullgraph=True) def fn(x): return (x + 1) * random.uniform(0, 1) res = [] for _ in range(5): res.append(fn(torch.ones(2))) for i in range(1, 5): self.assertFalse(same(res[i - 1], res[i])) def test_random_call_with_while_loop(self): def fn(x): dim1 = random.randrange(start=0, stop=3) dim2 = dim1 while dim1 == dim2: dim2 = random.randrange(start=0, stop=3) return x * 2 x = torch.randn(4) random.seed(1) res1 = fn(x) opt_fn = torch._dynamo.optimize("eager")(fn) random.seed(1) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) random.seed(10) res1 = fn(x) random.seed(10) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) def test_random_object(self): # test argument passing, mutation, reconstruction, state correctness def fn(x, rand2): r1 = random.randint(1, 9) r2 = rand2.randint(1, 9) rand3 = random.Random(42) r3 = rand3.randint(1, 9) y = x + r1 + r2 + r3 return y, rand2, rand3 inp = torch.randn(3, 3) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) random.seed(0) y_1, rand2_1, rand3_1 = fn(inp, random.Random(12)) state_1 = random.getstate() random.seed(0) y_2, rand2_2, rand3_2 = opt_fn(inp, random.Random(12)) state_2 = random.getstate() self.assertEqual(y_1, y_2) self.assertEqual(state_1, state_2) self.assertEqual(rand2_1.getstate(), rand2_2.getstate()) self.assertEqual(rand3_1.getstate(), rand3_2.getstate()) def test_random_object_methods(self): def fn(x, rand1, rand2, rand3): rand1.seed(42) rand4 = random.Random(9002) rand2.setstate(rand4.getstate()) r1 = rand1.random() r2 = rand2.randint(1, 10) r3 = rand3.randrange(10) r4 = rand4.uniform(0, 1) return x + r1 + r2 + r3 + r4 inp = torch.randn(3, 3) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) rand1_1 = random.Random(1) rand2_1 = random.Random(2) rand3_1 = random.Random(3) rand1_2 = random.Random(1) rand2_2 = random.Random(2) rand3_2 = random.Random(3) y1 = fn(inp, rand1_1, rand2_1, rand3_1) y2 = opt_fn(inp, rand1_2, rand2_2, rand3_2) self.assertEqual(y1, y2) self.assertEqual(rand1_1.getstate(), rand1_2.getstate()) self.assertEqual(rand2_1.getstate(), rand2_2.getstate()) self.assertEqual(rand3_1.getstate(), rand3_2.getstate()) def test_random_object_overriden_methods(self): # these will result in graph breaks, but we shouldn't crash def get_rng(): rand1 = random.Random(1) rand2 = random.Random(2) orig_random = rand1.random def custom_random(): return orig_random() orig_getstate = rand2.getstate def custom_getstate(): return orig_getstate() rand1.random = custom_random rand2.getstate = custom_getstate return rand1, rand2 def fn(x, rand1, rand2): r1 = rand1.random() rand3 = random.Random() rand3.setstate(rand2.getstate()) r2 = rand3.random() return x + r1 + r2 inp = torch.randn(3, 3) opt_fn = torch.compile(fn, backend="eager") y1 = fn(inp, *get_rng()) y2 = opt_fn(inp, *get_rng()) self.assertEqual(y1, y2) def test_builtin_getitem(self): # builtin getitem args[0] is python list and args[1] is unspec def fn(x, idx): return (torch.zeros(idx), x[idx], x[idx:]) x = list(range(50)) ref = fn(x, 48) # 48 is unspecialized cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res = opt_fn(x, 48) self.assertTrue(same(ref, res)) def test_use_and_specialize(self): cnt = CompileCounter() @torch.compile(backend=cnt, fullgraph=True, dynamic=True) def fn(x, y): x = x + y if y == 2: return x - 1 else: return x + 1 self.assertTrue(same(fn(torch.tensor([5]), 2), 6)) self.assertTrue(same(fn(torch.tensor([6]), 2), 7)) self.assertTrue(same(fn(torch.tensor([5]), 3), 9)) self.assertTrue(same(fn(torch.tensor([4]), 3), 8)) self.assertEqual(cnt.frame_count, 2) def test_no_recompiles(self): cnt = CompileCounter() @torch.compile(backend=cnt, fullgraph=True, dynamic=True) def fn(x, y): return x + y self.assertTrue(same(fn(torch.tensor([5]), 100), 105)) self.assertTrue(same(fn(torch.tensor([4]), 200), 204)) self.assertTrue(same(fn(torch.tensor([3]), 300), 303)) self.assertTrue(same(fn(torch.tensor([2]), 400), 402)) self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 1) def test_no_recompiles_prod_backward(self): # https://github.com/pytorch/pytorch/issues/120608 cnt = CompileCounter() @torch.compile(backend=cnt, fullgraph=True, dynamic=True) def fn(t): return torch.prod(t, 3, keepdim=True) input_shapes = [(8, 10, 3, 2), (8, 3, 5, 2), (8, 4, 8, 2)] for s in input_shapes: t1 = torch.randn(s, requires_grad=True) h_result = fn(t1) grad = torch.ones_like(h_result) h_result.backward(grad) self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 1) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_builtin_functions_on_cuda(self): def fn(x, scaler): m = torch.nn.ReLU() y = m(x) * scaler return y x = torch.randn([3, 6], device="cuda") scaler = 0.23 # 0.23 is unspecialized ref = fn(x, scaler) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res = opt_fn(x, scaler) self.assertTrue(same(ref, res)) self.assertEqual(ref.device, res.device) def test_unspec_float_precision(self): def fn(image, scale_factor): image = torch.nn.functional.interpolate( image[None], size=None, scale_factor=scale_factor, mode="bilinear", recompute_scale_factor=True, align_corners=False, )[0] return image.shape x = torch.rand([3, 427, 640]) scale_factor = 1.873536229133606 ref = fn(x, scale_factor) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res = opt_fn(x, scale_factor) self.assertTrue(same(ref, res)) @unittest.expectedFailure # fails as long as numpy scalars are 0D arrays def test_specializing_numpy_float_in_control_flow(self): # np.float64 is unspecialized by default, # but it should be specialized when used in control flow. def fn(x, y): if y > 1.0: return x + 1 else: return x - 1 x = torch.rand(4) opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) for t in [np.float16, np.float32, np.float64]: y = t(1.23) ref = fn(x, y) res = opt_fn(x, y) self.assertTrue(same(ref, res)) def test_mark_static_inside(self): def fn(x): torch._dynamo.mark_static(x, 0) comptime.assert_static(x.size(0)) return x + 1 opt_fn = torch.compile(fn, dynamic=True, fullgraph=True) opt_fn(torch.randn(12, 23)) def test_shape_graph_break(self): from torch._dynamo.comptime import comptime def fn(x): x_shape = x.size() comptime.graph_break() return x + torch.randn(x_shape) x = torch.randn(20) opt_fn = torch._dynamo.optimize("eager")(fn) opt_fn(x) def test_isinstance_symint(self): def fn(x): assert isinstance(x.size(0), int) return x * 2 x = torch.randn(20) opt_fn = torch._dynamo.optimize("eager")(fn) opt_fn(x) y = torch.randn(30) torch._dynamo.mark_dynamic(y, 0) opt_fn(y) def test_mark_01_dynamic(self): def fn(x): return x * 2 x = torch.randn(1) torch._dynamo.mark_dynamic(x, 0) opt_fn = torch._dynamo.optimize("eager")(fn) # This will fail to compile a generic kernel, but we should not # complain about it (mark dynamic will try its best but 0/1 # specialization is allowed) opt_fn(x) def test_conv1d_symint_padding(self): kernel = torch.randn(1, 1, 4) def func(x): padding = math.ceil((kernel.shape[-1] + x.shape[-1] % 2) / 2) - 1 out = F.conv1d(x, kernel, padding=padding, stride=2) return out opt_func = torch.compile(func) x = torch.randn(1, 1, 175) opt_func(x) # passes x = torch.randn(1, 1, 249) opt_func(x) # crashes @torch._dynamo.config.patch("assume_static_by_default", True) def test_propagate_dynamic_dim(self): x = torch.randn(20) torch._dynamo.mark_dynamic(x, 0) @torch.compile() def fn(x): y = x * 2 comptime.graph_break() z = y * 2 return z z = fn(x) self.assertEqual(z._dynamo_weak_dynamic_indices, {0}) def test_rshift_dynamic(self): def shift_right(tensor: torch.Tensor) -> torch.Tensor: return (tensor >> 2).to(torch.long) opt_fn = torch.compile(shift_right, fullgraph=True, dynamic=True) sample_input = torch.tensor([4, 4, 16, 32], dtype=torch.uint8) opt_fn(sample_input) @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_symfloat_to_tensor(self): def f1(v): return torch.tensor([v.item()]) def f2(v): return torch.tensor([[v.item()], [2.0]]) def f3(v): return torch.tensor(v.item()) def f4(v): return torch.tensor((v.item(),)) optimize = torch.compile(backend="aot_eager", fullgraph=True) r = torch.randn(1) self.assertEqual(f1(r), optimize(f1)(r)) self.assertEqual(f2(r), optimize(f2)(r)) self.assertEqual(f3(r), optimize(f3)(r)) self.assertEqual(f4(r), optimize(f4)(r)) @skipIfWindows( msg="AssertionError: The values for attribute 'dtype' do not match: torch.int32 != torch.int64." ) def test_to_tensor(self): def f1(): a = np.random.uniform(low=-1, high=1, size=(20, 1)) return torch.tensor([a, a, a, a], dtype=torch.float64, device="cpu") def f2(): a = torch.tensor([[[123]]]) return torch.tensor([a, a]) def f3(): a = torch.tensor(123) return torch.tensor([a, a]) def f4(): a = torch.tensor(123) b = torch.tensor([[[456]]]) return torch.tensor([a, b]) def f5(): a = np.array([1, 2]) return torch.tensor([a, a]) optimize = torch.compile(backend="aot_eager", fullgraph=True) self.assertEqual(f1().shape, optimize(f1)().shape) self.assertEqual(f2(), optimize(f2)()) self.assertEqual(f3(), optimize(f3)()) self.assertEqual(f4(), optimize(f4)()) self.assertEqual(f5(), optimize(f5)()) def test_sym_int_conversion(self): def f(x): y = x.size(0) return x * int(y == 0) opt_fn = torch.compile(f, backend="eager", fullgraph=True) x = torch.randn(2, 3) opt_fn(x) def test_sum_dimlist_spec(self): def fn(inputs, dim): return torch.sum(inputs, dim) inputs = torch.randn(128, 5, 24, 24) dim = (-1, 1, 0, 2) compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True) self.assertEqual(compl_fn(inputs, dim), fn(inputs, dim)) @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_item_max(self): def fn(x): return torch.ones(max(x.item(), 1024)) x = torch.tensor([1000]) y = torch.tensor([2000]) compl_fn = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(fn(x), compl_fn(x)) self.assertEqual(fn(y), compl_fn(y)) # https://github.com/pytorch/pytorch/issues/104812 def test_argmin_coerces_symint_to_intlist_spec(self): def fn(x, dim): # the python arg parser coerces dim into a vector return torch.amin(x, dim=dim, keepdim=True) x = torch.randn(4, 4, 4) dim = 2 compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True) self.assertEqual(compl_fn(x, dim), fn(x, dim)) def test_exponential(self): def fn(inputs, op_inputs_dict): res = inputs.exponential_(**op_inputs_dict) return res inputs = torch.randn(2, 3, 4) op_inputs_dict = {"lambd": 10, "generator": None} compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True) self.assertEqual(compl_fn(inputs, op_inputs_dict), fn(inputs, op_inputs_dict)) def test_symbol_guard_limit_before_specialize(self): cnts = torch._dynamo.testing.CompileCounter() @torch._dynamo.optimize(cnts, dynamic=True) def fn(x): torch._check(x.size(0) != 3) torch._check(x.size(0) != 4) torch._check(x.size(0) != 5) torch._check(x.size(0) != 6) return x + 2 # Control test fn(torch.randn(12)) fn(torch.randn(13)) fn(torch.randn(14)) self.assertExpectedInline(cnts.frame_count, """1""") cnts.frame_count = 0 torch._dynamo.reset() with torch.fx.experimental._config.patch( symbol_guard_limit_before_specialize=3 ): fn(torch.randn(12)) fn(torch.randn(13)) fn(torch.randn(14)) self.assertExpectedInline(cnts.frame_count, """3""") def test_defaults(self): def g(x, i=8): comptime.assert_static(i) return x * i def fn(x): return g(x) inputs = torch.randn(2, 3, 4) compl_fn = torch.compile(fn, dynamic=True, backend="eager") self.assertEqual(compl_fn(inputs), fn(inputs)) @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True) def test_unspec_float_input(self): cnts = torch._dynamo.testing.CompileCounter() def f(x, y): if y == 5.0: return x + 2 else: return x + y cf = torch.compile(backend=cnts, fullgraph=True)(f) x = torch.randn(3) self.assertEqual(f(x, 3.0), cf(x, 3.0)) self.assertEqual(f(x, 4.0), cf(x, 4.0)) self.assertExpectedInline(cnts.frame_count, """1""") # no recompile self.assertEqual(f(x, 5.0), cf(x, 5.0)) self.assertExpectedInline(cnts.frame_count, """2""") # guard worked self.assertEqual(f(x, math.nan), cf(x, math.nan)) self.assertExpectedInline(cnts.frame_count, """3""") # nan always recompiles @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True) def test_unspec_float_output(self): cnts = torch._dynamo.testing.CompileCounter() def f(x, y): return x + 1, y * 2 cf = torch.compile(backend=cnts, fullgraph=True)(f) x = torch.randn(3) self.assertEqual(f(x, 3.0), cf(x, 3.0)) self.assertEqual(f(x, 4.0), cf(x, 4.0)) self.assertEqual(f(x, 5.0), cf(x, 5.0)) @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_data_dependent_evaluate_expr_graph_break(self): cnts = torch._dynamo.testing.CompileCounter() # To ensure that the continuation frame is compiled, # have to write the test function in this funny way. # See https://github.com/pytorch/pytorch/issues/111918 def test(y): if y > 2: return True else: return False @torch._dynamo.optimize(cnts) def fn(x): x = x + 1 y = x.item() if test(y): return x * 2 else: return x * 3 x = torch.tensor([3.0]) fn(x) self.assertExpectedInline(cnts.frame_count, """2""") self.assertExpectedInline(cnts.op_count, """4""") def test_prune_torch_check(self): log_stream, ctx = logs_to_string("torch._dynamo.output_graph", "graph_code") @torch.compile(fullgraph=True, dynamic=True, backend="eager") def f(x, y): torch._check(y + 5 == 85) torch._check(x.size(0) == 80) with ctx(): f(torch.randn(80, 100), 80) out = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip() self.assertExpectedInline( out, """\ def forward(self): return ()""", ) @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_split_aot_autograd(self): @torch.compile(backend="aot_eager", fullgraph=True) def f(x, i): y, z = i.tolist() return torch.split(x, [y, z]) print(f(torch.randn(10, requires_grad=True), torch.tensor([7, 3]))) def test_bool_tensor_ctor(self): cnts = torch._dynamo.testing.CompileCounter() @torch.compile(backend=cnts, dynamic=True, fullgraph=True) def f(x): y = torch.empty((x.size(0) // 13) * 13) return torch.tensor(y.numel() == 0) self.assertTrue(f(torch.empty(8)).item()) self.assertFalse(f(torch.empty(13)).item()) @torch._dynamo.config.patch(error_on_recompile=True) def test_mark_unbacked(self): class TestModel(torch.nn.Module): def __init__( self, ): super().__init__() def forward(self, x: torch.Tensor, val: int) -> torch.Tensor: return x * 2 main_model = TestModel() opt_model = torch.compile(main_model, mode="max-autotune", dynamic=True) x1 = torch.rand(3, 5, 4, 8) x2 = torch.rand(1, 5, 4, 8) torch._dynamo.decorators.mark_unbacked(x1, 0) o1_ref = main_model(x1, 2) o1 = opt_model(x1, 2) self.assertEqual(o1_ref, o1) o1_2_ref = main_model(x2, 2) o1_2 = opt_model(x2, 2) self.assertEqual(o1_2_ref, o1_2) @torch._dynamo.config.patch(error_on_recompile=True) def test_mark_unbacked_hint_consistency(self): from torch.fx.experimental.symbolic_shapes import guard_size_oblivious x = torch.randn(1) torch._dynamo.decorators.mark_unbacked(x, 0) @torch.compile() def f(x): if guard_size_oblivious(x.size(0) != 1): return x + 3 else: return x + 4 self.assertEqual(f(x), x + 3) @torch._dynamo.config.patch(error_on_recompile=True) def test_mark_unbacked_channels_last(self): class TestModel(torch.nn.Module): def __init__( self, ): super().__init__() def forward(self, x: torch.Tensor, val: int) -> torch.Tensor: return x * 2 main_model = TestModel() opt_model = torch.compile(main_model, mode="max-autotune", dynamic=True) x1 = torch.rand(3, 5, 4, 8).to(memory_format=torch.channels_last) x2 = torch.rand(1, 5, 4, 8).to(memory_format=torch.channels_last) torch._dynamo.decorators.mark_unbacked(x1, 0) o1_ref = main_model(x1, 2) o1 = opt_model(x1, 2) self.assertEqual(o1_ref, o1) o1_2_ref = main_model(x2, 2) o1_2 = opt_model(x2, 2) self.assertEqual(o1_2_ref, o1_2) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()