""" PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes with test_functionalization_with_native_python_assertion) """ # Owner(s): ["oncall: export"] import math import operator import unittest from re import escape from typing import List, Set import torch from functorch.experimental.control_flow import cond from torch._dynamo.eval_frame import is_dynamo_supported from torch._export.non_strict_utils import ( _fakify_script_objects, _gather_constant_attrs, ) from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse from torch._export.passes.replace_set_grad_with_hop_pass import ( _is_set_grad_enabled_node, _is_set_grad_enabled_sub_mod, ) from torch._export.passes.replace_view_ops_with_view_copy_ops_pass import ( get_view_copy_of_view_op, is_view_op, ReplaceViewOpsWithViewCopyOpsPass, ) from torch._export.utils import ( node_inline_, nodes_count, nodes_filter, nodes_map, sequential_split, ) from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._subclasses.fake_tensor import FakeTensorMode from torch.export import export from torch.export._remove_auto_functionalized_pass import ( unsafe_remove_auto_functionalized_pass, ) from torch.export._remove_effect_tokens_pass import _remove_effect_tokens from torch.export.passes import move_to_device_pass from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.fx.passes.infra.partitioner import Partition from torch.fx.passes.operator_support import OperatorSupport from torch.library import _scoped_library, impl from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_utils import ( IS_WINDOWS, run_tests, skipIfTorchDynamo, TestCase, ) from torch.testing._internal.torchbind_impls import init_torchbind_implementations from torch.utils import _pytree as pytree def count_call_function(graph: torch.fx.Graph, target: torch.ops.OpOverload) -> int: count = 0 for node in graph.nodes: if node.op == "call_function" and node.target == target: count += 1 return count class _AddOperatorSupport(OperatorSupport): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: return node.op == "call_function" and node.target in {operator.add} class _AtenAddOperatorSupport(OperatorSupport): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: return node.op == "call_function" and node.target in {torch.ops.aten.add.Tensor} def _to_partition_names(partitions: List[Partition]) -> List[Set[str]]: return [{n.name for n in p.nodes} for p in partitions] def _get_output_names(gm: torch.fx.GraphModule) -> List[str]: output_node = next(n for n in gm.graph.nodes if n.op == "output") args = pytree.tree_leaves(output_node.args) # if isinstance(args, tuple) and len(args) == 1: # args = args[0] return [str(arg) for arg in args] class ModelsWithScriptObjectAttr: class Simple(torch.nn.Module): def __init__(self) -> None: super().__init__() self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) class SimpleWithAttrInContainer(torch.nn.Module): def __init__(self) -> None: super().__init__() self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) self.pytree_attr2 = [ torch.classes._TorchScriptTesting._Foo(1, 2), { torch.classes._TorchScriptTesting._Foo(3, 4), }, {"foo": torch.classes._TorchScriptTesting._Foo(5, 6)}, ] class NestedWithAttrInContainer(torch.nn.Module): def __init__(self) -> None: super().__init__() self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) self.pytree_attr2 = [ torch.classes._TorchScriptTesting._Foo(1, 2), { torch.classes._TorchScriptTesting._Foo(3, 4), }, {"foo": torch.classes._TorchScriptTesting._Foo(5, 6)}, ] self.sub_mod = ModelsWithScriptObjectAttr.Simple() self.sub_mod2 = ModelsWithScriptObjectAttr.SimpleWithAttrInContainer() class MoreNestedWithAttrInContainer(torch.nn.Module): def __init__(self) -> None: super().__init__() self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) self.pytree_attr2 = [ torch.classes._TorchScriptTesting._Foo(1, 2), { torch.classes._TorchScriptTesting._Foo(3, 4), }, {"foo": torch.classes._TorchScriptTesting._Foo(5, 6)}, ] self.sub_mod = ModelsWithScriptObjectAttr.Simple() self.sub_mod2 = ModelsWithScriptObjectAttr.NestedWithAttrInContainer() def _set_grad_enabled_tests(): from torch.export._trace import _export class SetGradOp(torch.nn.Module): def forward(self, x): x = x + 1 torch._C._set_grad_enabled(True) c = x.sin().sum() torch._C._set_grad_enabled(False) d = c + 1 torch._C._set_grad_enabled(True) e = d - 1 return d, e class SetGradCtxManager(torch.nn.Module): def forward(self, x): x = x + 1 with torch.enable_grad(): c = x.sin().sum() with torch.no_grad(): d = c + 1 with torch.enable_grad(): e = d - 1 return d, e class SetGradCtxManagerMultiDep(torch.nn.Module): def forward(self, x): x = x + 1 with torch.enable_grad(): c1 = x.sin().sum() c2 = x.cos().sum() with torch.no_grad(): d1 = c1 + 1 d2 = c2 + 1 with torch.enable_grad(): e1 = d1 - 1 e2 = d2 - 1 return d1, d2, e1, e2 x = torch.randn(2, 2) def _get_predispatch_module(mod, args, ambient_grad_enabled=True): with torch.set_grad_enabled(ambient_grad_enabled): return _export(mod, args, pre_dispatch=True).module() return { "ctx_manager": ( SetGradCtxManager(), _get_predispatch_module(SetGradCtxManager(), (x,)), (x,), ), "ctx_manager_under_no_grad": ( SetGradCtxManager(), _get_predispatch_module(SetGradCtxManager(), (x,), False), (x,), ), "ctx_manager_multi_dep": ( SetGradCtxManagerMultiDep(), _get_predispatch_module(SetGradCtxManagerMultiDep(), (x,)), (x,), ), "ctx_manager_multi_dep_no_grad": ( SetGradCtxManagerMultiDep(), _get_predispatch_module(SetGradCtxManagerMultiDep(), (x,), False), (x,), ), "op": (SetGradOp(), _get_predispatch_module(SetGradOp(), (x,)), (x,)), "op_under_no_grad": ( SetGradOp(), _get_predispatch_module(SetGradOp(), (x,), False), (x,), ), } def _with_autocast_tests(): from torch.export._trace import _export class WithAutocastOp(torch.nn.Module): def forward(self, x): x = x + 1 with torch.autocast(device_type="cpu", enabled=True): c = x.sin().sum() with torch.autocast(device_type="cpu", enabled=False): d = c + 1 with torch.autocast(device_type="cpu", enabled=True): e = d - 1 return d, e class WithAutocastOpMultiDep(torch.nn.Module): def forward(self, x): x = x + 1 with torch.autocast(device_type="cpu", enabled=True): c1 = x.sin().sum() c2 = x.cos().sum() with torch.autocast(device_type="cpu", enabled=False): d1 = c1 + 1 d2 = c2 + 1 with torch.autocast(device_type="cpu", enabled=True): e1 = d1 - 1 e2 = d2 - 1 return d1, d2, e1, e2 class SplitAutocastOp(torch.nn.Module): def forward(self, x): x = x + 1 with torch.autocast(device_type="cpu", enabled=True): c = x.sin().sum() d = c + 1 with torch.autocast(device_type="cpu", enabled=True): e = d - 1 return d, e x = torch.randn(2, 2) def _get_predispatch_module(mod, args): return _export(mod, args, pre_dispatch=True).module() return { "ctx_manager": ( WithAutocastOp(), _get_predispatch_module(WithAutocastOp(), (x,)), (x,), ), "ctx_manager_multi_dep": ( WithAutocastOpMultiDep(), _get_predispatch_module(WithAutocastOpMultiDep(), (x,)), (x,), ), "ctx_manager_split": ( SplitAutocastOp(), _get_predispatch_module(SplitAutocastOp(), (x,)), (x,), ), } def _sequential_split_inline_tests(): from torch.export._trace import _export class Simple(torch.nn.Module): def forward(self, x): x = x + 1 c = x.sin().sum() d = c + 1 e = d - 1 return d, e class MultiDep(torch.nn.Module): def forward(self, x1, x2): x1 = x1 + 1 x2 = x2 + 1 c1 = x1.sin() c2 = x2.cos() d1 = c1 + 1 d2 = c2 + 1 e1 = d1 - 1 e2 = d2 - 1 return d1, d2, e1, e2 def _get_predispatch_module(mod, args): return _export(mod, args, pre_dispatch=True).module() def _insert_dilimiter_nodes(gm: torch.fx.GraphModule, step: int = 1): insert_locs = [] for i, node in enumerate( nodes_filter(gm.graph.nodes, lambda n: n.op == "call_function") ): if i % step == 0: insert_locs.append(node) for i, node in enumerate(insert_locs): with gm.graph.inserting_before(node): gm.graph.call_function( torch._C._set_grad_enabled, (True if i % 2 == 0 else False,), {} ) return gm x = torch.randn(2, 2) simple = _get_predispatch_module(Simple(), (x,)) simple1 = _get_predispatch_module(Simple(), (x,)) multi_dep = _get_predispatch_module(MultiDep(), (x, x.sin())) multi_dep1 = _get_predispatch_module(MultiDep(), (x, x.sin())) return { "simple_step1": (_insert_dilimiter_nodes(simple1, 1), (x,)), "simple_step2": (_insert_dilimiter_nodes(simple, 2), (x,)), "multi_dep_step2": (_insert_dilimiter_nodes(multi_dep, 2), (x, x.sin())), "multi_dep_step3": (_insert_dilimiter_nodes(multi_dep1, 3), (x, x.sin())), } @skipIfTorchDynamo("recursively running dynamo on export is unlikely") @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") class TestPasses(TestCase): def setUp(self): super().setUp() self.SEQUENTIAL_SPLIT_INLINE_TESTS = _sequential_split_inline_tests() self.SET_GRAD_ENABLED_TESTS = _set_grad_enabled_tests() self.WITH_AUTOCAST_TESTS = _with_autocast_tests() init_torchbind_implementations() def tearDown(self): self.SEQUENTIAL_SPLIT_INLINE_TESTS.clear() self.SET_GRAD_ENABLED_TESTS.clear() self.WITH_AUTOCAST_TESTS.clear() super().tearDown() def test_runtime_assert_one_dim(self) -> None: class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): return x.cos() x = torch.zeros(2, 2, 3) dim1_x = torch.export.Dim("dim1_x", min=2, max=6) ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}}) with self.assertRaisesRegex( RuntimeError, escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"), ): ep.module()(torch.zeros(2, 7, 3)) self.assertEqual( ep.module()(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3)) ) def test_runtime_assert_multiple_dims(self) -> None: class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, y): return x.cos().sum() + y.sin().sum() x = torch.zeros(4, 2, 3) y = torch.zeros(5, 5, 5) dim1_x = torch.export.Dim("dim1_x", min=2, max=6) dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y", min=3) ep = torch.export.export( M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}} ) with self.assertRaisesRegex( RuntimeError, escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"), ): ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) with self.assertRaisesRegex( RuntimeError, escape("Expected input at *args[1].shape[0] to be >= 3, but got 2"), ): ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) def test_runtime_assert_some_dims_not_specified(self) -> None: class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, y): return x.cos().sum() + y.sin().sum() x = torch.zeros(4, 2, 3) y = torch.zeros(5, 5, 5) dim1_x = torch.export.Dim("dim1_x", min=2, max=6) dim0_x = torch.export.Dim("dim0_x", min=3) ep = torch.export.export( M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None} ) with self.assertRaisesRegex( RuntimeError, escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"), ): ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( RuntimeError, escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"), ): ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) # Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1 gm_result_for_1_size = ep.module()(torch.ones(3, 1, 3), torch.ones(5, 5, 5)) eager_result_for_1_size = M().forward(torch.ones(3, 1, 3), torch.ones(5, 5, 5)) self.assertEqual(gm_result_for_1_size, eager_result_for_1_size) def test_runtime_assert_some_inps_not_used(self) -> None: class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, y): return y.cos().sum() x = torch.zeros(4, 2, 3) y = torch.zeros(5, 5, 5) dim1_y = torch.export.Dim("dim1_y", min=3, max=6) ep = torch.export.export( M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}} ) with self.assertRaisesRegex(RuntimeError, escape("shape[1] to be equal to 2")): ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( RuntimeError, escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"), ): ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) # Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1 gm_result_for_1_size = ep.module()(torch.zeros(4, 2, 3), torch.ones(5, 5, 5)) eager_result_for_1_size = M().forward(torch.zeros(4, 2, 3), torch.ones(5, 5, 5)) self.assertEqual(gm_result_for_1_size, eager_result_for_1_size) def test_view_to_view_copy(self) -> None: class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): z = x.view(x.shape) return z.cos().sum() x = torch.zeros(4, 2, 3) ep = export(M(), (x,)) self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 1) ep = ep._transform_do_not_use(ReplaceViewOpsWithViewCopyOpsPass()) self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 0) def test_functionalization_with_view_copy(self) -> None: class Module(torch.nn.Module): def forward(self, x): y = x + 4 y.add_(4) z = y.view(y.shape) return x.cos() + z.cos() x = torch.zeros(4, 2, 3) foo = Module() ep = export(foo, (x,))._transform_do_not_use( ReplaceViewOpsWithViewCopyOpsPass() ) # After this pass, there shouldn't be any view nodes in the graph self.assertTrue(count_call_function(ep.graph, torch.ops.aten.view.default) == 0) self.assertTrue( count_call_function(ep.graph, torch.ops.aten.view_copy.default) > 0 ) def test_views_op_having_view_copy(self) -> None: schemas = torch._C._dispatch_get_registrations_for_dispatch_key("") aten_schemas = [s[6:] for s in schemas if s.startswith("aten::")] for aten_schema in aten_schemas: val = aten_schema.split(".") assert len(val) <= 2 name = "" overload = "" if len(val) == 1: name = val[0] overload = "default" else: name, overload = val[0], val[1] op_overload = getattr(getattr(torch.ops.aten, name), overload) if torch.Tag.core in op_overload.tags and is_view_op(op_overload._schema): self.assertIsNotNone(get_view_copy_of_view_op(op_overload._schema)) def test_custom_obj_tuple_out(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) def forward(self, x): a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x) y = a[0] + a[1] b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y) return b m = MyModule() inputs = (torch.ones(2, 3),) ep = torch.export.export(m, inputs, strict=False) inp = torch.randn(2, 3) orig_res = m(inp) ep_res = ep.module()(inp) without_token_ep = _remove_effect_tokens(ep) without_token_ep.verifier().check(without_token_ep) without_token_res = without_token_ep.module()(inp) self.assertTrue(torch.allclose(orig_res, ep_res)) self.assertTrue(torch.allclose(orig_res, without_token_res)) def test_remove_effect_token_kwargs(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) def forward(self, x): a = torch.ops._TorchScriptTesting.takes_foo_tuple_return( foo=self.attr, x=x ) y = a[0] + a[1] b = torch.ops._TorchScriptTesting.takes_foo(foo=self.attr, x=y) return b m = MyModule() inputs = (torch.ones(2, 3),) ep = torch.export.export(m, inputs, strict=False) without_token_ep = _remove_effect_tokens(ep) self.assertExpectedInline( without_token_ep.graph_module.code.strip(), """\ def forward(self, token, obj_attr, x): with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, foo = obj_attr, x = x); token = x = None getitem = with_effects[0] getitem_1 = with_effects[1] getitem_2 = with_effects[2]; with_effects = None add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, foo = obj_attr, x = add); getitem = obj_attr = add = None getitem_3 = with_effects_1[0] getitem_4 = with_effects_1[1]; with_effects_1 = None return (getitem_3, getitem_4)""", # noqa: B950 ) def test_fakify_script_objects(self): for m in [ ModelsWithScriptObjectAttr.Simple(), ModelsWithScriptObjectAttr.SimpleWithAttrInContainer(), ModelsWithScriptObjectAttr.NestedWithAttrInContainer(), ModelsWithScriptObjectAttr.MoreNestedWithAttrInContainer(), ]: constant_attrs = _gather_constant_attrs(m) fake_mode = FakeTensorMode( shape_env=ShapeEnv(tracked_fakes=[]), allow_non_fake_inputs=True, ) with _fakify_script_objects(m, (), {}, fake_mode) as ( patched_mod, _, _, fake_constant_attrs, fake_to_real, ): self.assertEqual(len(fake_constant_attrs), len(constant_attrs)) for fake_obj, fqn in fake_constant_attrs.items(): self.assertEqual(constant_attrs[fake_to_real[fake_obj]], fqn) # TODO: _gather_constants doesn't recursively look into the pytree containers. @unittest.expectedFailure def test_fakify_script_objects_properly_handle_containers(self): m = ModelsWithScriptObjectAttr.SimpleWithAttrInContainer() constant_attrs = _gather_constant_attrs(m) fake_mode = FakeTensorMode( shape_env=ShapeEnv(tracked_fakes=[]), allow_non_fake_inputs=True, ) with _fakify_script_objects(m, (), {}, fake_mode) as ( patched_mod, _, _, fake_constant_attrs, fake_to_real, ): self.assertTrue("attr" in fake_constant_attrs.values()) self.assertTrue("pytree_attr2" in fake_constant_attrs.values()) def test_runtime_assert_inline_constraints_for_item(self) -> None: class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): b = x.item() torch._check(b >= 2) torch._check(b <= 5) return b x = torch.tensor([2]) mod = M() ep = export(mod, (x,)) with self.assertRaisesRegex( RuntimeError, r"Runtime assertion failed for expression u[\d+] \<\= 5" ): ep.module()(torch.tensor([6])) new_inp = torch.tensor([5]) self.assertEqual(mod(new_inp), ep.module()(new_inp)) def test_runtime_assert_inline_constraints_for_nonzero(self) -> None: class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): b = x.nonzero() torch._check(b.shape[0] >= 3) torch._check(b.shape[0] <= 5) return b x = torch.tensor([2, 1, 2, 3, 5, 0]) mod = M() dim0_x = torch.export.Dim("dim0_x") ep = torch.export.export(mod, (x,), dynamic_shapes={"x": {0: dim0_x}}) num_assert = count_call_function( ep.graph, torch.ops.aten._assert_scalar.default ) self.assertEqual(num_assert, 2) num_constrain_range = count_call_function( ep.graph, torch.ops.aten.sym_constrain_range.default ) self.assertEqual(num_constrain_range, 0) with self.assertRaisesRegex( RuntimeError, r"Runtime assertion failed for expression u[\d+] \>\= 3", ): ep.module()(torch.tensor([1, 1, 0, 0, 0])) with self.assertRaisesRegex( RuntimeError, r"Runtime assertion failed for expression u[\d+] \<\= 5", ): ep.module()(torch.ones(6)) new_inp = torch.tensor([1, 1, 1, 1]) self.assertEqual(mod(new_inp), ep.module()(new_inp)) @unittest.skipIf(IS_WINDOWS, "Windows not supported") @unittest.expectedFailure # TODO(pianpwk): add back runtime asserts to subgraphs def test_runtime_assert_inline_constraints_for_cond(self) -> None: class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, pred, x, y): def true_fn(x, y): b = x.item() torch._check(b >= 2) torch._check(b <= 5) return x - b def false_fn(x, y): c = y.item() torch._check(c >= 2) torch._check(c <= 5) return y - c ret = cond(pred, true_fn, false_fn, [x, y]) return ret x = torch.tensor([2]) y = torch.tensor([5]) mod = M() ep = export(mod, (torch.tensor(True), x, y)) with self.assertRaisesRegex( RuntimeError, "is outside of inline constraint \\[2, 5\\]." ): ep.module()(torch.tensor(False), torch.tensor([6]), torch.tensor([6])) def test_math_ops(self): class Module(torch.nn.Module): def forward(self, x): return ( torch.tensor([math.ceil(x.item())]), torch.tensor([math.floor(x.item())]), ) func = Module() x = torch.randn(1, dtype=torch.float32) ep = torch.export.export(func, args=(x,)) _ExportPassBaseDeprecatedDoNotUse()(ep.graph_module) def test_predispatch_set_grad(self): def _check_node_users_in_the_same_graph(gm): for node in gm.graph.nodes: for user in node.users: self.assertTrue(user.graph is gm.graph) mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["op"] _check_node_users_in_the_same_graph(mod) self.assertEqual(mod_orig(*args), mod(*args)) self.assertExpectedInline( mod.code.strip("\n"), """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None sin = torch.ops.aten.sin.default(add); add = None sum_1 = torch.ops.aten.sum.default(sin); sin = None submod_4 = self.submod_2 add_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_4, sum_1); submod_4 = sum_1 = None getitem = add_1[0]; add_1 = None sub = torch.ops.aten.sub.Tensor(getitem, 1) return pytree.tree_unflatten((getitem, sub), self._out_spec) """, ) mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["op_under_no_grad"] _check_node_users_in_the_same_graph(mod) self.assertEqual(mod_orig(*args), mod(*args)) self.assertExpectedInline( mod.code.strip("\n"), """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None sin = torch.ops.aten.sin.default(add); add = None sum_1 = torch.ops.aten.sum.default(sin); sin = None submod_4 = self.submod_2 add_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_4, sum_1); submod_4 = sum_1 = None getitem = add_1[0]; add_1 = None sub = torch.ops.aten.sub.Tensor(getitem, 1) return pytree.tree_unflatten((getitem, sub), self._out_spec) """, ) mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager"] _check_node_users_in_the_same_graph(mod) self.assertEqual(mod_orig(*args), mod(*args)) self.assertExpectedInline( mod.code.strip("\n"), """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None sin = torch.ops.aten.sin.default(add); add = None sum_1 = torch.ops.aten.sum.default(sin); sin = None submod_3 = self.submod_1 add_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, sum_1); submod_3 = sum_1 = None getitem = add_1[0]; add_1 = None sub = torch.ops.aten.sub.Tensor(getitem, 1) return pytree.tree_unflatten((getitem, sub), self._out_spec) """, ) mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager_under_no_grad"] _check_node_users_in_the_same_graph(mod) self.assertEqual(mod_orig(*args), mod(*args)) self.assertExpectedInline( mod.code.strip("\n"), """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None submod_5 = self.submod_1 sum_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None getitem = sum_1[0]; sum_1 = None add_1 = torch.ops.aten.add.Tensor(getitem, 1); getitem = None submod_6 = self.submod_3 sub = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_6, add_1); submod_6 = None getitem_1 = sub[0]; sub = None return pytree.tree_unflatten((add_1, getitem_1), self._out_spec) """, ) mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager_multi_dep"] _check_node_users_in_the_same_graph(mod) self.assertEqual(mod_orig(*args), mod(*args)) self.assertExpectedInline( mod.code.strip("\n"), """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None sin = torch.ops.aten.sin.default(add) sum_1 = torch.ops.aten.sum.default(sin); sin = None cos = torch.ops.aten.cos.default(add); add = None sum_2 = torch.ops.aten.sum.default(cos); cos = None submod_3 = self.submod_1 wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, sum_1, sum_2); submod_3 = sum_1 = sum_2 = None add_1 = wrap_with_set_grad_enabled[0] add_2 = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None sub = torch.ops.aten.sub.Tensor(add_1, 1) sub_1 = torch.ops.aten.sub.Tensor(add_2, 1) return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec) """, # noqa: B950 ) mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS[ "ctx_manager_multi_dep_no_grad" ] _check_node_users_in_the_same_graph(mod) self.assertEqual(mod_orig(*args), mod(*args)) self.assertExpectedInline( mod.code.strip("\n"), """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None submod_5 = self.submod_1 wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None sum_1 = wrap_with_set_grad_enabled[0] sum_2 = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None add_2 = torch.ops.aten.add.Tensor(sum_2, 1); sum_2 = None submod_6 = self.submod_3 wrap_with_set_grad_enabled_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_6, add_1, add_2); submod_6 = None sub = wrap_with_set_grad_enabled_1[0] sub_1 = wrap_with_set_grad_enabled_1[1]; wrap_with_set_grad_enabled_1 = None return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec) """, # noqa: B950 ) def test_sequential_split(self): for gm, args in self.SEQUENTIAL_SPLIT_INLINE_TESTS.values(): set_grad_counts = nodes_count(gm.graph.nodes, _is_set_grad_enabled_node) new_gm = sequential_split(gm, _is_set_grad_enabled_node) new_set_grad_counts = nodes_count( new_gm.graph.nodes, _is_set_grad_enabled_sub_mod ) self.assertEqual(set_grad_counts, new_set_grad_counts) self.assertEqual(gm(*args), new_gm(*args)) def test_sequential_split_graph(self): gm, args = self.SEQUENTIAL_SPLIT_INLINE_TESTS["multi_dep_step2"] new_gm = sequential_split(gm, _is_set_grad_enabled_node) self.assertEqual(gm(*args), new_gm(*args)) self.assertExpectedInline( new_gm.code.strip("\n"), """\ def forward(self, x1, x2): x1, x2, = fx_pytree.tree_flatten_spec(([x1, x2], {}), self._in_spec) submod_1 = self.submod_1(x1, x2); x1 = x2 = None getitem = submod_1[0] getitem_1 = submod_1[1]; submod_1 = None submod_2 = self.submod_2(getitem, getitem_1); getitem = getitem_1 = None getitem_2 = submod_2[0] getitem_3 = submod_2[1]; submod_2 = None submod_3 = self.submod_3(getitem_2, getitem_3); getitem_2 = getitem_3 = None getitem_4 = submod_3[0] getitem_5 = submod_3[1]; submod_3 = None submod_4 = self.submod_4(getitem_4, getitem_5) getitem_6 = submod_4[0] getitem_7 = submod_4[1]; submod_4 = None return pytree.tree_unflatten((getitem_4, getitem_5, getitem_6, getitem_7), self._out_spec) """, ) self.assertExpectedInline( new_gm.submod_1.code.strip("\n"), """\ def forward(self, x1, x2): _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None add = torch.ops.aten.add.Tensor(x1, 1); x1 = None add_1 = torch.ops.aten.add.Tensor(x2, 1); x2 = None return (add, add_1) """, ) self.assertExpectedInline( new_gm.submod_2.code.strip("\n"), """\ def forward(self, add, add_1): _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None sin = torch.ops.aten.sin.default(add); add = None cos = torch.ops.aten.cos.default(add_1); add_1 = None return (sin, cos) """, ) self.assertExpectedInline( new_gm.submod_3.code.strip("\n"), """\ def forward(self, sin, cos): _set_grad_enabled_2 = torch._C._set_grad_enabled(True); _set_grad_enabled_2 = None add_2 = torch.ops.aten.add.Tensor(sin, 1); sin = None add_3 = torch.ops.aten.add.Tensor(cos, 1); cos = None return (add_2, add_3) """, ) def test_predispatch_autocast(self): def _check_node_users_in_the_same_graph(gm): for node in gm.graph.nodes: for user in node.users: self.assertTrue(user.graph is gm.graph) mod_orig, mod, args = self.WITH_AUTOCAST_TESTS["ctx_manager"] _check_node_users_in_the_same_graph(mod) self.assertEqual(mod_orig(*args), mod(*args)) self.assertExpectedInline( mod.code.strip("\n"), """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None submod_4 = self.submod_1 sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None getitem = sum_1[0]; sum_1 = None submod_5 = self.submod_2 add_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, False, None, submod_5, getitem); submod_5 = getitem = None getitem_1 = add_1[0]; add_1 = None submod_6 = self.submod_3 sub = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_6, getitem_1); submod_6 = None getitem_2 = sub[0]; sub = None return pytree.tree_unflatten((getitem_1, getitem_2), self._out_spec) """, ) self.assertExpectedInline( mod.submod_1.code.strip("\n"), """\ def forward(self, add): sin = torch.ops.aten.sin.default(add); add = None sum_1 = torch.ops.aten.sum.default(sin); sin = None return (sum_1,) """, ) self.assertExpectedInline( mod.submod_2.code.strip("\n"), """\ def forward(self, sum_1): add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None return (add_1,) """, ) self.assertExpectedInline( mod.submod_3.code.strip("\n"), """\ def forward(self, add_1): sub = torch.ops.aten.sub.Tensor(add_1, 1); add_1 = None return (sub,) """, ) mod_orig, mod, args = self.WITH_AUTOCAST_TESTS["ctx_manager_multi_dep"] _check_node_users_in_the_same_graph(mod) self.assertEqual(mod_orig(*args), mod(*args)) self.assertExpectedInline( mod.code.strip("\n"), """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None submod_4 = self.submod_1 wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None sum_1 = wrap_with_autocast[0] sum_2 = wrap_with_autocast[1]; wrap_with_autocast = None submod_5 = self.submod_2 wrap_with_autocast_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, False, None, submod_5, sum_1, sum_2); submod_5 = sum_1 = sum_2 = None add_1 = wrap_with_autocast_1[0] add_2 = wrap_with_autocast_1[1]; wrap_with_autocast_1 = None submod_6 = self.submod_3 wrap_with_autocast_2 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_6, add_1, add_2); submod_6 = None sub = wrap_with_autocast_2[0] sub_1 = wrap_with_autocast_2[1]; wrap_with_autocast_2 = None return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec) """, # noqa: B950 ) self.assertExpectedInline( mod.submod_1.code.strip("\n"), """\ def forward(self, add): sin = torch.ops.aten.sin.default(add) sum_1 = torch.ops.aten.sum.default(sin); sin = None cos = torch.ops.aten.cos.default(add); add = None sum_2 = torch.ops.aten.sum.default(cos); cos = None return (sum_1, sum_2) """, ) self.assertExpectedInline( mod.submod_2.code.strip("\n"), """\ def forward(self, sum_1, sum_2): add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None add_2 = torch.ops.aten.add.Tensor(sum_2, 1); sum_2 = None return (add_1, add_2) """, ) self.assertExpectedInline( mod.submod_3.code.strip("\n"), """\ def forward(self, add_1, add_2): sub = torch.ops.aten.sub.Tensor(add_1, 1); add_1 = None sub_1 = torch.ops.aten.sub.Tensor(add_2, 1); add_2 = None return (sub, sub_1) """, ) mod_orig, mod, args = self.WITH_AUTOCAST_TESTS["ctx_manager_split"] _check_node_users_in_the_same_graph(mod) self.assertEqual(mod_orig(*args), mod(*args)) self.assertExpectedInline( mod.code.strip("\n"), """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None submod_4 = self.submod_1 sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None getitem = sum_1[0]; sum_1 = None add_1 = torch.ops.aten.add.Tensor(getitem, 1); getitem = None submod_5 = self.submod_3 sub = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_5, add_1); submod_5 = None getitem_1 = sub[0]; sub = None return pytree.tree_unflatten((add_1, getitem_1), self._out_spec) """, ) self.assertExpectedInline( mod.submod_1.code.strip("\n"), """\ def forward(self, add): sin = torch.ops.aten.sin.default(add); add = None sum_1 = torch.ops.aten.sum.default(sin); sin = None return (sum_1,) """, ) self.assertExpectedInline( mod.submod_3.code.strip("\n"), """\ def forward(self, add_1): sub = torch.ops.aten.sub.Tensor(add_1, 1); add_1 = None return (sub,) """, ) def test_inline_(self): for gm, args in self.SEQUENTIAL_SPLIT_INLINE_TESTS.values(): before_str = gm.print_readable(print_output=False) new_gm = sequential_split(gm, _is_set_grad_enabled_node) nodes_map( new_gm.graph.nodes, lambda node: node_inline_(node) if node.op == "call_module" else node, ) after_inline_str = new_gm.print_readable(print_output=False) self.assertEqual(before_str, after_inline_str) self.assertEqual(gm(*args), new_gm(*args)) def test_remove_auto_functionalized_pass(self) -> None: with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib: lib.define("custom_mutator(Tensor x, Tensor(a!) y) -> Tensor") @impl(lib, "custom_mutator", "Meta") def custom_mutator_meta( x: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: return torch.empty_like(x) @impl(lib, "custom_mutator", "CompositeExplicitAutograd") def custom_mutator( x: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: return x + y.add_(1) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.state = torch.nn.Buffer(torch.zeros(1)) def forward(self, x): return torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator(x, self.state) mod = M() x = torch.randn([3, 3]) ep = export(mod, (x,)) inplace_ep = unsafe_remove_auto_functionalized_pass(ep) nodes = inplace_ep.graph.nodes for node in nodes: if node.op == "call_function": self.assertFalse(node.target is auto_functionalized) self.assertFalse(node.target is operator.getitem) for spec in inplace_ep.graph_signature.output_specs: self.assertFalse("getitem" in spec.arg.name) def test_remove_auto_functionalized_pass_tuple(self) -> None: with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib: lib.define( "custom_mutator_tuple(Tensor x, Tensor(a!) y) -> (Tensor, Tensor)" ) @impl(lib, "custom_mutator_tuple", "Meta") def custom_mutator_tuple_meta( x: torch.Tensor, y: torch.Tensor, ): return (torch.empty_like(x), torch.empty_like(x)) @impl(lib, "custom_mutator_tuple", "CompositeExplicitAutograd") def custom_mutator_tuple( x: torch.Tensor, y: torch.Tensor, ): return (x, x + y.add_(1)) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.state = torch.nn.Buffer(torch.zeros(1)) def forward(self, x): return torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple( x, self.state ) mod = M() x = torch.randn([3, 3]) ep = export(mod, (x,)) inplace_ep = unsafe_remove_auto_functionalized_pass(ep) graph_text = str(inplace_ep.graph) self.assertExpectedInline( graph_text, """\ graph(): %b_state : [num_users=2] = placeholder[target=b_state] %x : [num_users=1] = placeholder[target=x] %custom_mutator_tuple_default : [num_users=2] = call_function[target=torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple.\ default](args = (%x, %b_state), kwargs = {}) %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 0), kwargs = {}) %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 1), kwargs = {}) return (b_state, getitem_3, getitem_4)""", ) @unittest.skipIf(not TEST_CUDA, "requires cuda") def test_move_to_device_pass(self): class Model(torch.nn.Module): def __init__(self, size=4, h_dim=10): super().__init__() self.rnn = torch.nn.GRU(size, h_dim, batch_first=True) def forward(self, x): _, states = self.rnn(x) return states # move the exported program from cpu to cuda:0 mod = Model() example_inputs = (torch.rand(1, 10, 4),) ep = export(mod, example_inputs) location = torch.device("cuda:0") ep = move_to_device_pass(ep, location=location) gm = ep.module() test_inputs = (torch.rand(1, 10, 4).to("cuda:0"),) outputs = gm(*test_inputs) self.assertEqual(outputs.device, torch.device("cuda:0")) # move it back to cpu location = "cpu" ep = move_to_device_pass(ep, location=location) gm = ep.module() test_inputs = (torch.rand(1, 10, 4).to("cpu"),) outputs = gm(*test_inputs) self.assertEqual(outputs.device, torch.device("cpu")) # move it to cuda:0 again location = {"cpu": "cuda:0"} ep = move_to_device_pass(ep, location=location) gm = ep.module() test_inputs = (torch.rand(1, 10, 4).to("cuda:0"),) outputs = gm(*test_inputs) self.assertEqual(outputs.device, torch.device("cuda:0")) if __name__ == "__main__": run_tests()