xref: /aosp_15_r20/external/executorch/test/end2end/test_end2end.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# flake8: noqa: F401
8import functools
9import inspect
10import os
11import random
12import unittest
13from typing import Callable, Dict, Optional, Tuple, Type
14from unittest import skip, skipUnless
15
16import executorch.exir as exir
17
18import executorch.exir.control_flow as control_flow
19
20# @manual=//executorch/extension/pytree:pybindings
21import executorch.extension.pytree as pytree
22import torch
23
24from executorch.exir import (
25    CaptureConfig,
26    EdgeCompileConfig,
27    ExecutorchBackendConfig,
28    memory,
29)
30from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
31from executorch.exir.emit import emit_program
32from executorch.exir.pass_manager import PassManager
33from executorch.exir.passes import (
34    DebugPass,
35    MemoryPlanningPass,
36    to_scratch_op_pass,
37    ToOutVarPass,
38)
39from executorch.exir.print_program import pretty_print, print_program
40from executorch.exir.tensor import make_tensor_value, TensorSpec
41from executorch.exir.tests.control_flow_models import (
42    FTCondBasic,
43    FTCondDynShape,
44    FTMapBasic,
45    FTMapDynShape,
46)
47from executorch.exir.tests.dynamic_shape_models import BatchNormModel
48
49from executorch.exir.tests.transformer import Transformer
50from functorch.experimental.control_flow import cond
51
52kernel_mode = None  # either aten mode or lean mode
53try:
54    from executorch.extension.pybindings.portable_lib import (
55        _load_bundled_program_from_buffer,
56        _load_for_executorch_from_buffer,
57        _load_for_executorch_from_bundled_program,
58    )
59
60    kernel_mode = "lean"
61except ImportError as e:
62    print(e)
63    pass
64
65try:
66    from executorch.extension.pybindings.aten_lib import (
67        _load_bundled_program_from_buffer,
68        _load_for_executorch_from_buffer,
69        _load_for_executorch_from_bundled_program,
70    )
71
72    assert kernel_mode is None
73    kernel_mode = "aten"
74except ImportError as e:
75    print(e)
76    pass
77
78assert kernel_mode is not None
79
80is_aten_mode = kernel_mode == "aten"
81is_lean_mode = kernel_mode == "lean"
82
83from torch import nn
84from torch.utils import _pytree as torch_pytree
85
86from .exported_module import ExportedModule
87
88
89RUN_SKIPPED = int(os.environ.get("RUN_SKIPPED", "0"))
90
91
92class ModuleBasic(nn.Module):
93    def __init__(self):
94        super(ModuleBasic, self).__init__()
95
96    def forward(self, x):
97        return torch.sin(x).max()
98
99    def get_random_inputs(self):
100        return (torch.randn(100),)
101
102
103class ModuleOpsReturnMulti(nn.Module):
104    def __init__(self):
105        super(ModuleOpsReturnMulti, self).__init__()
106
107    def forward(self, a, b):
108        x, y = torch.topk(a, 3)
109        return x * 2 + b
110
111    def get_random_inputs(self):
112        return (torch.randn(10), torch.randn(3))
113
114
115class ModuleAdd(nn.Module):
116    def __init__(self):
117        super(ModuleAdd, self).__init__()
118
119    def forward(self, x, y):
120        return torch.add(x, y)
121
122    def get_random_inputs(self):
123        return (torch.randn(2, 2), torch.randn(2, 2))
124
125
126class ModuleFloatAddWithAlpha(nn.Module):
127    def __init__(self):
128        super(ModuleFloatAddWithAlpha, self).__init__()
129
130    def forward(self, x: torch.Tensor, y: torch.Tensor, c: float):
131        return torch.add(x, y, alpha=c)
132
133    def get_random_inputs(self):
134        return (torch.randn(2, 2), torch.randn(2, 2), random.random())
135
136
137class ModuleIntAddWithAlpha(nn.Module):
138    def __init__(self):
139        super(ModuleIntAddWithAlpha, self).__init__()
140
141    def forward(self, x: torch.Tensor, y: torch.Tensor, c: int):
142        return torch.add(x, y, alpha=c)
143
144    def get_random_inputs(self):
145        return (
146            torch.randint(0, 10, (2, 2)),
147            torch.randint(0, 10, (2, 2)),
148            random.randint(0, 10),
149        )
150
151
152class ModuleContainers(nn.Module):
153    def __init__(self):
154        super(ModuleContainers, self).__init__()
155
156    def forward(self, d):
157        a = d["a"]
158        b = d["b"]
159        return {"inputs": (a, b), "c": torch.add(a, b)}
160
161    def get_random_inputs(self):
162        return ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},)
163
164
165class ToyModelForMemPlanning(nn.Module):
166    def __init__(self):
167        super(ToyModelForMemPlanning, self).__init__()
168
169    def forward(self, a, b):
170        o = a
171        for i in range(3):
172            o = o * a
173            o = o + b
174        return o
175
176    def get_random_inputs(self):
177        return (
178            torch.randn(10),
179            torch.randn(10),
180        )
181
182
183class MemPlanningWithScratchTensor(nn.Module):
184    def __init__(self):
185        super(MemPlanningWithScratchTensor, self).__init__()
186        self.linear1 = nn.Linear(4, 2)
187        self.linear2 = nn.Linear(4, 2)
188
189    def forward(self, a, b):
190        o1 = self.linear1(a)
191        o2 = self.linear2(b)
192        return o1 + o2
193
194    def get_random_inputs(self):
195        return (
196            torch.randn(10, 4),
197            torch.randn(10, 4),
198        )
199
200
201class ModuleOpsReturnTensorList(nn.Module):
202    def __init__(self):
203        super(ModuleOpsReturnTensorList, self).__init__()
204
205    def forward(self, x):
206        split = torch.ops.aten.tensor_split.sections(x, 3)
207        return split[0]
208
209    def get_random_inputs(self):
210        return (torch.randn(100),)
211
212
213class ModuleReturnInput(nn.Module):
214    def __init__(self):
215        super(ModuleReturnInput, self).__init__()
216
217    def forward(self, x):
218        return (x, x, {"x": x, "y": x}, [x, x, x])
219
220    def get_random_inputs(self):
221        return (torch.randn(1),)
222
223
224class ModuleIfElse(nn.Module):
225    def __init__(self):
226        super().__init__()
227
228    def forward(self, c, x):
229        x = x * x
230
231        def addloop(x, n):
232            out = x
233            for _ in range(n - 1):
234                out = out + x
235            return out
236
237        def true_branch(c, x):
238            return addloop(x, 3)
239
240        def false_branch(c, x):
241            return addloop(x, 4)
242
243        y = cond(c, true_branch, false_branch, (c, x))
244        return y * y
245
246    def get_random_inputs(self):
247        return (torch.randint(2, [1]) == 0, torch.randn(10))
248
249
250class ModuleIfElseWithBoolInput(nn.Module):
251    def __init__(self):
252        super().__init__()
253
254    def forward(self, c: bool, x: torch.Tensor):
255        x = x * x
256
257        def addloop(x, n):
258            out = x
259            for _ in range(n - 1):
260                out = out + x
261            return out
262
263        def true_branch(c, x):
264            return addloop(x, 3)
265
266        def false_branch(c, x):
267            return addloop(x, 4)
268
269        y = cond(c, true_branch, false_branch, (c, x))
270
271        return y * y
272
273    def get_random_inputs(self):
274        return (random.randint(0, 1) == 0, torch.randn(10))
275
276
277class ModuleWhileIf(nn.Module):
278    def __init__(self):
279        super().__init__()
280
281    def forward(self, accum, cnt):
282        @control_flow.tracing_context(
283            inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
284        )
285        def loop_cond(accum, cnt):
286            return cnt != torch.zeros([1]).to(dtype=torch.long)
287
288        @control_flow.tracing_context(
289            inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
290        )
291        def loop_body(accum, cnt):
292            # return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long)
293            @control_flow.tracing_context(
294                inputs=(torch.zeros([1]).to(dtype=torch.long),)
295            )
296            def true_branch(cnt):
297                return cnt
298
299            @control_flow.tracing_context(
300                inputs=(torch.zeros([1]).to(dtype=torch.long),)
301            )
302            def false_branch(cnt):
303                return torch.zeros([1], dtype=torch.long)
304
305            accum = accum + cond(
306                torch.BoolTensor([True]), true_branch, false_branch, (cnt,)
307            )
308            # 'cnt - 1' does not work yet since the runtime does not expect
309            # tensor to be mixed with scalar for sub op.
310            return accum, cnt - torch.ones([1]).to(dtype=torch.long)
311
312        y, _ = control_flow.while_loop(
313            loop_cond,
314            loop_body,
315            (accum, cnt),
316        )
317        return y
318
319    def get_random_inputs(self):
320        return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
321
322
323class ModuleIfWhile(nn.Module):
324    def __init__(self):
325        super().__init__()
326
327    def forward(self, accum, cnt):
328        @control_flow.tracing_context(
329            inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
330        )
331        def true_branch(accum, cnt):
332            @control_flow.tracing_context(
333                inputs=(
334                    torch.zeros([1]).to(dtype=torch.long),
335                    torch.randint(10, 100, [1]),
336                )
337            )
338            def loop_cond(accum, cnt):
339                return cnt != torch.zeros([1]).to(dtype=torch.long)
340
341            @control_flow.tracing_context(
342                inputs=(
343                    torch.zeros([1]).to(dtype=torch.long),
344                    torch.randint(10, 100, [1]),
345                )
346            )
347            def loop_body(accum, cnt):
348                return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long)
349
350            return control_flow.while_loop(loop_cond, loop_body, (accum, cnt))
351
352        @control_flow.tracing_context(
353            inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
354        )
355        def false_branch(accum, cnt):
356            return accum, cnt
357
358        return cond(torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt))[
359            0
360        ]
361
362    def get_random_inputs(self):
363        return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
364
365
366class ModuleContiguousTensor(nn.Module):
367    def __init__(self):
368        super().__init__()
369        self.linear = nn.Linear(8, 32)
370
371    def forward(self, arg):
372        return self.linear(arg)
373
374    def get_random_inputs(self):
375        return (torch.randn(3, 8),)
376
377
378class ModuleInputDynamicShape(nn.Module):
379    def __init__(self):
380        super().__init__()
381
382    def forward(self, x):
383        for i in range(4):
384            x = x + x
385            x = x * x
386        return x
387
388    def get_upper_bound_inputs(self):
389        return (torch.randn(10),)
390
391    def get_random_inputs(self):
392        n = random.randint(1, 10)
393        return (torch.randn(n),)
394
395
396class ModuleIntermediateDynamicShape(nn.Module):
397    def __init__(self):
398        super().__init__()
399
400    def forward(self, x):
401        x = x * x
402
403        # We should use x[torch.nonzero(x)] ideally, but index op is not supported
404        # in the runtime so far.
405        x = torch.nonzero(x)
406        return x + x
407
408    def get_random_inputs(self):
409        return (torch.randint(0, 2, (10,), dtype=torch.float),)
410
411
412def allclose(lhs, rhs, rtol=1e-5, atol=1e-8):
413    r"""
414    Unlike torch.allocse which only handles Tensor arguments, allclose handles
415    list, tuple, dict and nesting of these as well.
416    """
417    if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
418        return torch.allclose(lhs, rhs, rtol, atol)
419    if isinstance(lhs, (tuple, list)) and isinstance(rhs, (tuple, list)):
420        return len(lhs) == len(rhs) and all(
421            allclose(a, b, rtol, atol) for a, b in zip(lhs, rhs)
422        )
423    if isinstance(lhs, dict) and isinstance(rhs, dict):
424        lhs_keys = set(lhs.keys())
425        rhs_keys = set(rhs.keys())
426        if lhs_keys != rhs_keys:
427            return False
428        return all(allclose(lhs[k], rhs[k], rtol, atol) for k in lhs)
429    else:
430        raise RuntimeError(
431            f"Unexpected types: lhs type {type(lhs)}, rhs type {type(rhs)}"
432        )
433
434
435def validate_contiguous_tensors(program):
436    def _is_contiguous_tensor(tensor: exir.schema.Tensor):
437        """
438        Ensure the tensor is pytorch contigous (torch.memory_format=torch.contiguous)
439        since the runtime can not handle non-contiguous tensors so far.
440        """
441        sizes = tensor.sizes
442        dim_order = tensor.dim_order
443        assert len(sizes) == len(dim_order)
444        for i, val in enumerate(dim_order):
445            if i != val:
446                return False
447        return True
448
449    for execution_plan in program.execution_plan:
450        for value in execution_plan.values:
451            if isinstance(value.val, exir.schema.Tensor):
452                assert _is_contiguous_tensor(
453                    value.val
454                ), f"Non-contiguous tensor found: size {value.val.sizes} stride {value.val.strides}. constant_buffer_idx {value.val.constant_buffer_idx}. allocation_info {value.val.allocation_info}."
455
456
457class BoundMethod(object):
458    def __init__(self, instance, callable):
459        self._instance = instance
460        self._callable = callable
461
462    def __call__(self, *args, **kwargs):
463        return self._callable(self.instance, *args, **kwargs)
464
465
466def maketest(
467    module_cls: Type[nn.Module],
468    niter: int = 10,
469    run_executor: bool = True,
470    do_tree_flatten: bool = False,
471    run_graph_module: bool = True,
472    atol: float = 1e-8,
473    rtol: float = 1e-5,
474    ignore_to_out_var_failure: bool = False,
475    allow_non_contiguous_tensor: bool = False,
476    method: str = "forward",
477    dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND,
478    capture_config=None,
479    verify_graph: Optional[Callable] = None,
480) -> Callable[[unittest.TestCase], None]:
481    r"""Returns a TestCase method to test the provided module class and method.
482
483    Args:
484        module_cls: The subclass of nn.Module to export.
485        niter: The number of random input data sets to test with.
486        run_executor: Whether to run the model on the executor. We may want to
487            skip running a model thru executor since some kernels are not
488            implemented.
489        do_tree_flatten: Whether to flatten input and unflatten output.
490        run_graph_module: Whether to run the traced and transformed GraphModule.
491            One may want to skip this if some custom ops do not have
492            implementation in torch.ops but is implemented in the executor.
493        atol: Absolute tolerance used in allclose and torch.allclose
494        rtol: Relative tolerance used in allclose and torch.allclose
495        ignore_to_out_var_failure: Whether to ignore the failue when a
496            functional op does not have an out variant.
497        allow_non_contiguous_tensor: If false, will validate that the emitted
498            program only contains contiguous tensors.
499        method: The name of the module_cls method to trace.
500        dynamic_memory_planning_mode: The dynamic memory planning mode to use.
501
502    Returns:
503        A TestCase method that tests the provided module class and method.
504    """
505
506    def wrapper(self: unittest.TestCase) -> None:
507        """A TestCase method that traces/exports/tests an nn.Module and method."""
508        module = ExportedModule.export(
509            module_class=module_cls,
510            # testend2end only supports modules with single methods defined
511            methods=(method,),
512            ignore_to_out_var_failure=ignore_to_out_var_failure,
513            dynamic_memory_planning_mode=dynamic_memory_planning_mode,
514            capture_config=capture_config,
515        )
516        if verify_graph:
517            verify_graph(self, module.exported_program.graph_module)
518        print(f"inputs for tracing: {module.trace_inputs}")
519
520        # compare the result between the eager module and graph module
521        inputs_list = [module.get_random_inputs() for _ in range(niter)]
522
523        if run_graph_module:
524            for inputs in inputs_list:
525                with torch.no_grad():
526                    # only one method is supported so just grab that single method
527                    expected = getattr(module.eager_module, module.methods[0])(*inputs)
528                with torch.no_grad():
529                    result = module.exported_program.module()(*inputs)
530                self.assertTrue(allclose(expected, result, rtol, atol))
531
532        program = module.executorch_program.executorch_program
533        pretty_print(program)
534        print_program(program, show_meminfo=True, mark_dynamic_shape_tensor=True)
535        print(f"mem buffer sizes: {program.execution_plan[0].non_const_buffer_sizes}")
536        if not allow_non_contiguous_tensor:
537            validate_contiguous_tensors(program)
538        self.assertTrue(len(program.execution_plan[0].non_const_buffer_sizes) >= 2)
539        # We should not enable the following assertion since for some models
540        # that simply returning graph input, no mutable memory should be allocated
541        # self.assertTrue(all(s > 0 for s in program.program.execution_plan[0].non_const_buffer_sizes[1:]))
542
543        program.version = 0
544        buff = module.executorch_program.buffer
545        # Check that the magic version number is in the expected place, and
546        # follows the expected pattern.
547        self.assertRegex(buff[4:8].decode(errors="replace"), r"^ET[0-9][0-9]$")
548
549        if run_executor:
550            print("Running on the runtime")
551            executorch_module = _load_for_executorch_from_buffer(buff)
552            # compare the result between eager module and executor
553            for idx, inputs in enumerate(inputs_list):
554                with torch.no_grad():
555                    expected = getattr(module.eager_module, method)(*inputs)
556
557                if do_tree_flatten:
558                    # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
559                    flatten_inputs, inputs_spec = pytree.tree_flatten(*inputs)
560                    executorch_result = executorch_module.forward([*flatten_inputs])
561                    # pyre-fixme[16]: Module `pytree` has no attribute `TreeSpec`.
562                    executorch_result_unflatten = pytree.TreeSpec.from_str(
563                        program.execution_plan[0].container_meta_type.encoded_out_str
564                    ).tree_unflatten(executorch_result)
565                    actual = executorch_result_unflatten
566                else:
567                    actual = executorch_module.forward(inputs)[0]
568                is_close = allclose(expected, actual, rtol, atol)
569                if not is_close:
570                    print(f"Fail for {idx}th inputs: {inputs}")
571                    print(f"expected result: {expected}")
572                    print(f"actual result: {actual}")
573                self.assertTrue(is_close)
574
575    return wrapper
576
577
578class E2ETest(unittest.TestCase):
579    r"""
580    When adding a new unittest, call maketest(ModuleName) if possible since
581    maketest handles all the boilterplate part. Ideally, we only need define
582    a new nn.Module and add one line to call maketest for new end2end test cases.
583    """
584
585    # don't run the model thru executor because aten::sin.out is not defined
586    # in the executor currently.
587    #
588    # aten::max.default does not have an out variant. Thus we need set
589    # ignore_to_out_var_failure to be True.
590    def test_basic(self):
591        maketest(ModuleBasic, run_executor=False, ignore_to_out_var_failure=True)(self)
592
593    # Make sure we can handle ops that return mutliple values. E.g. topk
594    # At one time we can not properly setup TensorSpec for an Fx node
595    # returning multiple tensors
596    #
597    # don't run the model thru executor because aten::topk.values is not defined
598    # in the executor currently
599    def test_ops_return_multi(self):
600        maketest(ModuleOpsReturnMulti, run_executor=False)(self)
601
602    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss")
603    def test_mem_planning_toy_model(self):
604        maketest(
605            ToyModelForMemPlanning,
606            capture_config=exir.CaptureConfig(
607                enable_dynamic_shape=True,
608            ),
609        )(self)
610
611    # TODO: add ops implementations and turn on 'run_executor'
612    def test_mem_planning_scratch_tensor(self):
613        maketest(
614            MemPlanningWithScratchTensor,
615            run_graph_module=False,
616            run_executor=False,
617            atol=1e-5,
618        )(self)
619
620    def test_executorch_forward(self):
621        maketest(ModuleAdd)(self)
622
623    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss")
624    def test_containers(self):
625        maketest(
626            ModuleContainers,
627            do_tree_flatten=True,
628            capture_config=exir.CaptureConfig(
629                enable_dynamic_shape=True,
630            ),
631        )(self)
632
633    # can not run the graph module since the out variance with tensor list out
634    # argument returns None rather than tensor list.
635    #
636    # Can not run in the executor since kernel for tensor splitting is not implemented..
637    def test_ops_return_tensorlist(self):
638        maketest(ModuleOpsReturnTensorList, run_graph_module=False, run_executor=False)(
639            self
640        )
641
642    # Failed to produce a graph during tracing w/ dynamo because there are no torch ops
643    # test_return_input = maketest(ModuleReturnInput, do_tree_flatten=True)
644
645    # can not run this on the executor because missing the following ops:
646    #   aten::select_copy.int_out, aten::eq.Scalar_out
647    # TODO(zhxchen17) re-enable these tests.
648    # test_control_flow_cond = maketest(ControlFlowCond, run_executor=False)
649    # fail to trace with functionalization enabled
650    # test_ifelse = maketest(ModuleIfElse)
651
652    # fail to trace with functionalization enabled
653    # Fail with error: Missing out variants: {'aten::select', 'aten::_shape_as_tensor', 'aten::tensor_split'}
654    # TODO(zhxchen17) re-enable these tests.
655    # test_while_0 = maketest(
656    #     ControlFlowWhile,
657    #     ignore_to_out_var_failure=True,
658    #     run_executor=False,
659    # )
660
661    # test_while = maketest(ModuleWhile)
662
663    # test_while_if = maketest(ModuleWhileIf)
664    # test_if_while = maketest(ModuleIfWhile)
665    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) This fails on OSS macos job")
666    def test_contiguous_tensor(self):
667        maketest(ModuleContiguousTensor, run_executor=False)(self)
668
669
670class DynamicModelE2ETest(unittest.TestCase):
671    """
672    End2end tests for dynamic models. For dynamic models we mean models with
673    control flow or dynamic shape.
674    """
675
676    @skip("Revisit when unbacked symint is ready")
677    def test_intermediate_dynamic_shape(self):
678        maketest(
679            ModuleIntermediateDynamicShape,
680            run_graph_module=False,
681            allow_non_contiguous_tensor=True,
682            capture_config=exir.CaptureConfig(
683                enable_dynamic_shape=True,
684            ),
685        )(self)
686
687    # TODO(shunting): some non constant tensors for transformer are non-contiguous.
688    # Ignore for now. Will debug more.
689    # NOTE: can not run on runtime since missing these ops: P535190636
690    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) This fails on OSS macos job")
691    def test_transformer_encode(self):
692        maketest(
693            Transformer,
694            method="encode",
695            allow_non_contiguous_tensor=True,
696            run_executor=False,
697        )(self)
698
699    # basic test for functorch torch.ops.higher_order.cond
700    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss")
701    def test_ft_cond_basic(self):
702        maketest(
703            FTCondBasic,
704            capture_config=exir.CaptureConfig(
705                enable_dynamic_shape=True,
706                enable_functionalization=False,  # TODO enable functionalization
707            ),
708        )(self)
709
710    @skipUnless(RUN_SKIPPED, "Emitter is not ready yet")
711    def test_ft_map_basic(self):
712        maketest(
713            FTMapBasic,
714            capture_config=exir.CaptureConfig(
715                enable_dynamic_shape=True,
716                enable_functionalization=False,  # TODO enable functionalization
717            ),
718        )(self)
719
720    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss")
721    def test_ft_cond_dynshape(self):
722        maketest(
723            FTCondDynShape,
724            capture_config=exir.CaptureConfig(
725                enable_dynamic_shape=True,
726                enable_functionalization=False,  # TODO enable functionalization
727            ),
728        )(self)
729
730    @skipUnless(RUN_SKIPPED, "Emitter is not ready yet")
731    def test_ft_map_dynshape(self):
732        maketest(
733            FTMapDynShape,
734            capture_config=exir.CaptureConfig(
735                enable_dynamic_shape=True,
736                enable_functionalization=False,  # TODO enable functionalization
737            ),
738        )(self)
739
740    @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss")
741    def test_batch_norm(self):
742        maketest(
743            BatchNormModel,
744            capture_config=exir.CaptureConfig(
745                enable_dynamic_shape=True,
746            ),
747            verify_graph=BatchNormModel.verify_graph,
748            # TODO: lean mode does not have native_batch_norm.out implemented
749            # run this on aten mode.
750            run_executor=is_aten_mode,
751        )(self)
752