xref: /aosp_15_r20/external/executorch/runtime/test/test_runtime.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport tempfile
8*523fa7a6SAndroid Build Coastguard Workerimport unittest
9*523fa7a6SAndroid Build Coastguard Workerfrom pathlib import Path
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport torch
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.pybindings.test.make_test import (
14*523fa7a6SAndroid Build Coastguard Worker    create_program,
15*523fa7a6SAndroid Build Coastguard Worker    ModuleAdd,
16*523fa7a6SAndroid Build Coastguard Worker    ModuleMulti,
17*523fa7a6SAndroid Build Coastguard Worker)
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.runtime import Runtime, Verification
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Workerclass RuntimeTest(unittest.TestCase):
22*523fa7a6SAndroid Build Coastguard Worker    def test_smoke(self):
23*523fa7a6SAndroid Build Coastguard Worker        ep, inputs = create_program(ModuleAdd())
24*523fa7a6SAndroid Build Coastguard Worker        runtime = Runtime.get()
25*523fa7a6SAndroid Build Coastguard Worker        # Demonstrate that get() returns a singleton.
26*523fa7a6SAndroid Build Coastguard Worker        runtime2 = Runtime.get()
27*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(runtime is runtime2)
28*523fa7a6SAndroid Build Coastguard Worker        program = runtime.load_program(ep.buffer, verification=Verification.Minimal)
29*523fa7a6SAndroid Build Coastguard Worker        method = program.load_method("forward")
30*523fa7a6SAndroid Build Coastguard Worker        outputs = method.execute(inputs)
31*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1]))
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker    def test_module_with_multiple_method_names(self):
34*523fa7a6SAndroid Build Coastguard Worker        ep, inputs = create_program(ModuleMulti())
35*523fa7a6SAndroid Build Coastguard Worker        runtime = Runtime.get()
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker        program = runtime.load_program(ep.buffer, verification=Verification.Minimal)
38*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(program.method_names, set({"forward", "forward2"}))
39*523fa7a6SAndroid Build Coastguard Worker        method = program.load_method("forward")
40*523fa7a6SAndroid Build Coastguard Worker        outputs = method.execute(inputs)
41*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1]))
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker        method = program.load_method("forward2")
44*523fa7a6SAndroid Build Coastguard Worker        outputs = method.execute(inputs)
45*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1] + 1))
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker    def test_print_operator_names(self):
48*523fa7a6SAndroid Build Coastguard Worker        ep, inputs = create_program(ModuleAdd())
49*523fa7a6SAndroid Build Coastguard Worker        runtime = Runtime.get()
50*523fa7a6SAndroid Build Coastguard Worker
51*523fa7a6SAndroid Build Coastguard Worker        operator_names = runtime.operator_registry.operator_names
52*523fa7a6SAndroid Build Coastguard Worker        self.assertGreater(len(operator_names), 0)
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker        self.assertIn("aten::add.out", operator_names)
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker    def test_load_program_with_path(self):
57*523fa7a6SAndroid Build Coastguard Worker        ep, inputs = create_program(ModuleAdd())
58*523fa7a6SAndroid Build Coastguard Worker        runtime = Runtime.get()
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker        def test_add(program):
61*523fa7a6SAndroid Build Coastguard Worker            method = program.load_method("forward")
62*523fa7a6SAndroid Build Coastguard Worker            outputs = method.execute(inputs)
63*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1]))
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
66*523fa7a6SAndroid Build Coastguard Worker            f.write(ep.buffer)
67*523fa7a6SAndroid Build Coastguard Worker            f.flush()
68*523fa7a6SAndroid Build Coastguard Worker            # filename
69*523fa7a6SAndroid Build Coastguard Worker            program = runtime.load_program(f.name)
70*523fa7a6SAndroid Build Coastguard Worker            test_add(program)
71*523fa7a6SAndroid Build Coastguard Worker            # pathlib.Path
72*523fa7a6SAndroid Build Coastguard Worker            path = Path(f.name)
73*523fa7a6SAndroid Build Coastguard Worker            program = runtime.load_program(path)
74*523fa7a6SAndroid Build Coastguard Worker            test_add(program)
75*523fa7a6SAndroid Build Coastguard Worker            # BytesIO
76*523fa7a6SAndroid Build Coastguard Worker            with open(f.name, "rb") as f:
77*523fa7a6SAndroid Build Coastguard Worker                program = runtime.load_program(f.read())
78*523fa7a6SAndroid Build Coastguard Worker                test_add(program)
79