# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe import typing import unittest from contextlib import contextmanager from typing import List, Optional, Tuple import executorch.exir as exir import executorch.exir.schema as schema import executorch.exir.tests.models as models import pytest import torch from executorch.exir import ( EdgeCompileConfig, ExecutorchBackendConfig, ExecutorchProgramManager, to_edge, ) from executorch.exir._serialize._program import deserialize_pte_binary from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.emit import emit_program # noqa from executorch.exir.error import InternalError from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.constant_prop_pass import constant_prop_pass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.print_program import pretty_print, print_program # noqa from executorch.exir.schema import ( Bool, DelegateCall, Double, EValue, ExecutionPlan, Int, IntList, JumpFalseCall, KernelCall, KernelTypes, MoveCall, Null, OptionalTensorList, Program, String, Tensor, ) from executorch.exir.tests.common import register_additional_test_aten_ops from executorch.exir.tests.models import Mul from executorch.extension.pybindings.portable_lib import ( _load_for_executorch_from_buffer, ) from functorch.experimental import control_flow from torch import nn from torch.export import Dim, export class WrapperModule(torch.nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, *args, **kwargs): return self.fn(*args, **kwargs) @contextmanager def patch_forward(obj: torch.nn.Module, new_method): """Helper method to make it easier to cleanly torch.export() a method on a module that is not `forward`. TODO(suo): upstream this to torch.export.wrapper. """ # Save the original method original_method = obj.forward # Patch the method obj.forward = new_method.__get__(obj, obj.__class__) try: yield finally: # Restore the original method obj.forward = original_method class TestEmit(unittest.TestCase): @classmethod def setUpClass(cls) -> None: register_additional_test_aten_ops() def setUp(self) -> None: self.compile_config = EdgeCompileConfig(_check_ir_validity=False) def check_tensor_buffer_loc( self, value_index: int, values: List[EValue], exp_buffer_idx: int, exp_mem_id: Optional[int], exp_mem_offset: Optional[int], ) -> None: value = typing.cast(schema.Tensor, values[value_index].val) self.assertIsInstance(value, schema.Tensor) self.assertEqual(value.data_buffer_idx, exp_buffer_idx) if not value.allocation_info: self.assertIsNone(exp_mem_id) self.assertIsNone(exp_mem_offset) else: self.assertEqual(value.allocation_info.memory_id, exp_mem_id) assert value.allocation_info self.assertEqual(value.allocation_info.memory_offset, exp_mem_offset) def count_node(self, graph_module: torch.fx.GraphModule, opname: str) -> int: return [ node.target._overloadpacket._qualified_op_name for node in graph_module.graph.nodes if node.op == "call_function" ].count(opname) def run_dce(self, graph_module: torch.fx.GraphModule) -> None: for submodule in graph_module.modules(): self.assertIsInstance(submodule, torch.fx.GraphModule) typing.cast(torch.fx.GraphModule, submodule).graph.eliminate_dead_code() def check_value_types(self, values: List[EValue]) -> None: for value in values: self.assertTrue(type(value.val) in KernelTypes.__args__) def count_move_instructions(self, program: Program) -> int: instructions = program.execution_plan[0].chains[0].instructions assert instructions is not None res = 0 for instr in instructions: if isinstance(instr.instr_args, MoveCall): res += 1 return res def test_basic_api(self) -> None: class Foo(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x * y + x f = Foo() program = ( to_edge( export( f, (torch.ones(3, 2), torch.zeros(3, 2)), ) ) .to_executorch() .executorch_program ) exec_plan = program.execution_plan[0] ops = exec_plan.operators for op in ops: self.assertEqual(op.overload, "out") self.assertEqual(ops[0].name, "aten::mul") self.assertEqual(ops[1].name, "aten::add") self.assertEqual(len(exec_plan.inputs), 2) self.assertEqual(len(exec_plan.outputs), 1) self.assertEqual(exec_plan.inputs[0], 0) self.assertEqual(exec_plan.outputs[0], 3) def test_basic_end_to_end(self) -> None: f = models.BasicSinMax() program = ( to_edge(export(f, f.get_random_inputs())).to_executorch().executorch_program ) exec_plan = program.execution_plan[0] ops = exec_plan.operators for op in ops: self.assertIn(op.overload, {"out", "unary_out"}) self.assertEqual(ops[0].name, "aten::sin") self.assertEqual(len(exec_plan.inputs), 1) self.assertEqual(len(exec_plan.outputs), 1) self.assertEqual(exec_plan.inputs[0], 0) self.assertEqual(exec_plan.outputs[0], 1) @pytest.mark.skip(reason="Test not working on OSS") def test_nested_return(self) -> None: class Foo(torch.nn.Module): def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: return ( torch.tensor(1), torch.tensor(2), [torch.sin(x).max(), torch.cos(x).max()], ) f = Foo() x = (torch.randn(100),) program = to_edge(export(f, x)).to_executorch().executorch_program exec_plan = program.execution_plan[0] self.assertEqual(len(exec_plan.outputs), 4) self.assertEqual(len(exec_plan.inputs), 1) self.assertEqual( program.execution_plan[0].container_meta_type.encoded_out_str, "T3#1#1#2($,$,L2#1#1($,$))", ) self.assertEqual( program.execution_plan[0].container_meta_type.encoded_inp_str, "T2#1#0(T1#1($),D0())", ) def test_constant_output(self): class M(torch.nn.Module): def forward(self, x): return [((1, 3, 1.2), True, [x + x, x * x], None)] ep = torch.export.export(M(), (torch.ones(2, 3),)) res = ep.module()(torch.ones(2, 3)) self.assertEqual(res[0][0], (1, 3, 1.2)) program = to_edge(ep).to_executorch().executorch_program outputs = program.execution_plan[0].outputs self.assertEqual(len(outputs), 7) self.assertEqual(program.execution_plan[0].values[outputs[0]].val.int_val, 1) self.assertEqual(program.execution_plan[0].values[outputs[1]].val.int_val, 3) self.assertEqual( program.execution_plan[0].values[outputs[2]].val.double_val, 1.2 ) self.assertEqual( program.execution_plan[0].values[outputs[3]].val.bool_val, True ) self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null) def test_int_list_input(self): class M(torch.nn.Module): def forward(self, x, y, z): return x + y, x + x, x + y + z ep = torch.export.export(M(), (torch.ones(2, 3), 2, True)) ep.module()(torch.ones(2, 3), 2, True) program = to_edge(ep).to_executorch().executorch_program inputs = program.execution_plan[0].inputs self.assertEqual(len(inputs), 3) self.assertEqual(program.execution_plan[0].values[inputs[1]].val.int_val, 2) self.assertEqual(program.execution_plan[0].values[inputs[2]].val.bool_val, True) def test_inplace_ops(self) -> None: class Foo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: y = torch.sin(x) z = y.view(100) torch.relu_(z) return z.max() f = Foo() inputs = (torch.ones((10, 10)),) edge = to_edge(export(f, inputs)) removed_ops = ["aten::relu_", "aten::view"] expected_ops = [ "aten::sin", "aten::relu", "aten::max", "executorch_prim::et_view", # aten::view_copy if ExecutorchBackendConfig.remove_view_copy = False ] for opname in removed_ops: self.assertEqual( self.count_node(edge.exported_program().graph_module, opname), 0 ) for opname in expected_ops: if ( opname != "executorch_prim::et_view" ): # et_view appears as call_function with target = memory.view in graph self.assertTrue( self.count_node(edge.exported_program().graph_module, opname) >= 1 ) program = edge.to_executorch().executorch_program for opname in removed_ops: self.assertTrue( all(op.name != opname for op in program.execution_plan[0].operators) ) for opname in expected_ops: self.assertTrue( any(op.name == opname for op in program.execution_plan[0].operators) ) def test_operators_unique(self) -> None: class OpRepeatedModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = torch.ones(2, 2) self.b = 2 * torch.ones(2, 2) def forward(self, x: torch.Tensor) -> torch.Tensor: for _ in range(10): z = self.a * x y = z + self.b return y model = OpRepeatedModule() inputs = (torch.ones(2, 2),) program = to_edge(export(model, inputs)).to_executorch().executorch_program self.assertEqual(len(program.execution_plan[0].operators), 2) def test_list_type(self) -> None: """Tests that the types of lists are correctly found""" class Foo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.permute(x, (2, 0, 1)) f = Foo() program = ( to_edge(export(f, (torch.randn(2, 3, 5),))) .to_executorch() .executorch_program ) exir.print_program.pretty_print(program) deboxed_int_list = [] for item in program.execution_plan[0].values[5].val.items: # pyre-ignore[16] deboxed_int_list.append( program.execution_plan[0].values[item].val.int_val # pyre-ignore[16] ) self.assertEqual(IntList(deboxed_int_list), IntList([2, 0, 1])) def test_kwargs1(self) -> None: """Tests that the kwargs are placed in the order specified by native_functions.yaml """ class Foo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: batch1 = torch.randn(10, 3, 4) batch2 = torch.randn(10, 4, 5) return torch.addbmm(x, batch1, batch2, alpha=2, beta=3) f = Foo() program = ( to_edge(export(f, (torch.randn(3, 5),))).to_executorch().executorch_program ) # The value for beta should appear before alpha self.assertEqual(program.execution_plan[0].values[12].val, Int(3)) self.assertEqual(program.execution_plan[0].values[13].val, Int(2)) def test_kwargs2(self) -> None: """Tests that the kwargs are placed in the order specified by native_functions.yaml """ class Foo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: values = torch.randn(3, 2) return torch.searchsorted(x, values, side="right", right=True) f = Foo() x, _ = torch.sort(torch.randn(3, 4)) program = to_edge(export(f, (x,))).to_executorch().executorch_program # The value for right should appear before side self.assertEqual(program.execution_plan[0].values[6].val, Bool(False)) self.assertEqual(program.execution_plan[0].values[7].val, Bool(True)) self.assertEqual(program.execution_plan[0].values[8].val, String("right")) self.assertEqual(program.execution_plan[0].values[9].val, Null()) def _assertCallLength(self, program: Program, idx: int, expected_len: int) -> None: instr_args = program.execution_plan[0].chains[0].instructions[idx].instr_args if isinstance(instr_args, KernelCall) or isinstance(instr_args, DelegateCall): self.assertEqual(len(instr_args.args), expected_len) else: self.assertTrue(False) def test_out(self) -> None: class Foo(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: z = y.clone() return torch.mul(x, y, out=z) f = Foo() program = ( to_edge(export(f, (torch.ones(3), torch.ones(3)))) .to_executorch() .executorch_program ) self.assertEqual(len(program.execution_plan[0].chains[0].instructions), 1) self._assertCallLength(program, 0, 4) def test_model_out(self) -> None: class Module_out(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = 3 * torch.ones(2, 2, dtype=torch.int32) self.b = 2 * torch.ones(2, 2, dtype=torch.int32) def forward(self, x: torch.Tensor) -> torch.Tensor: z = x.clone() torch.mul(self.a, x, out=z) y = x.clone() torch.add(z, self.b, alpha=2, out=y) return y model_out = Module_out() inputs = (torch.ones(2, 2, dtype=torch.int32),) # Trace to FX Graph. program = to_edge(export(model_out, inputs)).to_executorch().executorch_program self.assertEqual(len(program.execution_plan[0].chains[0].instructions), 2) self._assertCallLength(program, 0, 4) self._assertCallLength(program, 1, 5) def test_stacktrace(self) -> None: def f(x: torch.Tensor) -> torch.Tensor: return torch.mul(x, torch.randn(3, 2)) def g(x: torch.Tensor) -> torch.Tensor: return torch.sin(f(x)) class Foo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.add(g(x), torch.randn(3, 2)) h = Foo() x = (torch.randn(3, 2),) exec_prog = to_edge(export(h, x)).to_executorch( exir.ExecutorchBackendConfig(emit_stacktrace=True) ) program = exec_prog.executorch_program # Check the mul operator's stack trace contains f -> g -> h self.assertTrue( "return torch.mul(x, torch.randn(3, 2))" in program.execution_plan[0] # pyre-ignore[16] .chains[0] .stacktrace[1] .items[-1] .context ) self.assertEqual( program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "f" ) self.assertEqual( program.execution_plan[0].chains[0].stacktrace[1].items[-2].name, "g" ) self.assertEqual( program.execution_plan[0].chains[0].stacktrace[1].items[-3].name, "forward" ) # Check the sin operator's stack trace contains g -> h self.assertEqual( program.execution_plan[0].chains[0].stacktrace[2].items[-1].name, "g" ) self.assertEqual( program.execution_plan[0].chains[0].stacktrace[2].items[-2].name, "forward" ) def test_stacktrace_off(self) -> None: class Foo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.mul(x, torch.randn(3, 2)) f = Foo() class Goo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.sin(f(x)) g = Goo() class Hoo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.add(g(x), torch.randn(3, 2)) h = Hoo() x = (torch.randn(3, 2),) program = to_edge(export(h, x)).to_executorch().executorch_program # Check the stacktrace is None since we did not specify to get the stacktrace self.assertTrue(program.execution_plan[0].chains[0].stacktrace is None) def test_positional_argument_default_value(self) -> None: class Foo(torch.nn.Module): def forward(self, x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: z = torch.ones(6, 2) return torch.ops.aten.cat.out((x, n), out=z) f = Foo() x = torch.randn(3, 2) program = ( to_edge(export(f, (x, x))) # .to_edge(self.compile_config) # TODO(larryliu): fix cat .to_executorch().executorch_program ) self.assertEqual(len(program.execution_plan[0].chains[0].instructions), 1) self._assertCallLength(program, 0, 4) @pytest.mark.skip(reason="Test not working on OSS") def test_emit_multiple_out(self) -> None: class Foo(torch.nn.Module): def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return torch.topk(x, 5) f = Foo() x = (torch.randn(10),) program = to_edge(export(f, x)).to_executorch().executorch_program self._assertCallLength(program, 0, 8) def test_emit_layout(self) -> None: class Foo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.ones_like(x) f = Foo() x = (torch.randn(3, 2),) program = to_edge(export(f, x)).to_executorch().executorch_program vals = program.execution_plan[0].values for val in vals: v = val.val if isinstance(v, Tensor): self.assertEqual(v.layout, 0) def test_optional_tensor_list(self) -> None: class Foo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: a = torch.nonzero(x) torch._constrain_as_size(a.shape[0], min=1) b = torch.ops.aten.index.Tensor(x, [a]) return b f = Foo() x = (torch.triu(torch.ones(2, 2)),) program = ( to_edge( export(f, x), compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ) .to_executorch() .executorch_program ) self.assertTrue( isinstance(program.execution_plan[0].values[3].val, OptionalTensorList) ) self._assertCallLength(program, 0, 3) self._assertCallLength(program, 1, 4) def test_optional_float_list(self) -> None: class M(torch.nn.Module): def forward(self, x): return torch.nn.functional.interpolate(x, scale_factor=2) x = (torch.randn(1, 1, 2, 2),) program = to_edge(export(M(), x)).to_executorch().executorch_program self.assertIsInstance( program.execution_plan[0].values[-1].val, schema.OptionalTensorList ) def test_emit_cond(self) -> None: class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, pred, x): def true_fn(y: torch.Tensor) -> torch.Tensor: y = y + y y = torch.mm(y, y) return y def false_fn(y: torch.Tensor) -> torch.Tensor: return torch.mm(y, y) ret = control_flow.cond(pred, true_fn, false_fn, [x]) return ret module = to_edge(export(M(), (torch.tensor(True), torch.ones(2, 2)))) program = module.to_executorch().executorch_program num_mm = 0 num_add = 0 num_other = 0 for inst in program.execution_plan[0].chains[0].instructions: if not isinstance(inst.instr_args, KernelCall): continue op = ( program.execution_plan[0] .operators[inst.instr_args.op_index] # pyre-ignore[16] .name ) if "mm" in op: num_mm += 1 elif "add" in op: num_add += 1 else: num_other += 1 self.assertEqual(num_mm, 2) self.assertEqual(num_add, 1) self.assertEqual(num_other, 0) def test_emit_map(self) -> None: class Foo(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y return control_flow.map(map_fn, x, y) f = Foo() inputs = (torch.ones(4, 4), torch.ones(4)) module = to_edge( export(f, inputs), compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ) program = module.to_executorch().executorch_program op_table = program.execution_plan[0].operators # The first two operators at the beginning of a map program should be sym_size # and select_copy, which is what we verify here. The first operator is to generate # the number of iterations and the second operator is to slice the input tensor to # generate the tensor on which this iteration will operate on. self.assertEqual( op_table[ program.execution_plan[0] # pyre-ignore[16] .chains[0] .instructions[0] .instr_args.op_index ].name, "aten::sym_size", ) self.assertEqual( op_table[ program.execution_plan[0] # pyre-ignore[16] .chains[0] .instructions[1] .instr_args.op_index ].name, "aten::select_copy", ) # The last three instructions in the map sub-program are: # - Calling the custom op to append the output of this iteration to the accumulator tensor # - Increment the iteration count. # - Then checking if we've completed all the iterations. # We check here that both of these have been generated. self.assertEqual( op_table[ program.execution_plan[0] # pyre-ignore[16] .chains[0] .instructions[-5] .instr_args.op_index ].name, "executorch_prim::et_copy_index", ) self.assertEqual( op_table[ program.execution_plan[0] # pyre-ignore[16] .chains[0] .instructions[-4] .instr_args.op_index ].name, "executorch_prim::add", ) self.assertEqual( op_table[ program.execution_plan[0] # pyre-ignore[16] .chains[0] .instructions[-3] .instr_args.op_index ].name, "executorch_prim::eq", ) # The last two instructions in the overall program check if we should jump back to the # beginning of the loop and then resets the iteration counter if we fall through. self.assertTrue( isinstance( program.execution_plan[0].chains[0].instructions[-2].instr_args, JumpFalseCall, ) ) self.assertEqual( op_table[ program.execution_plan[0] # pyre-ignore[16] .chains[0] .instructions[-1] .instr_args.op_index ].name, "executorch_prim::sub", ) def test_load_emit_map(self) -> None: class Foo(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y return control_flow.map(map_fn, x, y) f = Foo() inputs = (torch.ones(4, 4), torch.ones(4)) module = to_edge( export(f, inputs), compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ) _load_for_executorch_from_buffer(module.to_executorch().buffer) def test_run_emit_map(self) -> None: class Foo(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y return control_flow.map(map_fn, x, y) f = Foo() inputs = (torch.ones(4, 4), torch.ones(4)) module = to_edge( export(f, inputs), compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ) buffer = module.to_executorch().buffer loaded_model = _load_for_executorch_from_buffer(buffer) outputs = loaded_model(inputs)[0] torch.allclose(outputs, f(*inputs)) def test_dim_order(self) -> None: class SimpleLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.relu(self.linear(x)) model = SimpleLinear() inputs = (torch.ones(10, 5),) program = ( to_edge( export(model, inputs), compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ) .to_executorch() .executorch_program ) addmm_found = False for inst in program.execution_plan[0].chains[0].instructions: kernel = inst.instr_args if isinstance(kernel, KernelCall): op_id = kernel.op_index op = program.execution_plan[0].operators[op_id] if op.name == "aten::addmm": addmm_found = True args = kernel.args bias_id = args[0] act_id = args[1] weight_id = args[2] bias_dim_order = [0] act_dim_order = [0, 1] weight_dim_order = [0, 1] bias_tensor = typing.cast( schema.Tensor, program.execution_plan[0].values[bias_id].val ) act_tensor = typing.cast( schema.Tensor, program.execution_plan[0].values[act_id].val ) weight_tensor = typing.cast( schema.Tensor, program.execution_plan[0].values[weight_id].val ) self.assertTrue(bias_tensor.dim_order == bias_dim_order) self.assertTrue(act_tensor.dim_order == act_dim_order) self.assertTrue(weight_tensor.dim_order == weight_dim_order) self.assertTrue(addmm_found) def test_non_const_buffer_sizes(self) -> None: class Add(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: b = 3 + 1 return x + b f = Add() edge_program_manager = to_edge( export( f, (torch.ones(3, 2),), ) ) edge_program_manager._edge_programs["forward"] = constant_prop_pass( edge_program_manager.exported_program() ) non_const_buffer_size_with_const_prop_pass = ( edge_program_manager.to_executorch() .executorch_program.execution_plan[0] .non_const_buffer_sizes ) edge_program_manager = to_edge( export( f, (torch.ones(3, 2),), ) ) non_const_buffer_size_without_const_prop_pass = ( edge_program_manager.to_executorch() .executorch_program.execution_plan[0] .non_const_buffer_sizes ) self.assertTrue( non_const_buffer_size_with_const_prop_pass[1] < non_const_buffer_size_without_const_prop_pass[1] ) # cant compare plans directly with __eq__ because of the plan names, and data_buffer_idx in tensor values def _compare_execution_plans( self, plan_single: ExecutionPlan, plan_merged: ExecutionPlan ) -> None: self.assertEqual( plan_single.container_meta_type, plan_merged.container_meta_type, ) self.assertEqual( plan_single.inputs, plan_merged.inputs, ) self.assertEqual( plan_single.outputs, plan_merged.outputs, ) self.assertEqual( plan_single.chains, plan_merged.chains, ) self.assertEqual( plan_single.operators, plan_merged.operators, ) self.assertEqual( plan_single.non_const_buffer_sizes, plan_merged.non_const_buffer_sizes, ) self.assertEqual( len(plan_single.values), len(plan_merged.values), ) for i in range(0, len(plan_single.values)): single_val = plan_single.values[i].val merged_val = plan_merged.values[i].val if isinstance(single_val, Tensor): # constant buffer index might be different as the constant buffer is shared between plans self.assertTrue(isinstance(merged_val, Tensor)) self.assertEqual(single_val.storage_offset, merged_val.storage_offset) self.assertEqual(single_val.scalar_type, merged_val.scalar_type) self.assertEqual(single_val.sizes, merged_val.sizes) self.assertEqual(single_val.dim_order, merged_val.dim_order) self.assertEqual(single_val.requires_grad, merged_val.requires_grad) self.assertEqual(single_val.layout, merged_val.layout) self.assertEqual(single_val.allocation_info, merged_val.allocation_info) self.assertEqual(single_val.shape_dynamism, merged_val.shape_dynamism) else: self.assertEqual(single_val, merged_val) def test_emit_memory_format_valid(self) -> None: class SimpleLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: contiguous = x.to( dtype=torch.float32, memory_format=torch.contiguous_format ) preserve = x.to( dtype=torch.float32, memory_format=torch.preserve_format ) return contiguous + preserve # Should succeed at exporting model with legal memory format (contiguous, preserve) model = SimpleLinear() inputs = (torch.ones(10, 5),) try: to_edge( export(model, inputs), compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ).to_executorch() except: self.fail("Failed to export model with legal memory format") def test_emit_memory_format_invalid(self) -> None: class SimpleLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: return x.to(dtype=torch.float32, memory_format=torch.channels_last) # Failure expected when exporting model with illegal memory format (channels_last) when not using dim_order model = SimpleLinear() inputs = (torch.ones(10, 5, 2, 1),) with self.assertRaises(InternalError): to_edge( export(model, inputs), compile_config=exir.EdgeCompileConfig( _check_ir_validity=False, _skip_dim_order=True ), ).to_executorch() # Success if you use dim_order to_edge( export(model, inputs), compile_config=exir.EdgeCompileConfig( _check_ir_validity=False, _skip_dim_order=False ), ).to_executorch() def test_emit_multiple_entry_points(self) -> None: class SimpleLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) self.linear2 = torch.nn.Linear(5, 5) def forward_relu(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.relu(self.linear(x)) def forward_sigmoid(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.sigmoid(self.linear2(x)) model = SimpleLinear() inputs = (torch.ones(10, 5),) with patch_forward(model, model.forward_relu): program_relu = to_edge( export(model, inputs), compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ).to_executorch() with patch_forward(model, model.forward_sigmoid): program_sigmoid = to_edge( export(model, inputs), compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ).to_executorch() exir_input = { "forward_relu": program_relu.exported_program(), "forward_sigmoid": program_sigmoid.exported_program(), } merged_program = emit_program(exir_input, False).program self.assertEqual(len(merged_program.execution_plan), 2) self.assertEqual( merged_program.execution_plan[0].name, "forward_relu", ) self.assertEqual( merged_program.execution_plan[1].name, "forward_sigmoid", ) # reserved spot, weight, bias self.assertEqual( len(program_sigmoid._emitter_output.program.constant_buffer), 3, ) self.assertEqual( len(program_relu._emitter_output.program.constant_buffer), 3, ) # sum of the entry points minus 1 because we only have one reserved spot still self.assertEqual( len(merged_program.constant_buffer), len(program_sigmoid._emitter_output.program.constant_buffer) + len(program_relu._emitter_output.program.constant_buffer) - 1, ) self._compare_execution_plans( merged_program.execution_plan[0], program_relu._emitter_output.program.execution_plan[0], ) self._compare_execution_plans( merged_program.execution_plan[1], program_sigmoid._emitter_output.program.execution_plan[0], ) def test_emit_weight_deduplication(self) -> None: class SimpleLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) def forward_relu(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.relu(self.linear(x)) def forward_sigmoid(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.sigmoid(self.linear(x)) model = SimpleLinear() inputs = (torch.ones(10, 5),) with patch_forward(model, model.forward_relu): program_relu = to_edge(export(model, inputs)).to_executorch() with patch_forward(model, model.forward_sigmoid): program_sigmoid = to_edge(export(model, inputs)).to_executorch() exir_input = { "forward_relu": program_relu.exported_program(), "forward_sigmoid": program_sigmoid.exported_program(), } merged_program = emit_program(exir_input, False).program self.assertEqual(len(merged_program.execution_plan), 2) # reserved spot, weight, bias self.assertEqual( len(program_sigmoid._emitter_output.program.constant_buffer), 3, ) self.assertEqual( len(program_relu._emitter_output.program.constant_buffer), 3, ) # weights are shared between entry points so the merged one should deduplicate everything self.assertEqual(len(merged_program.constant_buffer), 3) self._compare_execution_plans( merged_program.execution_plan[0], program_relu._emitter_output.program.execution_plan[0], ) self._compare_execution_plans( merged_program.execution_plan[1], program_sigmoid._emitter_output.program.execution_plan[0], ) def test_emit_execution_plans_sorted(self) -> None: class Simple(torch.nn.Module): def __init__(self) -> None: super().__init__() def a(self, x: torch.Tensor) -> torch.Tensor: return x def b(self, x: torch.Tensor) -> torch.Tensor: return x def c(self, x: torch.Tensor) -> torch.Tensor: return x model = Simple() inputs = (torch.ones(10, 5),) def make_program( fn, inputs, ) -> "ExecutorchProgramManager": return to_edge( export( WrapperModule(fn), inputs, ) ).to_executorch() program_a = make_program(model.a, inputs) program_b = make_program(model.b, inputs) program_c = make_program(model.c, inputs) exir_input = { "b": program_b.exported_program(), "c": program_c.exported_program(), "a": program_a.exported_program(), } merged_program = emit_program(exir_input, False).program self.assertEqual(len(merged_program.execution_plan), 3) self.assertEqual(merged_program.execution_plan[0].name, "a") self.assertEqual(merged_program.execution_plan[1].name, "b") self.assertEqual(merged_program.execution_plan[2].name, "c") # Create a second program equivalent to the first, but the input is in a different order. # python dicts are instertion ordered exir_input2 = { "a": program_b.exported_program(), "b": program_c.exported_program(), "c": program_a.exported_program(), } merged_program2 = emit_program(exir_input2, False).program self.assertEqual( merged_program2.execution_plan[0], merged_program.execution_plan[0] ) self.assertEqual( merged_program2.execution_plan[1], merged_program.execution_plan[1] ) self.assertEqual( merged_program2.execution_plan[2], merged_program.execution_plan[2] ) def test_upper_bound_memory_planning_respect_input_constraints(self) -> None: class Foo(torch.nn.Module): def forward(self, k: torch.Tensor) -> torch.Tensor: k = torch.cat((k, torch.ones(1, 4))) return k func = Foo() k = torch.rand(2, 4) dim0_k = Dim("dim0_k", max=3) dynamic_shapes = {"k": {0: dim0_k}} captured = export( func, (k,), dynamic_shapes=dynamic_shapes, ) edge = to_edge(captured) from executorch.exir.passes import MemoryPlanningPass config = exir.ExecutorchBackendConfig( sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), memory_planning_pass=MemoryPlanningPass( # allow_lifetime_and_storage_overlap: bool = False, alloc_graph_input=True, alloc_graph_output=False, ), ) exe_prog = edge.to_executorch(config) program = exe_prog._emitter_output.program exir.print_program.pretty_print(exe_prog._emitter_output.program.execution_plan) execution_plan = program.execution_plan[0] self.check_tensor_buffer_loc(0, execution_plan.values, 0, 1, 0) self.check_tensor_buffer_loc(1, execution_plan.values, 0, 1, 48) def test_emit_prims(self) -> None: tensor_output = torch.rand(1, 4) tensor_list_output = [torch.rand(1, 4), torch.rand(1, 4)] class Simple(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) self.x: int = 3 self.y = 2 def get_ints(self) -> Tuple[int]: return (self.x, self.y) def get_str(self) -> str: return "foo" def get_tensor(self) -> torch.Tensor: return tensor_output def get_tensor_list(self) -> List[torch.Tensor]: return tensor_list_output def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.sigmoid(self.linear(x)) model = Simple() inputs = (torch.ones(10, 5),) program = to_edge(export(model, inputs)).to_executorch() exir_input = { "forward": program.exported_program(), } getters = {} getters["get_ints"] = model.get_ints() getters["get_str"] = model.get_str() getters["get_tensor"] = model.get_tensor() getters["get_tensor_list"] = model.get_tensor_list() merged_program = emit_program(exir_input, False, getters).program self.assertEqual(len(merged_program.execution_plan), 5) self.assertEqual( merged_program.execution_plan[0].name, "forward", ) self.assertEqual( merged_program.execution_plan[1].name, "get_ints", ) self.assertEqual( merged_program.execution_plan[2].name, "get_str", ) self.assertEqual( merged_program.execution_plan[3].name, "get_tensor", ) self.assertEqual( merged_program.execution_plan[4].name, "get_tensor_list", ) # no instructions in a getter self.assertEqual( len(merged_program.execution_plan[1].chains[0].instructions), 0, ) # 2 outputs for the flattened tuple self.assertEqual( len(merged_program.execution_plan[1].outputs), 2, ) # outputs are 0 and 1 in the values table self.assertEqual( merged_program.execution_plan[1].outputs, [0, 1], ) # value 0 is 3 self.assertEqual( # pyre-ignore merged_program.execution_plan[1].values[0].val.int_val, 3, ) self.assertEqual( # pyre-ignore merged_program.execution_plan[1].values[1].val.int_val, 2, ) self.assertEqual( len(merged_program.execution_plan[2].outputs), 1, ) self.assertEqual( # pyre-ignore merged_program.execution_plan[2].values[0].val.string_val, "foo", ) self.assertEqual(len(merged_program.execution_plan[3].outputs), 1) self.assertEqual(len(merged_program.execution_plan[4].outputs), 2) merged_program = to_edge( export(model, inputs), constant_methods=getters ).to_executorch() executorch_module = _load_for_executorch_from_buffer(merged_program.buffer) torch.allclose(executorch_module.run_method("get_tensor", [])[0], tensor_output) model_output = executorch_module.run_method("get_tensor_list", []) for i in range(len(tensor_list_output)): torch.allclose(model_output[i], tensor_list_output[i]) def test_emit_debug_handle_map(self) -> None: mul_model = Mul() program_mul = to_edge( export( mul_model, mul_model.get_random_inputs(), ) ).to_executorch() # this triggers the actual emission of the graph program_mul._emitter_output.program self.assertIsNotNone(program_mul.debug_handle_map) def test_final_graph_module_update_debug_handle(self) -> None: class SimpleAddMul(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: a = x + 1 return a * 2 mul_model = SimpleAddMul() program_mul = to_edge( export( mul_model, (torch.ones(2, 2),), ) ).to_executorch() # this triggers the actual emission of the graph program = program_mul._emitter_output.program node = None program.execution_plan[0].chains[0].instructions[ # pyre-ignore[16] 0 ].instr_args.op_index # Find the multiplication node in the graph that was emitted. for node in program_mul.exported_program().graph.nodes: if node.target == torch.ops.aten.mul.out: break self.assertIsNotNone(node) idx = 0 # Find the multiplication instruction in the program that was emitted. for idx in range(len(program.execution_plan[0].chains[0].instructions)): instruction = program.execution_plan[0].chains[0].instructions[idx] op_index = instruction.instr_args.op_index # pyre-ignore[16] if "mul" in program.execution_plan[0].operators[op_index].name: break # The instruction id of the multiplication instruction and the debug handle of the # multiplication node in the graph module (which was updated in the emitter to be # the same as the instruction id) must be the same. self.assertEqual( idx, node.meta.get("debug_handle"), ) def test_delegate_with_input_list(self) -> None: class BackendWithCompilerExample(BackendDetails): @staticmethod def preprocess( edge_program, compile_specs, ) -> bytes: return PreprocessResult( processed_bytes=bytes(str("test"), encoding="utf8"), debug_handle_map=None, ) class TestModel(nn.Module): def __init__(self): super(TestModel, self).__init__() def forward(self, x): return torch.cat(x) inputs = ([torch.ones(2, 2), torch.ones(2, 2)],) model = TestModel() edgeir_m = to_edge(export(model, inputs)) lowered_module = to_backend( "BackendWithCompilerExample", edgeir_m.exported_program(), [] ) class CompositeModule(torch.nn.Module): def __init__(self): super().__init__() self.lowered_module = lowered_module def forward(self, list_a): return self.lowered_module(list_a) composite_model = CompositeModule() exec_prog = to_edge( export(composite_model, inputs), ).to_executorch() exec_prog.buffer def test_delegate_with_input_tuple(self) -> None: class BackendWithCompilerExample(BackendDetails): @staticmethod def preprocess( edge_program, compile_specs, ) -> bytes: return PreprocessResult( processed_bytes=bytes(str("test"), encoding="utf8"), debug_handle_map=None, ) class AddMulModule(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input): # a, x, b): y = torch.mm(input[0], input[1]) z = torch.add(y, input[2]) return z model_inputs = ((torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)),) model = AddMulModule() edgeir_m = to_edge(export(model, model_inputs)) lowered_module = to_backend( "BackendWithCompilerExample", edgeir_m.exported_program(), [] ) class CompositeModule(torch.nn.Module): def __init__(self): super().__init__() self.lowered_module = lowered_module def forward(self, list_a): return self.lowered_module(list_a) composite_model = CompositeModule() exec_prog = to_edge( export(composite_model, model_inputs), ).to_executorch() exec_prog.buffer def test_delegate_mapping(self) -> None: debug_handle_map = {1: [1, 2]} class BackendWithCompilerExample(BackendDetails): @staticmethod def preprocess( edge_program, compile_specs, ) -> bytes: return PreprocessResult( processed_bytes=bytes(str("test"), encoding="utf8"), debug_handle_map=debug_handle_map, ) class TestModel(nn.Module): def __init__(self): super(TestModel, self).__init__() def forward(self, x, y): return torch.add(x, y) inputs = (torch.ones(2, 2), torch.ones(2, 2)) model = TestModel() edgeir_m = to_edge(export(model, inputs)) lowered_module = to_backend( "BackendWithCompilerExample", edgeir_m.exported_program(), [] ) class CompositeModule(torch.nn.Module): def __init__(self): super().__init__() self.lowered_module = lowered_module def forward(self, x, y): return self.lowered_module(x, y) composite_model = CompositeModule() exec_prog = to_edge( export(composite_model, inputs), ).to_executorch() # Reading the program triggers the call to emit_program underneath which # we need to be done for our test to succeed. exec_prog._emitter_output.program self.assertIsNotNone(exec_prog.delegate_map) self.assertIsNotNone(exec_prog.delegate_map.get("forward")) self.assertIsNotNone( exec_prog.delegate_map.get("forward").get(0) # pyre-ignore[16] ) self.assertEqual( exec_prog.delegate_map.get("forward").get(0).get("name"), "BackendWithCompilerExample", ) self.assertTrue( len(exec_prog.delegate_map.get("forward").get(0).get("delegate_map")) != 0 ) def test_emit_weight_view(self) -> None: class ModWithWeightViews(nn.Module): def __init__(self): super(ModWithWeightViews, self).__init__() self.W = torch.nn.Parameter(torch.randn(2)) self.W1 = self.W[:1] self.W2 = self.W[1:] def forward(self, x): return self.W1 + self.W2 + x model = ModWithWeightViews() # each weight is a view of the same storage self.assertEqual(model.W1.nbytes, 4) self.assertEqual(model.W1.untyped_storage().nbytes(), 8) self.assertEqual(model.W2.nbytes, 4) self.assertEqual(model.W2.untyped_storage().nbytes(), 8) program = to_edge( export( model, (torch.ones(1),), ) ).to_executorch() program = program._emitter_output.program # each emitted weight is not a view self.assertEqual(len(program.constant_buffer[1].storage), 4) self.assertEqual(len(program.constant_buffer[2].storage), 4) def test_non_persistent_buffer(self) -> None: class NonPersistentBuffer(nn.Module): def __init__(self): super(NonPersistentBuffer, self).__init__() self.register_buffer("buf", torch.tensor([1]), persistent=False) def forward(self, x): return x + self.buf model = NonPersistentBuffer() program = to_edge( export( model, (torch.ones(1),), ) ).to_executorch() program = program._emitter_output.program # confirm that the buffer was emitted self.assertEqual(len(program.constant_buffer), 2) self.assertEqual(len(program.constant_buffer[1].storage), 8) def test_emit_lifted_tensor_constant(self) -> None: class LiftedConstants(nn.Module): def __init__(self): super().__init__() def forward(self, x): x = x * torch.tensor([[4, 3], [1, 2], [5, 6]], dtype=torch.float) return x model = LiftedConstants() program = to_edge( export( model, (torch.ones(3, 2),), ) ).to_executorch() program = program._emitter_output.program exec_plan = program.execution_plan[0] # There should only be 1 input to this model. self.assertEqual(len(exec_plan.inputs), 1) self.assertEqual(len(program.constant_buffer), 2) self.assertEqual(len(program.constant_buffer[1].storage), 24) def test_mutable_buffers(self) -> None: def count_copies(gm: torch.fx.GraphModule) -> int: return sum( ( node.target == torch.ops.aten.copy_ or node.target == exir_ops.edge.aten.copy_.default ) for node in gm.graph.nodes ) class MutableStateModule(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("state", torch.zeros(1)) def forward(self, x): y = x + self.state self.state.add_(1) return y model = to_edge( export( MutableStateModule(), (torch.zeros(1),), ) ) model = model.to_executorch() model.dump_executorch_program(True) self.assertTrue( model.executorch_program.execution_plan[0] # pyre-ignore[16] .values[0] .val.allocation_info is not None ) executorch_module = _load_for_executorch_from_buffer(model.buffer) self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1)) self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1) def test_mutable_buffers_without_memplanned_inputs(self) -> None: def count_copies(gm: torch.fx.GraphModule) -> int: return sum( ( node.target == torch.ops.aten.copy_ or node.target == exir_ops.edge.aten.copy_.default ) for node in gm.graph.nodes ) class MutableStateModule(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("state", torch.zeros(1)) def forward(self, x): y = x + self.state self.state.add_(1) return y model = to_edge( export( MutableStateModule(), (torch.zeros(1),), ) ) model = model.to_executorch( config=ExecutorchBackendConfig( memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) model.dump_executorch_program(True) self.assertTrue( model.executorch_program.execution_plan[0] # pyre-ignore[16] .values[0] .val.allocation_info is not None ) executorch_module = _load_for_executorch_from_buffer(model.buffer) self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1)) self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1) def test_infinity_in_model(self) -> None: class InfinityMaskModel(nn.Module): def __init__(self): super().__init__() self.mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.float32) def forward(self, x): masked_weights = x.masked_fill(self.mask == 0, float("-inf")) return masked_weights model = to_edge( export( InfinityMaskModel(), (torch.randn(2, 2),), ) ) # Confirm that we can serialize the model with infinity in it. model = model.to_executorch() # Assert that the infinity is stored as a string "-inf". values = model.executorch_program.execution_plan[0].values self.assertEqual(values[5].val, Double(double_val=float("-inf"))) # Confirm that we can also deserialize the model with infinity in it. pte_data = deserialize_pte_binary(model.buffer) self.assertEqual( pte_data.execution_plan, model.executorch_program.execution_plan ) def test_mutate_input_tensor(self) -> None: class MutateInputTensorModule(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): x.add_(1) model = to_edge( export(MutateInputTensorModule(), (torch.zeros(1),)) ).to_executorch( config=ExecutorchBackendConfig( memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False) ) ) executorch_model = _load_for_executorch_from_buffer(model.buffer) input = torch.zeros(1) executorch_model(input) self.assertEqual(input, torch.ones(1))