1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport copy 10*523fa7a6SAndroid Build Coastguard Workerimport unittest 11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, List, Tuple 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir 14*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.tests.models as models 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Workerimport torch 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import CaptureConfig 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.common import register_additional_test_aten_ops 21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tracer import dynamo_trace, ExirDynamoConfig, using_dynamo 22*523fa7a6SAndroid Build Coastguard Workerfrom functorch.experimental.control_flow import cond, map 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Workerfrom parameterized import parameterized 25*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.verifier import SpecViolationError 26*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.experimental.symbolic_shapes import is_concrete_int 27*523fa7a6SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Workerclass TestTorchDispatchFXTracer(unittest.TestCase): 31*523fa7a6SAndroid Build Coastguard Worker @classmethod 32*523fa7a6SAndroid Build Coastguard Worker def setUpClass(cls) -> None: 33*523fa7a6SAndroid Build Coastguard Worker register_additional_test_aten_ops() 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Worker def test_simple(self) -> None: 36*523fa7a6SAndroid Build Coastguard Worker f = models.BasicSinMax() 37*523fa7a6SAndroid Build Coastguard Worker f = ( 38*523fa7a6SAndroid Build Coastguard Worker exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) 39*523fa7a6SAndroid Build Coastguard Worker .to_edge() 40*523fa7a6SAndroid Build Coastguard Worker .exported_program.graph_module 41*523fa7a6SAndroid Build Coastguard Worker ) 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").run(f.code) 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Worker def test_static_control_flow(self) -> None: 46*523fa7a6SAndroid Build Coastguard Worker def f(pred: bool, x: torch.Tensor) -> torch.Tensor: 47*523fa7a6SAndroid Build Coastguard Worker if pred: 48*523fa7a6SAndroid Build Coastguard Worker return torch.sin(x).max() 49*523fa7a6SAndroid Build Coastguard Worker else: 50*523fa7a6SAndroid Build Coastguard Worker return torch.sin(x) 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker pred = True 53*523fa7a6SAndroid Build Coastguard Worker x = torch.randn(100) 54*523fa7a6SAndroid Build Coastguard Worker f_true = ( 55*523fa7a6SAndroid Build Coastguard Worker exir.capture(f, (pred, x), exir.CaptureConfig()) 56*523fa7a6SAndroid Build Coastguard Worker .to_edge() 57*523fa7a6SAndroid Build Coastguard Worker .exported_program.graph_module 58*523fa7a6SAndroid Build Coastguard Worker ) 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check( 61*523fa7a6SAndroid Build Coastguard Worker "executorch_exir_dialects_edge__ops_aten_max" 62*523fa7a6SAndroid Build Coastguard Worker ).run(f_true.code) 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker pred = False 65*523fa7a6SAndroid Build Coastguard Worker f_false = ( 66*523fa7a6SAndroid Build Coastguard Worker exir.capture(f, (pred, x), exir.CaptureConfig()) 67*523fa7a6SAndroid Build Coastguard Worker .to_edge() 68*523fa7a6SAndroid Build Coastguard Worker .exported_program.graph_module 69*523fa7a6SAndroid Build Coastguard Worker ) 70*523fa7a6SAndroid Build Coastguard Worker FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check_not( 71*523fa7a6SAndroid Build Coastguard Worker "executorch_exir_dialects_edge__ops_aten_max" 72*523fa7a6SAndroid Build Coastguard Worker ).run(f_false.code) 73*523fa7a6SAndroid Build Coastguard Worker 74*523fa7a6SAndroid Build Coastguard Worker def test_copy(self) -> None: 75*523fa7a6SAndroid Build Coastguard Worker f = models.BasicSinMax() 76*523fa7a6SAndroid Build Coastguard Worker f = ( 77*523fa7a6SAndroid Build Coastguard Worker exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) 78*523fa7a6SAndroid Build Coastguard Worker .to_edge() 79*523fa7a6SAndroid Build Coastguard Worker .exported_program.graph_module 80*523fa7a6SAndroid Build Coastguard Worker ) 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(f, torch.fx.GraphModule)) 83*523fa7a6SAndroid Build Coastguard Worker g = copy.deepcopy(f) 84*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(g, torch.fx.GraphModule)) 85*523fa7a6SAndroid Build Coastguard Worker 86*523fa7a6SAndroid Build Coastguard Worker def test_stacktrace(self) -> None: 87*523fa7a6SAndroid Build Coastguard Worker def f(x: torch.Tensor) -> torch.Tensor: 88*523fa7a6SAndroid Build Coastguard Worker return x + x 89*523fa7a6SAndroid Build Coastguard Worker 90*523fa7a6SAndroid Build Coastguard Worker traced_f = ( 91*523fa7a6SAndroid Build Coastguard Worker exir.capture(f, (torch.rand(2, 2),), exir.CaptureConfig()) 92*523fa7a6SAndroid Build Coastguard Worker .to_edge() 93*523fa7a6SAndroid Build Coastguard Worker .exported_program.graph_module 94*523fa7a6SAndroid Build Coastguard Worker ) 95*523fa7a6SAndroid Build Coastguard Worker # Check that stacktrace is populated and retained (by checking twice) 96*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 97*523fa7a6SAndroid Build Coastguard Worker any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes) 98*523fa7a6SAndroid Build Coastguard Worker ) 99*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 100*523fa7a6SAndroid Build Coastguard Worker any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes) 101*523fa7a6SAndroid Build Coastguard Worker ) 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Worker def test_ones(self) -> None: 104*523fa7a6SAndroid Build Coastguard Worker class M(torch.nn.Module): 105*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 106*523fa7a6SAndroid Build Coastguard Worker y = torch.ones(x.shape[0]) 107*523fa7a6SAndroid Build Coastguard Worker return x + y 108*523fa7a6SAndroid Build Coastguard Worker 109*523fa7a6SAndroid Build Coastguard Worker ep = torch.export.export( 110*523fa7a6SAndroid Build Coastguard Worker M(), (torch.ones(3),), dynamic_shapes={"x": {0: torch.export.Dim("x")}} 111*523fa7a6SAndroid Build Coastguard Worker ) 112*523fa7a6SAndroid Build Coastguard Worker exir.to_edge(ep) 113*523fa7a6SAndroid Build Coastguard Worker 114*523fa7a6SAndroid Build Coastguard Worker def test_possible_input_mutation(self) -> None: 115*523fa7a6SAndroid Build Coastguard Worker def f(x: torch.Tensor) -> torch.Tensor: 116*523fa7a6SAndroid Build Coastguard Worker return torch.add(torch.ones(5), torch.ones(5), out=x) 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Worker with self.assertRaisesRegex( 119*523fa7a6SAndroid Build Coastguard Worker SpecViolationError, 120*523fa7a6SAndroid Build Coastguard Worker r"operator .* is not functional", 121*523fa7a6SAndroid Build Coastguard Worker ): 122*523fa7a6SAndroid Build Coastguard Worker exir.capture(f, (torch.zeros(5),), exir.CaptureConfig()).to_edge() 123*523fa7a6SAndroid Build Coastguard Worker 124*523fa7a6SAndroid Build Coastguard Worker def test_tensor_spec_for_const_tensors(self) -> None: 125*523fa7a6SAndroid Build Coastguard Worker class Module(torch.nn.Module): 126*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 127*523fa7a6SAndroid Build Coastguard Worker super(Module, self).__init__() 128*523fa7a6SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(2, 3) 129*523fa7a6SAndroid Build Coastguard Worker 130*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 131*523fa7a6SAndroid Build Coastguard Worker return self.linear(x) 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: 134*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(2),) 135*523fa7a6SAndroid Build Coastguard Worker 136*523fa7a6SAndroid Build Coastguard Worker model = Module() 137*523fa7a6SAndroid Build Coastguard Worker graph_module = ( 138*523fa7a6SAndroid Build Coastguard Worker exir.capture(model, model.get_random_inputs(), exir.CaptureConfig()) 139*523fa7a6SAndroid Build Coastguard Worker # torch._ops.aten.t.default 140*523fa7a6SAndroid Build Coastguard Worker .to_edge( 141*523fa7a6SAndroid Build Coastguard Worker exir.EdgeCompileConfig(_check_ir_validity=False) 142*523fa7a6SAndroid Build Coastguard Worker ).exported_program.graph_module 143*523fa7a6SAndroid Build Coastguard Worker ) 144*523fa7a6SAndroid Build Coastguard Worker num_get_attr_node = 0 145*523fa7a6SAndroid Build Coastguard Worker num_get_attr_node_with_tensorspec = 0 146*523fa7a6SAndroid Build Coastguard Worker for nd in graph_module.graph.nodes: 147*523fa7a6SAndroid Build Coastguard Worker if nd.op == "get_attr": 148*523fa7a6SAndroid Build Coastguard Worker num_get_attr_node += 1 149*523fa7a6SAndroid Build Coastguard Worker if nd.meta.get("val") is not None: 150*523fa7a6SAndroid Build Coastguard Worker num_get_attr_node_with_tensorspec += 1 151*523fa7a6SAndroid Build Coastguard Worker 152*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(2, num_get_attr_node) 153*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(2, num_get_attr_node_with_tensorspec) 154*523fa7a6SAndroid Build Coastguard Worker 155*523fa7a6SAndroid Build Coastguard Worker def test_multiple_returns_spec(self) -> None: 156*523fa7a6SAndroid Build Coastguard Worker def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 157*523fa7a6SAndroid Build Coastguard Worker return torch.ops.aten.max.dim(x, 0, False) 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Worker cnt = 0 160*523fa7a6SAndroid Build Coastguard Worker module = ( 161*523fa7a6SAndroid Build Coastguard Worker exir.capture(f, (torch.zeros(1, 2, 3),), exir.CaptureConfig()) 162*523fa7a6SAndroid Build Coastguard Worker .to_edge() 163*523fa7a6SAndroid Build Coastguard Worker .exported_program.graph_module 164*523fa7a6SAndroid Build Coastguard Worker ) 165*523fa7a6SAndroid Build Coastguard Worker for node in module.graph.nodes: 166*523fa7a6SAndroid Build Coastguard Worker if node.target == exir_ops.edge.aten.max.dim: 167*523fa7a6SAndroid Build Coastguard Worker cnt += 1 168*523fa7a6SAndroid Build Coastguard Worker self.assertIsInstance(node.meta["val"], tuple) 169*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(cnt, 1) 170*523fa7a6SAndroid Build Coastguard Worker 171*523fa7a6SAndroid Build Coastguard Worker def test_multiple_returns_pt2_mode(self) -> None: 172*523fa7a6SAndroid Build Coastguard Worker def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 173*523fa7a6SAndroid Build Coastguard Worker a = x * x 174*523fa7a6SAndroid Build Coastguard Worker b = x + a 175*523fa7a6SAndroid Build Coastguard Worker return a, b 176*523fa7a6SAndroid Build Coastguard Worker 177*523fa7a6SAndroid Build Coastguard Worker inputs = (torch.ones(1, 2, 3),) 178*523fa7a6SAndroid Build Coastguard Worker orig_res = f(*inputs) 179*523fa7a6SAndroid Build Coastguard Worker module = ( 180*523fa7a6SAndroid Build Coastguard Worker exir.capture( 181*523fa7a6SAndroid Build Coastguard Worker f, 182*523fa7a6SAndroid Build Coastguard Worker inputs, 183*523fa7a6SAndroid Build Coastguard Worker exir.CaptureConfig(), 184*523fa7a6SAndroid Build Coastguard Worker ) 185*523fa7a6SAndroid Build Coastguard Worker .to_edge() 186*523fa7a6SAndroid Build Coastguard Worker .exported_program.graph_module 187*523fa7a6SAndroid Build Coastguard Worker ) 188*523fa7a6SAndroid Build Coastguard Worker new_res = module(*inputs) 189*523fa7a6SAndroid Build Coastguard Worker for node in module.graph.nodes: 190*523fa7a6SAndroid Build Coastguard Worker if node.op == "output": 191*523fa7a6SAndroid Build Coastguard Worker self.assertIsInstance(node.meta["val"], list) 192*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(node.meta["val"]), 2) 193*523fa7a6SAndroid Build Coastguard Worker 194*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(orig_res[0], new_res[0])) 195*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(orig_res[1], new_res[1])) 196*523fa7a6SAndroid Build Coastguard Worker 197*523fa7a6SAndroid Build Coastguard Worker def test_dynamo_capture_scalar_outputs(self) -> None: 198*523fa7a6SAndroid Build Coastguard Worker def f(x: torch.Tensor) -> float: 199*523fa7a6SAndroid Build Coastguard Worker return x.item() 200*523fa7a6SAndroid Build Coastguard Worker 201*523fa7a6SAndroid Build Coastguard Worker gm, guards = dynamo_trace( 202*523fa7a6SAndroid Build Coastguard Worker f, 203*523fa7a6SAndroid Build Coastguard Worker (torch.ones(1),), 204*523fa7a6SAndroid Build Coastguard Worker False, 205*523fa7a6SAndroid Build Coastguard Worker "real", 206*523fa7a6SAndroid Build Coastguard Worker ExirDynamoConfig(), 207*523fa7a6SAndroid Build Coastguard Worker ) 208*523fa7a6SAndroid Build Coastguard Worker 209*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 210*523fa7a6SAndroid Build Coastguard Worker @parameterized.expand([("stock_tensor",)]) 211*523fa7a6SAndroid Build Coastguard Worker def test_embedding_dynamic_shape(self, input_type: str) -> None: 212*523fa7a6SAndroid Build Coastguard Worker class Module(torch.nn.Module): 213*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 214*523fa7a6SAndroid Build Coastguard Worker super().__init__() 215*523fa7a6SAndroid Build Coastguard Worker 216*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 217*523fa7a6SAndroid Build Coastguard Worker return x + x 218*523fa7a6SAndroid Build Coastguard Worker 219*523fa7a6SAndroid Build Coastguard Worker example_input = torch.ones(10, dtype=torch.int64) 220*523fa7a6SAndroid Build Coastguard Worker m = Module() 221*523fa7a6SAndroid Build Coastguard Worker gm = ( 222*523fa7a6SAndroid Build Coastguard Worker exir.capture( 223*523fa7a6SAndroid Build Coastguard Worker m, 224*523fa7a6SAndroid Build Coastguard Worker (example_input,), 225*523fa7a6SAndroid Build Coastguard Worker exir.CaptureConfig( 226*523fa7a6SAndroid Build Coastguard Worker enable_functionalization=False, 227*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=True, 228*523fa7a6SAndroid Build Coastguard Worker ), 229*523fa7a6SAndroid Build Coastguard Worker ) 230*523fa7a6SAndroid Build Coastguard Worker .to_edge() 231*523fa7a6SAndroid Build Coastguard Worker .exported_program.graph_module 232*523fa7a6SAndroid Build Coastguard Worker ) 233*523fa7a6SAndroid Build Coastguard Worker 234*523fa7a6SAndroid Build Coastguard Worker print(gm.graph) 235*523fa7a6SAndroid Build Coastguard Worker 236*523fa7a6SAndroid Build Coastguard Worker def test_dynamic_shape(self) -> None: 237*523fa7a6SAndroid Build Coastguard Worker def forward(x: torch.Tensor) -> torch.Tensor: 238*523fa7a6SAndroid Build Coastguard Worker x = x.view(x.shape[0] - 1, -1) 239*523fa7a6SAndroid Build Coastguard Worker return torch.cat([x, x]) 240*523fa7a6SAndroid Build Coastguard Worker 241*523fa7a6SAndroid Build Coastguard Worker gm = ( 242*523fa7a6SAndroid Build Coastguard Worker exir.capture( 243*523fa7a6SAndroid Build Coastguard Worker forward, 244*523fa7a6SAndroid Build Coastguard Worker (torch.ones(3, 2, dtype=torch.int64),), 245*523fa7a6SAndroid Build Coastguard Worker exir.CaptureConfig( 246*523fa7a6SAndroid Build Coastguard Worker enable_functionalization=False, 247*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=True, 248*523fa7a6SAndroid Build Coastguard Worker _dynamo_config=ExirDynamoConfig(assume_static_by_default=True), 249*523fa7a6SAndroid Build Coastguard Worker ), 250*523fa7a6SAndroid Build Coastguard Worker # sym_size is not reg op 251*523fa7a6SAndroid Build Coastguard Worker ) 252*523fa7a6SAndroid Build Coastguard Worker .to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) 253*523fa7a6SAndroid Build Coastguard Worker .exported_program.graph_module 254*523fa7a6SAndroid Build Coastguard Worker ) 255*523fa7a6SAndroid Build Coastguard Worker 256*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 257*523fa7a6SAndroid Build Coastguard Worker if node.op in ("placeholder", "call_function"): 258*523fa7a6SAndroid Build Coastguard Worker self.assertIn("val", node.meta) 259*523fa7a6SAndroid Build Coastguard Worker 260*523fa7a6SAndroid Build Coastguard Worker def test_dynamo_frontend_container_input(self) -> None: 261*523fa7a6SAndroid Build Coastguard Worker class Module(torch.nn.Module): 262*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 263*523fa7a6SAndroid Build Coastguard Worker super(Module, self).__init__() 264*523fa7a6SAndroid Build Coastguard Worker 265*523fa7a6SAndroid Build Coastguard Worker def forward( 266*523fa7a6SAndroid Build Coastguard Worker self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] 267*523fa7a6SAndroid Build Coastguard Worker ) -> torch.Tensor: 268*523fa7a6SAndroid Build Coastguard Worker a = x[0] 269*523fa7a6SAndroid Build Coastguard Worker b = x[1] 270*523fa7a6SAndroid Build Coastguard Worker cum = 0 271*523fa7a6SAndroid Build Coastguard Worker for i in b: 272*523fa7a6SAndroid Build Coastguard Worker cum += i.sum() 273*523fa7a6SAndroid Build Coastguard Worker return a.cos() + cum.sin() 274*523fa7a6SAndroid Build Coastguard Worker 275*523fa7a6SAndroid Build Coastguard Worker with using_dynamo(True): 276*523fa7a6SAndroid Build Coastguard Worker inp = ((torch.ones(6), (torch.ones(6), torch.ones(6))),) 277*523fa7a6SAndroid Build Coastguard Worker gm = exir.capture(Module(), inp, exir.CaptureConfig()) 278*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(Module()(*inp), gm(*inp))) 279*523fa7a6SAndroid Build Coastguard Worker 280*523fa7a6SAndroid Build Coastguard Worker # TODO (tmanlaibaatar) remove this test 281*523fa7a6SAndroid Build Coastguard Worker def test_pt2_mode_with_dynamo_config(self) -> None: 282*523fa7a6SAndroid Build Coastguard Worker def f(x: torch.Tensor) -> torch.Tensor: 283*523fa7a6SAndroid Build Coastguard Worker return x[: x.shape[0] - 1] 284*523fa7a6SAndroid Build Coastguard Worker 285*523fa7a6SAndroid Build Coastguard Worker inp = (torch.randn(4, 5),) 286*523fa7a6SAndroid Build Coastguard Worker prog = exir.capture( 287*523fa7a6SAndroid Build Coastguard Worker f, 288*523fa7a6SAndroid Build Coastguard Worker inp, 289*523fa7a6SAndroid Build Coastguard Worker # missing dispatch key 290*523fa7a6SAndroid Build Coastguard Worker ).to_edge() 291*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(prog(torch.randn(4, 5)).shape[0], 3) 292*523fa7a6SAndroid Build Coastguard Worker 293*523fa7a6SAndroid Build Coastguard Worker def test_input_container_type(self) -> None: 294*523fa7a6SAndroid Build Coastguard Worker def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]: 295*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 296*523fa7a6SAndroid Build Coastguard Worker return {"a": x.sum() + sum(y).sum()} 297*523fa7a6SAndroid Build Coastguard Worker 298*523fa7a6SAndroid Build Coastguard Worker inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)]) 299*523fa7a6SAndroid Build Coastguard Worker 300*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[23]: Unable to unpack `(...) -> Tuple[GraphModule, 301*523fa7a6SAndroid Build Coastguard Worker # Set[torch._guards.Guard]]` into 2 values. 302*523fa7a6SAndroid Build Coastguard Worker gm, _ = torch._dynamo.export(f, *inp, aten_graph=True, tracing_mode="symbolic") 303*523fa7a6SAndroid Build Coastguard Worker prog = exir.capture(f, inp, config=exir.CaptureConfig()).to_edge() 304*523fa7a6SAndroid Build Coastguard Worker 305*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(prog(*inp), f(*inp)) 306*523fa7a6SAndroid Build Coastguard Worker 307*523fa7a6SAndroid Build Coastguard Worker def test_aot_buffer_mutation(self) -> None: 308*523fa7a6SAndroid Build Coastguard Worker class Module(torch.nn.Module): 309*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 310*523fa7a6SAndroid Build Coastguard Worker super().__init__() 311*523fa7a6SAndroid Build Coastguard Worker self.register_buffer( 312*523fa7a6SAndroid Build Coastguard Worker "_bin_num_examples", 313*523fa7a6SAndroid Build Coastguard Worker torch.empty([42]).fill_( 314*523fa7a6SAndroid Build Coastguard Worker 0.0, 315*523fa7a6SAndroid Build Coastguard Worker ), 316*523fa7a6SAndroid Build Coastguard Worker ) 317*523fa7a6SAndroid Build Coastguard Worker 318*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y, z): 319*523fa7a6SAndroid Build Coastguard Worker self._bin_num_examples.index_copy_( 320*523fa7a6SAndroid Build Coastguard Worker dim=0, 321*523fa7a6SAndroid Build Coastguard Worker index=y, 322*523fa7a6SAndroid Build Coastguard Worker source=z, 323*523fa7a6SAndroid Build Coastguard Worker ) 324*523fa7a6SAndroid Build Coastguard Worker self._bin_num_examples.index_add_( 325*523fa7a6SAndroid Build Coastguard Worker dim=0, index=torch.arange(4), source=x 326*523fa7a6SAndroid Build Coastguard Worker ) 327*523fa7a6SAndroid Build Coastguard Worker return self._bin_num_examples - 1, x * z 328*523fa7a6SAndroid Build Coastguard Worker 329*523fa7a6SAndroid Build Coastguard Worker model = Module() 330*523fa7a6SAndroid Build Coastguard Worker example_inputs = ( 331*523fa7a6SAndroid Build Coastguard Worker torch.randn(4, requires_grad=True), 332*523fa7a6SAndroid Build Coastguard Worker torch.tensor(0), 333*523fa7a6SAndroid Build Coastguard Worker torch.tensor(3.14), 334*523fa7a6SAndroid Build Coastguard Worker ) 335*523fa7a6SAndroid Build Coastguard Worker 336*523fa7a6SAndroid Build Coastguard Worker with self.assertRaisesRegex( 337*523fa7a6SAndroid Build Coastguard Worker RuntimeError, 338*523fa7a6SAndroid Build Coastguard Worker "Found a graph input that requires gradients, and received a mutation.", 339*523fa7a6SAndroid Build Coastguard Worker ): 340*523fa7a6SAndroid Build Coastguard Worker _ = exir.capture( 341*523fa7a6SAndroid Build Coastguard Worker model, 342*523fa7a6SAndroid Build Coastguard Worker example_inputs, 343*523fa7a6SAndroid Build Coastguard Worker exir.CaptureConfig( 344*523fa7a6SAndroid Build Coastguard Worker enable_aot=True, 345*523fa7a6SAndroid Build Coastguard Worker ), 346*523fa7a6SAndroid Build Coastguard Worker ) 347*523fa7a6SAndroid Build Coastguard Worker 348*523fa7a6SAndroid Build Coastguard Worker # Note that model._bin_num_examples is mutated during exir.capture 349*523fa7a6SAndroid Build Coastguard Worker # We need to create a new_model 350*523fa7a6SAndroid Build Coastguard Worker new_model = Module() 351*523fa7a6SAndroid Build Coastguard Worker example_inputs = ( 352*523fa7a6SAndroid Build Coastguard Worker torch.randn(4), 353*523fa7a6SAndroid Build Coastguard Worker torch.tensor(0), 354*523fa7a6SAndroid Build Coastguard Worker torch.tensor(3.14), 355*523fa7a6SAndroid Build Coastguard Worker ) 356*523fa7a6SAndroid Build Coastguard Worker 357*523fa7a6SAndroid Build Coastguard Worker ep = exir.capture( 358*523fa7a6SAndroid Build Coastguard Worker new_model, 359*523fa7a6SAndroid Build Coastguard Worker example_inputs, 360*523fa7a6SAndroid Build Coastguard Worker exir.CaptureConfig( 361*523fa7a6SAndroid Build Coastguard Worker enable_aot=True, 362*523fa7a6SAndroid Build Coastguard Worker ), 363*523fa7a6SAndroid Build Coastguard Worker ) 364*523fa7a6SAndroid Build Coastguard Worker 365*523fa7a6SAndroid Build Coastguard Worker test_inputs = ( 366*523fa7a6SAndroid Build Coastguard Worker torch.randn(4), 367*523fa7a6SAndroid Build Coastguard Worker torch.tensor(0), 368*523fa7a6SAndroid Build Coastguard Worker torch.tensor(2.1), 369*523fa7a6SAndroid Build Coastguard Worker ) 370*523fa7a6SAndroid Build Coastguard Worker graph_outputs = ep(*test_inputs) 371*523fa7a6SAndroid Build Coastguard Worker eager_outputs = Module()(*test_inputs) 372*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(graph_outputs), 2) 373*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(eager_outputs), 2) 374*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(graph_outputs[0], eager_outputs[0])) 375*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(graph_outputs[1], eager_outputs[1])) 376*523fa7a6SAndroid Build Coastguard Worker 377*523fa7a6SAndroid Build Coastguard Worker def test_assume_constant_by_default_prop(self) -> None: 378*523fa7a6SAndroid Build Coastguard Worker def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 379*523fa7a6SAndroid Build Coastguard Worker if x.shape[0] > 3: 380*523fa7a6SAndroid Build Coastguard Worker return x.cos() 381*523fa7a6SAndroid Build Coastguard Worker return x.sin() 382*523fa7a6SAndroid Build Coastguard Worker 383*523fa7a6SAndroid Build Coastguard Worker dynamo_config = ExirDynamoConfig(assume_static_by_default=True) 384*523fa7a6SAndroid Build Coastguard Worker capture_config = exir.CaptureConfig( 385*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=True, _dynamo_config=dynamo_config 386*523fa7a6SAndroid Build Coastguard Worker ) 387*523fa7a6SAndroid Build Coastguard Worker captured = exir.capture( 388*523fa7a6SAndroid Build Coastguard Worker foo, (torch.ones(6, 2), torch.ones(6, 3)), capture_config 389*523fa7a6SAndroid Build Coastguard Worker ).exported_program.graph_module 390*523fa7a6SAndroid Build Coastguard Worker found = False 391*523fa7a6SAndroid Build Coastguard Worker for node in captured.graph.nodes: 392*523fa7a6SAndroid Build Coastguard Worker # at least one input needs to have concrete dims 393*523fa7a6SAndroid Build Coastguard Worker if "val" in node.meta: 394*523fa7a6SAndroid Build Coastguard Worker fake_val = node.meta["val"] 395*523fa7a6SAndroid Build Coastguard Worker for dim in fake_val.shape: 396*523fa7a6SAndroid Build Coastguard Worker if is_concrete_int(dim): 397*523fa7a6SAndroid Build Coastguard Worker found = True 398*523fa7a6SAndroid Build Coastguard Worker 399*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(found) 400*523fa7a6SAndroid Build Coastguard Worker 401*523fa7a6SAndroid Build Coastguard Worker def test_aot_config(self) -> None: 402*523fa7a6SAndroid Build Coastguard Worker class FooWithBuffer(torch.nn.Module): 403*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 404*523fa7a6SAndroid Build Coastguard Worker super().__init__() 405*523fa7a6SAndroid Build Coastguard Worker self.register_buffer("buffer", torch.zeros(42)) 406*523fa7a6SAndroid Build Coastguard Worker 407*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 408*523fa7a6SAndroid Build Coastguard Worker return x.cos() + self.buffer.sum() 409*523fa7a6SAndroid Build Coastguard Worker 410*523fa7a6SAndroid Build Coastguard Worker capture_config = exir.CaptureConfig(enable_aot=True) 411*523fa7a6SAndroid Build Coastguard Worker captured_ep = exir.capture(FooWithBuffer(), (torch.ones(6, 2),), capture_config) 412*523fa7a6SAndroid Build Coastguard Worker captured_gm = captured_ep.exported_program.graph_module 413*523fa7a6SAndroid Build Coastguard Worker 414*523fa7a6SAndroid Build Coastguard Worker placeholder_nodes = set() 415*523fa7a6SAndroid Build Coastguard Worker print(captured_gm.graph) 416*523fa7a6SAndroid Build Coastguard Worker for node in captured_gm.graph.nodes: 417*523fa7a6SAndroid Build Coastguard Worker self.assertFalse(node.op == "get_attr") 418*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder": 419*523fa7a6SAndroid Build Coastguard Worker placeholder_nodes.add(node) 420*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: 421*523fa7a6SAndroid Build Coastguard Worker # make sure the placeholders are used 422*523fa7a6SAndroid Build Coastguard Worker arg_0, arg_1 = node.args 423*523fa7a6SAndroid Build Coastguard Worker self.assertEqual( 424*523fa7a6SAndroid Build Coastguard Worker placeholder_nodes, 425*523fa7a6SAndroid Build Coastguard Worker { 426*523fa7a6SAndroid Build Coastguard Worker list(arg_0._input_nodes.keys())[0], 427*523fa7a6SAndroid Build Coastguard Worker list(arg_1._input_nodes.keys())[0], 428*523fa7a6SAndroid Build Coastguard Worker }, 429*523fa7a6SAndroid Build Coastguard Worker ) 430*523fa7a6SAndroid Build Coastguard Worker 431*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(placeholder_nodes), 2) 432*523fa7a6SAndroid Build Coastguard Worker captured_ep.to_edge() 433*523fa7a6SAndroid Build Coastguard Worker 434*523fa7a6SAndroid Build Coastguard Worker def test_export_unlift(self) -> None: 435*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 436*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 437*523fa7a6SAndroid Build Coastguard Worker super().__init__() 438*523fa7a6SAndroid Build Coastguard Worker self.register_buffer("buffer", torch.ones(6, 4)) 439*523fa7a6SAndroid Build Coastguard Worker 440*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 441*523fa7a6SAndroid Build Coastguard Worker return x.cos() + self.buffer.sin() 442*523fa7a6SAndroid Build Coastguard Worker 443*523fa7a6SAndroid Build Coastguard Worker ep = exir.capture( 444*523fa7a6SAndroid Build Coastguard Worker Foo(), 445*523fa7a6SAndroid Build Coastguard Worker (torch.ones(6, 4),), 446*523fa7a6SAndroid Build Coastguard Worker exir.CaptureConfig(enable_aot=True, _unlift=True), 447*523fa7a6SAndroid Build Coastguard Worker ) 448*523fa7a6SAndroid Build Coastguard Worker 449*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))) 450*523fa7a6SAndroid Build Coastguard Worker 451*523fa7a6SAndroid Build Coastguard Worker def test_export_container_unlift(self) -> None: 452*523fa7a6SAndroid Build Coastguard Worker class FooContainerInputOutput(torch.nn.Module): 453*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 454*523fa7a6SAndroid Build Coastguard Worker super().__init__() 455*523fa7a6SAndroid Build Coastguard Worker self.register_buffer("buffer", torch.ones(6, 4)) 456*523fa7a6SAndroid Build Coastguard Worker 457*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 458*523fa7a6SAndroid Build Coastguard Worker return x[0][0].cos() + x[0][1].sin() + self.buffer.sin() 459*523fa7a6SAndroid Build Coastguard Worker 460*523fa7a6SAndroid Build Coastguard Worker inp = ((torch.ones(6, 4), torch.ones(6, 4)),) 461*523fa7a6SAndroid Build Coastguard Worker ep = exir.capture( 462*523fa7a6SAndroid Build Coastguard Worker FooContainerInputOutput(), 463*523fa7a6SAndroid Build Coastguard Worker (inp,), 464*523fa7a6SAndroid Build Coastguard Worker CaptureConfig(enable_aot=True, _unlift=True), 465*523fa7a6SAndroid Build Coastguard Worker ) 466*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ep(inp), FooContainerInputOutput()(inp))) 467*523fa7a6SAndroid Build Coastguard Worker 468*523fa7a6SAndroid Build Coastguard Worker def test_export_container_input_unlift(self) -> None: 469*523fa7a6SAndroid Build Coastguard Worker class FooContainerInputOutputV2(torch.nn.Module): 470*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 471*523fa7a6SAndroid Build Coastguard Worker super().__init__() 472*523fa7a6SAndroid Build Coastguard Worker self.register_buffer("buffer", torch.ones(6, 4)) 473*523fa7a6SAndroid Build Coastguard Worker 474*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 475*523fa7a6SAndroid Build Coastguard Worker return x[0].cos() + y[0].sin() + self.buffer.sin() 476*523fa7a6SAndroid Build Coastguard Worker 477*523fa7a6SAndroid Build Coastguard Worker inp = ((torch.ones(6, 4),), (torch.ones(6, 4),)) 478*523fa7a6SAndroid Build Coastguard Worker ep = exir.capture( 479*523fa7a6SAndroid Build Coastguard Worker FooContainerInputOutputV2(), 480*523fa7a6SAndroid Build Coastguard Worker inp, 481*523fa7a6SAndroid Build Coastguard Worker CaptureConfig(enable_aot=True, _unlift=True), 482*523fa7a6SAndroid Build Coastguard Worker ) 483*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ep(*inp), FooContainerInputOutputV2()(*inp))) 484*523fa7a6SAndroid Build Coastguard Worker 485*523fa7a6SAndroid Build Coastguard Worker def test_export_cond(self) -> None: 486*523fa7a6SAndroid Build Coastguard Worker class A(torch.nn.Module): 487*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 488*523fa7a6SAndroid Build Coastguard Worker super().__init__() 489*523fa7a6SAndroid Build Coastguard Worker self.register_buffer("buffer", torch.ones(6, 4)) 490*523fa7a6SAndroid Build Coastguard Worker 491*523fa7a6SAndroid Build Coastguard Worker def forward(self): 492*523fa7a6SAndroid Build Coastguard Worker return self.buffer.cos() 493*523fa7a6SAndroid Build Coastguard Worker 494*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 495*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 496*523fa7a6SAndroid Build Coastguard Worker super().__init__() 497*523fa7a6SAndroid Build Coastguard Worker self.a = A() 498*523fa7a6SAndroid Build Coastguard Worker 499*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 500*523fa7a6SAndroid Build Coastguard Worker def true_fn(x): 501*523fa7a6SAndroid Build Coastguard Worker return x.cos() + self.a().sum() 502*523fa7a6SAndroid Build Coastguard Worker 503*523fa7a6SAndroid Build Coastguard Worker def false_fn(x): 504*523fa7a6SAndroid Build Coastguard Worker return x.sin() 505*523fa7a6SAndroid Build Coastguard Worker 506*523fa7a6SAndroid Build Coastguard Worker return cond(x.shape[0] > 4, true_fn, false_fn, [x]) 507*523fa7a6SAndroid Build Coastguard Worker 508*523fa7a6SAndroid Build Coastguard Worker inp = torch.ones(6, 4) 509*523fa7a6SAndroid Build Coastguard Worker ep = exir.capture( 510*523fa7a6SAndroid Build Coastguard Worker Foo(), 511*523fa7a6SAndroid Build Coastguard Worker (inp,), 512*523fa7a6SAndroid Build Coastguard Worker CaptureConfig(enable_aot=True, _unlift=True), 513*523fa7a6SAndroid Build Coastguard Worker ) 514*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))) 515*523fa7a6SAndroid Build Coastguard Worker 516*523fa7a6SAndroid Build Coastguard Worker def test_export_cond_map(self) -> None: 517*523fa7a6SAndroid Build Coastguard Worker class A(torch.nn.Module): 518*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 519*523fa7a6SAndroid Build Coastguard Worker super().__init__() 520*523fa7a6SAndroid Build Coastguard Worker self.register_buffer("buffer", torch.ones(6, 4)) 521*523fa7a6SAndroid Build Coastguard Worker 522*523fa7a6SAndroid Build Coastguard Worker def forward(self): 523*523fa7a6SAndroid Build Coastguard Worker return self.buffer.sum() 524*523fa7a6SAndroid Build Coastguard Worker 525*523fa7a6SAndroid Build Coastguard Worker class Module(torch.nn.Module): 526*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 527*523fa7a6SAndroid Build Coastguard Worker super().__init__() 528*523fa7a6SAndroid Build Coastguard Worker self.a = A() 529*523fa7a6SAndroid Build Coastguard Worker 530*523fa7a6SAndroid Build Coastguard Worker def inner(self, x, pred): 531*523fa7a6SAndroid Build Coastguard Worker def true_fn(x): 532*523fa7a6SAndroid Build Coastguard Worker return x + x + self.a() 533*523fa7a6SAndroid Build Coastguard Worker 534*523fa7a6SAndroid Build Coastguard Worker def false_fn(x): 535*523fa7a6SAndroid Build Coastguard Worker return x * x - self.a() 536*523fa7a6SAndroid Build Coastguard Worker 537*523fa7a6SAndroid Build Coastguard Worker return cond(pred, true_fn, false_fn, [x]) 538*523fa7a6SAndroid Build Coastguard Worker 539*523fa7a6SAndroid Build Coastguard Worker def forward(self, pred, xs): 540*523fa7a6SAndroid Build Coastguard Worker def body(x, pred): 541*523fa7a6SAndroid Build Coastguard Worker return self.inner(x, pred) + self.a() 542*523fa7a6SAndroid Build Coastguard Worker 543*523fa7a6SAndroid Build Coastguard Worker return map(body, xs, pred) 544*523fa7a6SAndroid Build Coastguard Worker 545*523fa7a6SAndroid Build Coastguard Worker inp = torch.randn(3, 2, 1) 546*523fa7a6SAndroid Build Coastguard Worker ep = exir.capture( 547*523fa7a6SAndroid Build Coastguard Worker Module(), 548*523fa7a6SAndroid Build Coastguard Worker (torch.tensor(True), inp), 549*523fa7a6SAndroid Build Coastguard Worker CaptureConfig(enable_aot=True, _unlift=True), 550*523fa7a6SAndroid Build Coastguard Worker ) 551*523fa7a6SAndroid Build Coastguard Worker 552*523fa7a6SAndroid Build Coastguard Worker inp_test = torch.randn(3, 2, 1) 553*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 554*523fa7a6SAndroid Build Coastguard Worker torch.allclose( 555*523fa7a6SAndroid Build Coastguard Worker ep(torch.tensor(True), inp_test), 556*523fa7a6SAndroid Build Coastguard Worker Module()(torch.tensor(True), inp_test), 557*523fa7a6SAndroid Build Coastguard Worker ) 558*523fa7a6SAndroid Build Coastguard Worker ) 559*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 560*523fa7a6SAndroid Build Coastguard Worker torch.allclose( 561*523fa7a6SAndroid Build Coastguard Worker ep(torch.tensor(False), inp_test), 562*523fa7a6SAndroid Build Coastguard Worker Module()(torch.tensor(False), inp_test), 563*523fa7a6SAndroid Build Coastguard Worker ) 564*523fa7a6SAndroid Build Coastguard Worker ) 565