1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport tempfile 8*523fa7a6SAndroid Build Coastguard Workerimport unittest 9*523fa7a6SAndroid Build Coastguard Workerfrom pathlib import Path 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport torch 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.pybindings.test.make_test import ( 14*523fa7a6SAndroid Build Coastguard Worker create_program, 15*523fa7a6SAndroid Build Coastguard Worker ModuleAdd, 16*523fa7a6SAndroid Build Coastguard Worker ModuleMulti, 17*523fa7a6SAndroid Build Coastguard Worker) 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.runtime import Runtime, Verification 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Workerclass RuntimeTest(unittest.TestCase): 22*523fa7a6SAndroid Build Coastguard Worker def test_smoke(self): 23*523fa7a6SAndroid Build Coastguard Worker ep, inputs = create_program(ModuleAdd()) 24*523fa7a6SAndroid Build Coastguard Worker runtime = Runtime.get() 25*523fa7a6SAndroid Build Coastguard Worker # Demonstrate that get() returns a singleton. 26*523fa7a6SAndroid Build Coastguard Worker runtime2 = Runtime.get() 27*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(runtime is runtime2) 28*523fa7a6SAndroid Build Coastguard Worker program = runtime.load_program(ep.buffer, verification=Verification.Minimal) 29*523fa7a6SAndroid Build Coastguard Worker method = program.load_method("forward") 30*523fa7a6SAndroid Build Coastguard Worker outputs = method.execute(inputs) 31*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1])) 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker def test_module_with_multiple_method_names(self): 34*523fa7a6SAndroid Build Coastguard Worker ep, inputs = create_program(ModuleMulti()) 35*523fa7a6SAndroid Build Coastguard Worker runtime = Runtime.get() 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker program = runtime.load_program(ep.buffer, verification=Verification.Minimal) 38*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(program.method_names, set({"forward", "forward2"})) 39*523fa7a6SAndroid Build Coastguard Worker method = program.load_method("forward") 40*523fa7a6SAndroid Build Coastguard Worker outputs = method.execute(inputs) 41*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1])) 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker method = program.load_method("forward2") 44*523fa7a6SAndroid Build Coastguard Worker outputs = method.execute(inputs) 45*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1] + 1)) 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard Worker def test_print_operator_names(self): 48*523fa7a6SAndroid Build Coastguard Worker ep, inputs = create_program(ModuleAdd()) 49*523fa7a6SAndroid Build Coastguard Worker runtime = Runtime.get() 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Worker operator_names = runtime.operator_registry.operator_names 52*523fa7a6SAndroid Build Coastguard Worker self.assertGreater(len(operator_names), 0) 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker self.assertIn("aten::add.out", operator_names) 55*523fa7a6SAndroid Build Coastguard Worker 56*523fa7a6SAndroid Build Coastguard Worker def test_load_program_with_path(self): 57*523fa7a6SAndroid Build Coastguard Worker ep, inputs = create_program(ModuleAdd()) 58*523fa7a6SAndroid Build Coastguard Worker runtime = Runtime.get() 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker def test_add(program): 61*523fa7a6SAndroid Build Coastguard Worker method = program.load_method("forward") 62*523fa7a6SAndroid Build Coastguard Worker outputs = method.execute(inputs) 63*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1])) 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 66*523fa7a6SAndroid Build Coastguard Worker f.write(ep.buffer) 67*523fa7a6SAndroid Build Coastguard Worker f.flush() 68*523fa7a6SAndroid Build Coastguard Worker # filename 69*523fa7a6SAndroid Build Coastguard Worker program = runtime.load_program(f.name) 70*523fa7a6SAndroid Build Coastguard Worker test_add(program) 71*523fa7a6SAndroid Build Coastguard Worker # pathlib.Path 72*523fa7a6SAndroid Build Coastguard Worker path = Path(f.name) 73*523fa7a6SAndroid Build Coastguard Worker program = runtime.load_program(path) 74*523fa7a6SAndroid Build Coastguard Worker test_add(program) 75*523fa7a6SAndroid Build Coastguard Worker # BytesIO 76*523fa7a6SAndroid Build Coastguard Worker with open(f.name, "rb") as f: 77*523fa7a6SAndroid Build Coastguard Worker program = runtime.load_program(f.read()) 78*523fa7a6SAndroid Build Coastguard Worker test_add(program) 79