1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# flake8: noqa: F401 8import functools 9import inspect 10import os 11import random 12import unittest 13from typing import Callable, Dict, Optional, Tuple, Type 14from unittest import skip, skipUnless 15 16import executorch.exir as exir 17 18import executorch.exir.control_flow as control_flow 19 20# @manual=//executorch/extension/pytree:pybindings 21import executorch.extension.pytree as pytree 22import torch 23 24from executorch.exir import ( 25 CaptureConfig, 26 EdgeCompileConfig, 27 ExecutorchBackendConfig, 28 memory, 29) 30from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode 31from executorch.exir.emit import emit_program 32from executorch.exir.pass_manager import PassManager 33from executorch.exir.passes import ( 34 DebugPass, 35 MemoryPlanningPass, 36 to_scratch_op_pass, 37 ToOutVarPass, 38) 39from executorch.exir.print_program import pretty_print, print_program 40from executorch.exir.tensor import make_tensor_value, TensorSpec 41from executorch.exir.tests.control_flow_models import ( 42 FTCondBasic, 43 FTCondDynShape, 44 FTMapBasic, 45 FTMapDynShape, 46) 47from executorch.exir.tests.dynamic_shape_models import BatchNormModel 48 49from executorch.exir.tests.transformer import Transformer 50from functorch.experimental.control_flow import cond 51 52kernel_mode = None # either aten mode or lean mode 53try: 54 from executorch.extension.pybindings.portable_lib import ( 55 _load_bundled_program_from_buffer, 56 _load_for_executorch_from_buffer, 57 _load_for_executorch_from_bundled_program, 58 ) 59 60 kernel_mode = "lean" 61except ImportError as e: 62 print(e) 63 pass 64 65try: 66 from executorch.extension.pybindings.aten_lib import ( 67 _load_bundled_program_from_buffer, 68 _load_for_executorch_from_buffer, 69 _load_for_executorch_from_bundled_program, 70 ) 71 72 assert kernel_mode is None 73 kernel_mode = "aten" 74except ImportError as e: 75 print(e) 76 pass 77 78assert kernel_mode is not None 79 80is_aten_mode = kernel_mode == "aten" 81is_lean_mode = kernel_mode == "lean" 82 83from torch import nn 84from torch.utils import _pytree as torch_pytree 85 86from .exported_module import ExportedModule 87 88 89RUN_SKIPPED = int(os.environ.get("RUN_SKIPPED", "0")) 90 91 92class ModuleBasic(nn.Module): 93 def __init__(self): 94 super(ModuleBasic, self).__init__() 95 96 def forward(self, x): 97 return torch.sin(x).max() 98 99 def get_random_inputs(self): 100 return (torch.randn(100),) 101 102 103class ModuleOpsReturnMulti(nn.Module): 104 def __init__(self): 105 super(ModuleOpsReturnMulti, self).__init__() 106 107 def forward(self, a, b): 108 x, y = torch.topk(a, 3) 109 return x * 2 + b 110 111 def get_random_inputs(self): 112 return (torch.randn(10), torch.randn(3)) 113 114 115class ModuleAdd(nn.Module): 116 def __init__(self): 117 super(ModuleAdd, self).__init__() 118 119 def forward(self, x, y): 120 return torch.add(x, y) 121 122 def get_random_inputs(self): 123 return (torch.randn(2, 2), torch.randn(2, 2)) 124 125 126class ModuleFloatAddWithAlpha(nn.Module): 127 def __init__(self): 128 super(ModuleFloatAddWithAlpha, self).__init__() 129 130 def forward(self, x: torch.Tensor, y: torch.Tensor, c: float): 131 return torch.add(x, y, alpha=c) 132 133 def get_random_inputs(self): 134 return (torch.randn(2, 2), torch.randn(2, 2), random.random()) 135 136 137class ModuleIntAddWithAlpha(nn.Module): 138 def __init__(self): 139 super(ModuleIntAddWithAlpha, self).__init__() 140 141 def forward(self, x: torch.Tensor, y: torch.Tensor, c: int): 142 return torch.add(x, y, alpha=c) 143 144 def get_random_inputs(self): 145 return ( 146 torch.randint(0, 10, (2, 2)), 147 torch.randint(0, 10, (2, 2)), 148 random.randint(0, 10), 149 ) 150 151 152class ModuleContainers(nn.Module): 153 def __init__(self): 154 super(ModuleContainers, self).__init__() 155 156 def forward(self, d): 157 a = d["a"] 158 b = d["b"] 159 return {"inputs": (a, b), "c": torch.add(a, b)} 160 161 def get_random_inputs(self): 162 return ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},) 163 164 165class ToyModelForMemPlanning(nn.Module): 166 def __init__(self): 167 super(ToyModelForMemPlanning, self).__init__() 168 169 def forward(self, a, b): 170 o = a 171 for i in range(3): 172 o = o * a 173 o = o + b 174 return o 175 176 def get_random_inputs(self): 177 return ( 178 torch.randn(10), 179 torch.randn(10), 180 ) 181 182 183class MemPlanningWithScratchTensor(nn.Module): 184 def __init__(self): 185 super(MemPlanningWithScratchTensor, self).__init__() 186 self.linear1 = nn.Linear(4, 2) 187 self.linear2 = nn.Linear(4, 2) 188 189 def forward(self, a, b): 190 o1 = self.linear1(a) 191 o2 = self.linear2(b) 192 return o1 + o2 193 194 def get_random_inputs(self): 195 return ( 196 torch.randn(10, 4), 197 torch.randn(10, 4), 198 ) 199 200 201class ModuleOpsReturnTensorList(nn.Module): 202 def __init__(self): 203 super(ModuleOpsReturnTensorList, self).__init__() 204 205 def forward(self, x): 206 split = torch.ops.aten.tensor_split.sections(x, 3) 207 return split[0] 208 209 def get_random_inputs(self): 210 return (torch.randn(100),) 211 212 213class ModuleReturnInput(nn.Module): 214 def __init__(self): 215 super(ModuleReturnInput, self).__init__() 216 217 def forward(self, x): 218 return (x, x, {"x": x, "y": x}, [x, x, x]) 219 220 def get_random_inputs(self): 221 return (torch.randn(1),) 222 223 224class ModuleIfElse(nn.Module): 225 def __init__(self): 226 super().__init__() 227 228 def forward(self, c, x): 229 x = x * x 230 231 def addloop(x, n): 232 out = x 233 for _ in range(n - 1): 234 out = out + x 235 return out 236 237 def true_branch(c, x): 238 return addloop(x, 3) 239 240 def false_branch(c, x): 241 return addloop(x, 4) 242 243 y = cond(c, true_branch, false_branch, (c, x)) 244 return y * y 245 246 def get_random_inputs(self): 247 return (torch.randint(2, [1]) == 0, torch.randn(10)) 248 249 250class ModuleIfElseWithBoolInput(nn.Module): 251 def __init__(self): 252 super().__init__() 253 254 def forward(self, c: bool, x: torch.Tensor): 255 x = x * x 256 257 def addloop(x, n): 258 out = x 259 for _ in range(n - 1): 260 out = out + x 261 return out 262 263 def true_branch(c, x): 264 return addloop(x, 3) 265 266 def false_branch(c, x): 267 return addloop(x, 4) 268 269 y = cond(c, true_branch, false_branch, (c, x)) 270 271 return y * y 272 273 def get_random_inputs(self): 274 return (random.randint(0, 1) == 0, torch.randn(10)) 275 276 277class ModuleWhileIf(nn.Module): 278 def __init__(self): 279 super().__init__() 280 281 def forward(self, accum, cnt): 282 @control_flow.tracing_context( 283 inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 284 ) 285 def loop_cond(accum, cnt): 286 return cnt != torch.zeros([1]).to(dtype=torch.long) 287 288 @control_flow.tracing_context( 289 inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 290 ) 291 def loop_body(accum, cnt): 292 # return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long) 293 @control_flow.tracing_context( 294 inputs=(torch.zeros([1]).to(dtype=torch.long),) 295 ) 296 def true_branch(cnt): 297 return cnt 298 299 @control_flow.tracing_context( 300 inputs=(torch.zeros([1]).to(dtype=torch.long),) 301 ) 302 def false_branch(cnt): 303 return torch.zeros([1], dtype=torch.long) 304 305 accum = accum + cond( 306 torch.BoolTensor([True]), true_branch, false_branch, (cnt,) 307 ) 308 # 'cnt - 1' does not work yet since the runtime does not expect 309 # tensor to be mixed with scalar for sub op. 310 return accum, cnt - torch.ones([1]).to(dtype=torch.long) 311 312 y, _ = control_flow.while_loop( 313 loop_cond, 314 loop_body, 315 (accum, cnt), 316 ) 317 return y 318 319 def get_random_inputs(self): 320 return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 321 322 323class ModuleIfWhile(nn.Module): 324 def __init__(self): 325 super().__init__() 326 327 def forward(self, accum, cnt): 328 @control_flow.tracing_context( 329 inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 330 ) 331 def true_branch(accum, cnt): 332 @control_flow.tracing_context( 333 inputs=( 334 torch.zeros([1]).to(dtype=torch.long), 335 torch.randint(10, 100, [1]), 336 ) 337 ) 338 def loop_cond(accum, cnt): 339 return cnt != torch.zeros([1]).to(dtype=torch.long) 340 341 @control_flow.tracing_context( 342 inputs=( 343 torch.zeros([1]).to(dtype=torch.long), 344 torch.randint(10, 100, [1]), 345 ) 346 ) 347 def loop_body(accum, cnt): 348 return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long) 349 350 return control_flow.while_loop(loop_cond, loop_body, (accum, cnt)) 351 352 @control_flow.tracing_context( 353 inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 354 ) 355 def false_branch(accum, cnt): 356 return accum, cnt 357 358 return cond(torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt))[ 359 0 360 ] 361 362 def get_random_inputs(self): 363 return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) 364 365 366class ModuleContiguousTensor(nn.Module): 367 def __init__(self): 368 super().__init__() 369 self.linear = nn.Linear(8, 32) 370 371 def forward(self, arg): 372 return self.linear(arg) 373 374 def get_random_inputs(self): 375 return (torch.randn(3, 8),) 376 377 378class ModuleInputDynamicShape(nn.Module): 379 def __init__(self): 380 super().__init__() 381 382 def forward(self, x): 383 for i in range(4): 384 x = x + x 385 x = x * x 386 return x 387 388 def get_upper_bound_inputs(self): 389 return (torch.randn(10),) 390 391 def get_random_inputs(self): 392 n = random.randint(1, 10) 393 return (torch.randn(n),) 394 395 396class ModuleIntermediateDynamicShape(nn.Module): 397 def __init__(self): 398 super().__init__() 399 400 def forward(self, x): 401 x = x * x 402 403 # We should use x[torch.nonzero(x)] ideally, but index op is not supported 404 # in the runtime so far. 405 x = torch.nonzero(x) 406 return x + x 407 408 def get_random_inputs(self): 409 return (torch.randint(0, 2, (10,), dtype=torch.float),) 410 411 412def allclose(lhs, rhs, rtol=1e-5, atol=1e-8): 413 r""" 414 Unlike torch.allocse which only handles Tensor arguments, allclose handles 415 list, tuple, dict and nesting of these as well. 416 """ 417 if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor): 418 return torch.allclose(lhs, rhs, rtol, atol) 419 if isinstance(lhs, (tuple, list)) and isinstance(rhs, (tuple, list)): 420 return len(lhs) == len(rhs) and all( 421 allclose(a, b, rtol, atol) for a, b in zip(lhs, rhs) 422 ) 423 if isinstance(lhs, dict) and isinstance(rhs, dict): 424 lhs_keys = set(lhs.keys()) 425 rhs_keys = set(rhs.keys()) 426 if lhs_keys != rhs_keys: 427 return False 428 return all(allclose(lhs[k], rhs[k], rtol, atol) for k in lhs) 429 else: 430 raise RuntimeError( 431 f"Unexpected types: lhs type {type(lhs)}, rhs type {type(rhs)}" 432 ) 433 434 435def validate_contiguous_tensors(program): 436 def _is_contiguous_tensor(tensor: exir.schema.Tensor): 437 """ 438 Ensure the tensor is pytorch contigous (torch.memory_format=torch.contiguous) 439 since the runtime can not handle non-contiguous tensors so far. 440 """ 441 sizes = tensor.sizes 442 dim_order = tensor.dim_order 443 assert len(sizes) == len(dim_order) 444 for i, val in enumerate(dim_order): 445 if i != val: 446 return False 447 return True 448 449 for execution_plan in program.execution_plan: 450 for value in execution_plan.values: 451 if isinstance(value.val, exir.schema.Tensor): 452 assert _is_contiguous_tensor( 453 value.val 454 ), f"Non-contiguous tensor found: size {value.val.sizes} stride {value.val.strides}. constant_buffer_idx {value.val.constant_buffer_idx}. allocation_info {value.val.allocation_info}." 455 456 457class BoundMethod(object): 458 def __init__(self, instance, callable): 459 self._instance = instance 460 self._callable = callable 461 462 def __call__(self, *args, **kwargs): 463 return self._callable(self.instance, *args, **kwargs) 464 465 466def maketest( 467 module_cls: Type[nn.Module], 468 niter: int = 10, 469 run_executor: bool = True, 470 do_tree_flatten: bool = False, 471 run_graph_module: bool = True, 472 atol: float = 1e-8, 473 rtol: float = 1e-5, 474 ignore_to_out_var_failure: bool = False, 475 allow_non_contiguous_tensor: bool = False, 476 method: str = "forward", 477 dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND, 478 capture_config=None, 479 verify_graph: Optional[Callable] = None, 480) -> Callable[[unittest.TestCase], None]: 481 r"""Returns a TestCase method to test the provided module class and method. 482 483 Args: 484 module_cls: The subclass of nn.Module to export. 485 niter: The number of random input data sets to test with. 486 run_executor: Whether to run the model on the executor. We may want to 487 skip running a model thru executor since some kernels are not 488 implemented. 489 do_tree_flatten: Whether to flatten input and unflatten output. 490 run_graph_module: Whether to run the traced and transformed GraphModule. 491 One may want to skip this if some custom ops do not have 492 implementation in torch.ops but is implemented in the executor. 493 atol: Absolute tolerance used in allclose and torch.allclose 494 rtol: Relative tolerance used in allclose and torch.allclose 495 ignore_to_out_var_failure: Whether to ignore the failue when a 496 functional op does not have an out variant. 497 allow_non_contiguous_tensor: If false, will validate that the emitted 498 program only contains contiguous tensors. 499 method: The name of the module_cls method to trace. 500 dynamic_memory_planning_mode: The dynamic memory planning mode to use. 501 502 Returns: 503 A TestCase method that tests the provided module class and method. 504 """ 505 506 def wrapper(self: unittest.TestCase) -> None: 507 """A TestCase method that traces/exports/tests an nn.Module and method.""" 508 module = ExportedModule.export( 509 module_class=module_cls, 510 # testend2end only supports modules with single methods defined 511 methods=(method,), 512 ignore_to_out_var_failure=ignore_to_out_var_failure, 513 dynamic_memory_planning_mode=dynamic_memory_planning_mode, 514 capture_config=capture_config, 515 ) 516 if verify_graph: 517 verify_graph(self, module.exported_program.graph_module) 518 print(f"inputs for tracing: {module.trace_inputs}") 519 520 # compare the result between the eager module and graph module 521 inputs_list = [module.get_random_inputs() for _ in range(niter)] 522 523 if run_graph_module: 524 for inputs in inputs_list: 525 with torch.no_grad(): 526 # only one method is supported so just grab that single method 527 expected = getattr(module.eager_module, module.methods[0])(*inputs) 528 with torch.no_grad(): 529 result = module.exported_program.module()(*inputs) 530 self.assertTrue(allclose(expected, result, rtol, atol)) 531 532 program = module.executorch_program.executorch_program 533 pretty_print(program) 534 print_program(program, show_meminfo=True, mark_dynamic_shape_tensor=True) 535 print(f"mem buffer sizes: {program.execution_plan[0].non_const_buffer_sizes}") 536 if not allow_non_contiguous_tensor: 537 validate_contiguous_tensors(program) 538 self.assertTrue(len(program.execution_plan[0].non_const_buffer_sizes) >= 2) 539 # We should not enable the following assertion since for some models 540 # that simply returning graph input, no mutable memory should be allocated 541 # self.assertTrue(all(s > 0 for s in program.program.execution_plan[0].non_const_buffer_sizes[1:])) 542 543 program.version = 0 544 buff = module.executorch_program.buffer 545 # Check that the magic version number is in the expected place, and 546 # follows the expected pattern. 547 self.assertRegex(buff[4:8].decode(errors="replace"), r"^ET[0-9][0-9]$") 548 549 if run_executor: 550 print("Running on the runtime") 551 executorch_module = _load_for_executorch_from_buffer(buff) 552 # compare the result between eager module and executor 553 for idx, inputs in enumerate(inputs_list): 554 with torch.no_grad(): 555 expected = getattr(module.eager_module, method)(*inputs) 556 557 if do_tree_flatten: 558 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 559 flatten_inputs, inputs_spec = pytree.tree_flatten(*inputs) 560 executorch_result = executorch_module.forward([*flatten_inputs]) 561 # pyre-fixme[16]: Module `pytree` has no attribute `TreeSpec`. 562 executorch_result_unflatten = pytree.TreeSpec.from_str( 563 program.execution_plan[0].container_meta_type.encoded_out_str 564 ).tree_unflatten(executorch_result) 565 actual = executorch_result_unflatten 566 else: 567 actual = executorch_module.forward(inputs)[0] 568 is_close = allclose(expected, actual, rtol, atol) 569 if not is_close: 570 print(f"Fail for {idx}th inputs: {inputs}") 571 print(f"expected result: {expected}") 572 print(f"actual result: {actual}") 573 self.assertTrue(is_close) 574 575 return wrapper 576 577 578class E2ETest(unittest.TestCase): 579 r""" 580 When adding a new unittest, call maketest(ModuleName) if possible since 581 maketest handles all the boilterplate part. Ideally, we only need define 582 a new nn.Module and add one line to call maketest for new end2end test cases. 583 """ 584 585 # don't run the model thru executor because aten::sin.out is not defined 586 # in the executor currently. 587 # 588 # aten::max.default does not have an out variant. Thus we need set 589 # ignore_to_out_var_failure to be True. 590 def test_basic(self): 591 maketest(ModuleBasic, run_executor=False, ignore_to_out_var_failure=True)(self) 592 593 # Make sure we can handle ops that return mutliple values. E.g. topk 594 # At one time we can not properly setup TensorSpec for an Fx node 595 # returning multiple tensors 596 # 597 # don't run the model thru executor because aten::topk.values is not defined 598 # in the executor currently 599 def test_ops_return_multi(self): 600 maketest(ModuleOpsReturnMulti, run_executor=False)(self) 601 602 @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") 603 def test_mem_planning_toy_model(self): 604 maketest( 605 ToyModelForMemPlanning, 606 capture_config=exir.CaptureConfig( 607 enable_dynamic_shape=True, 608 ), 609 )(self) 610 611 # TODO: add ops implementations and turn on 'run_executor' 612 def test_mem_planning_scratch_tensor(self): 613 maketest( 614 MemPlanningWithScratchTensor, 615 run_graph_module=False, 616 run_executor=False, 617 atol=1e-5, 618 )(self) 619 620 def test_executorch_forward(self): 621 maketest(ModuleAdd)(self) 622 623 @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") 624 def test_containers(self): 625 maketest( 626 ModuleContainers, 627 do_tree_flatten=True, 628 capture_config=exir.CaptureConfig( 629 enable_dynamic_shape=True, 630 ), 631 )(self) 632 633 # can not run the graph module since the out variance with tensor list out 634 # argument returns None rather than tensor list. 635 # 636 # Can not run in the executor since kernel for tensor splitting is not implemented.. 637 def test_ops_return_tensorlist(self): 638 maketest(ModuleOpsReturnTensorList, run_graph_module=False, run_executor=False)( 639 self 640 ) 641 642 # Failed to produce a graph during tracing w/ dynamo because there are no torch ops 643 # test_return_input = maketest(ModuleReturnInput, do_tree_flatten=True) 644 645 # can not run this on the executor because missing the following ops: 646 # aten::select_copy.int_out, aten::eq.Scalar_out 647 # TODO(zhxchen17) re-enable these tests. 648 # test_control_flow_cond = maketest(ControlFlowCond, run_executor=False) 649 # fail to trace with functionalization enabled 650 # test_ifelse = maketest(ModuleIfElse) 651 652 # fail to trace with functionalization enabled 653 # Fail with error: Missing out variants: {'aten::select', 'aten::_shape_as_tensor', 'aten::tensor_split'} 654 # TODO(zhxchen17) re-enable these tests. 655 # test_while_0 = maketest( 656 # ControlFlowWhile, 657 # ignore_to_out_var_failure=True, 658 # run_executor=False, 659 # ) 660 661 # test_while = maketest(ModuleWhile) 662 663 # test_while_if = maketest(ModuleWhileIf) 664 # test_if_while = maketest(ModuleIfWhile) 665 @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) This fails on OSS macos job") 666 def test_contiguous_tensor(self): 667 maketest(ModuleContiguousTensor, run_executor=False)(self) 668 669 670class DynamicModelE2ETest(unittest.TestCase): 671 """ 672 End2end tests for dynamic models. For dynamic models we mean models with 673 control flow or dynamic shape. 674 """ 675 676 @skip("Revisit when unbacked symint is ready") 677 def test_intermediate_dynamic_shape(self): 678 maketest( 679 ModuleIntermediateDynamicShape, 680 run_graph_module=False, 681 allow_non_contiguous_tensor=True, 682 capture_config=exir.CaptureConfig( 683 enable_dynamic_shape=True, 684 ), 685 )(self) 686 687 # TODO(shunting): some non constant tensors for transformer are non-contiguous. 688 # Ignore for now. Will debug more. 689 # NOTE: can not run on runtime since missing these ops: P535190636 690 @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) This fails on OSS macos job") 691 def test_transformer_encode(self): 692 maketest( 693 Transformer, 694 method="encode", 695 allow_non_contiguous_tensor=True, 696 run_executor=False, 697 )(self) 698 699 # basic test for functorch torch.ops.higher_order.cond 700 @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") 701 def test_ft_cond_basic(self): 702 maketest( 703 FTCondBasic, 704 capture_config=exir.CaptureConfig( 705 enable_dynamic_shape=True, 706 enable_functionalization=False, # TODO enable functionalization 707 ), 708 )(self) 709 710 @skipUnless(RUN_SKIPPED, "Emitter is not ready yet") 711 def test_ft_map_basic(self): 712 maketest( 713 FTMapBasic, 714 capture_config=exir.CaptureConfig( 715 enable_dynamic_shape=True, 716 enable_functionalization=False, # TODO enable functionalization 717 ), 718 )(self) 719 720 @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") 721 def test_ft_cond_dynshape(self): 722 maketest( 723 FTCondDynShape, 724 capture_config=exir.CaptureConfig( 725 enable_dynamic_shape=True, 726 enable_functionalization=False, # TODO enable functionalization 727 ), 728 )(self) 729 730 @skipUnless(RUN_SKIPPED, "Emitter is not ready yet") 731 def test_ft_map_dynshape(self): 732 maketest( 733 FTMapDynShape, 734 capture_config=exir.CaptureConfig( 735 enable_dynamic_shape=True, 736 enable_functionalization=False, # TODO enable functionalization 737 ), 738 )(self) 739 740 @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") 741 def test_batch_norm(self): 742 maketest( 743 BatchNormModel, 744 capture_config=exir.CaptureConfig( 745 enable_dynamic_shape=True, 746 ), 747 verify_graph=BatchNormModel.verify_graph, 748 # TODO: lean mode does not have native_batch_norm.out implemented 749 # run this on aten mode. 750 run_executor=is_aten_mode, 751 )(self) 752