xref: /aosp_15_r20/external/executorch/build/packaging/smoke_test.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker#
5*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker
8*523fa7a6SAndroid Build Coastguard Worker"""
9*523fa7a6SAndroid Build Coastguard WorkerThis script is run by CI after building the executorch wheel. Before running
10*523fa7a6SAndroid Build Coastguard Workerthis, the job will install the matching torch package as well as the newly-built
11*523fa7a6SAndroid Build Coastguard Workerexecutorch package and its dependencies.
12*523fa7a6SAndroid Build Coastguard Worker"""
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Worker# Import this first. If it can't find the torch.so libraries, the dynamic load
15*523fa7a6SAndroid Build Coastguard Worker# will fail and the process will exit.
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.pybindings import portable_lib  # usort: skip
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker# Import custom ops. This requires portable_lib to be loaded first.
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.llm.custom_ops import (  # noqa: F401, F403
20*523fa7a6SAndroid Build Coastguard Worker    sdpa_with_kv_cache,
21*523fa7a6SAndroid Build Coastguard Worker)  # usort: skip
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Worker# Import quantized ops. This requires portable_lib to be loaded first.
24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.kernels import quantized  # usort: skip # noqa: F401, F403
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker# Import this after importing the ExecuTorch pybindings. If the pybindings
27*523fa7a6SAndroid Build Coastguard Worker# links against a different torch.so than this uses, there will be a set of
28*523fa7a6SAndroid Build Coastguard Worker# symbol comflicts; the process will either exit now, or there will be issues
29*523fa7a6SAndroid Build Coastguard Worker# later in the smoke test.
30*523fa7a6SAndroid Build Coastguard Workerimport torch  # usort: skip
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard Worker# Import everything else later to help isolate the critical imports above.
33*523fa7a6SAndroid Build Coastguard Workerimport os
34*523fa7a6SAndroid Build Coastguard Workerimport tempfile
35*523fa7a6SAndroid Build Coastguard Workerfrom typing import Tuple
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import to_edge
38*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Workerclass LinearModel(torch.nn.Module):
42*523fa7a6SAndroid Build Coastguard Worker    """Runs Linear on its input, which should have shape [4]."""
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
45*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
46*523fa7a6SAndroid Build Coastguard Worker        self.linear = torch.nn.Linear(4, 2)
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x: torch.Tensor):
49*523fa7a6SAndroid Build Coastguard Worker        """Expects a single tensor of shape [4]."""
50*523fa7a6SAndroid Build Coastguard Worker        return self.linear(x)
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Workerdef linear_model_inputs() -> Tuple[torch.Tensor]:
54*523fa7a6SAndroid Build Coastguard Worker    """Returns some example inputs compatible with LinearModel."""
55*523fa7a6SAndroid Build Coastguard Worker    # The model takes a single tensor of shape [4] as an input.
56*523fa7a6SAndroid Build Coastguard Worker    return (torch.ones(4),)
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Worker
59*523fa7a6SAndroid Build Coastguard Workerdef export_linear_model() -> bytes:
60*523fa7a6SAndroid Build Coastguard Worker    """Exports LinearModel and returns the .pte data."""
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker    # This helps the exporter understand the shapes of tensors used in the model.
63*523fa7a6SAndroid Build Coastguard Worker    # Since our model only takes one input, this is a one-tuple.
64*523fa7a6SAndroid Build Coastguard Worker    example_inputs = linear_model_inputs()
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Worker    # Export the pytorch model and process for ExecuTorch.
67*523fa7a6SAndroid Build Coastguard Worker    print("Exporting program...")
68*523fa7a6SAndroid Build Coastguard Worker    exported_program = export(LinearModel(), example_inputs)
69*523fa7a6SAndroid Build Coastguard Worker    print("Lowering to edge...")
70*523fa7a6SAndroid Build Coastguard Worker    edge_program = to_edge(exported_program)
71*523fa7a6SAndroid Build Coastguard Worker    print("Creating ExecuTorch program...")
72*523fa7a6SAndroid Build Coastguard Worker    et_program = edge_program.to_executorch()
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker    return et_program.buffer
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Workerdef main():
78*523fa7a6SAndroid Build Coastguard Worker    """Tests the export and execution of a simple model."""
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker    # If the pybindings loaded correctly, we should be able to ask for the set
81*523fa7a6SAndroid Build Coastguard Worker    # of operators.
82*523fa7a6SAndroid Build Coastguard Worker    ops = portable_lib._get_operator_names()
83*523fa7a6SAndroid Build Coastguard Worker    assert len(ops) > 0, "Empty operator list"
84*523fa7a6SAndroid Build Coastguard Worker    print(f"Found {len(ops)} operators; first element '{ops[0]}'")
85*523fa7a6SAndroid Build Coastguard Worker
86*523fa7a6SAndroid Build Coastguard Worker    # Make sure custom ops are registered.
87*523fa7a6SAndroid Build Coastguard Worker    assert (
88*523fa7a6SAndroid Build Coastguard Worker        "llama::sdpa_with_kv_cache.out" in ops
89*523fa7a6SAndroid Build Coastguard Worker    ), f"llama::sdpa_with_kv_cache.out not registered, Got ops: {ops}"
90*523fa7a6SAndroid Build Coastguard Worker
91*523fa7a6SAndroid Build Coastguard Worker    # Make sure quantized ops are registered.
92*523fa7a6SAndroid Build Coastguard Worker    assert (
93*523fa7a6SAndroid Build Coastguard Worker        "quantized_decomposed::add.out" in ops
94*523fa7a6SAndroid Build Coastguard Worker    ), f"quantized_decomposed::add.out not registered, Got ops: {ops}"
95*523fa7a6SAndroid Build Coastguard Worker    # Export LinearModel to .pte data.
96*523fa7a6SAndroid Build Coastguard Worker    pte_data: bytes = export_linear_model()
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Worker    # Try saving to and loading from a file.
99*523fa7a6SAndroid Build Coastguard Worker    with tempfile.TemporaryDirectory() as tempdir:
100*523fa7a6SAndroid Build Coastguard Worker        pte_file = os.path.join(tempdir, "linear.pte")
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker        # Save the .pte data to a file.
103*523fa7a6SAndroid Build Coastguard Worker        with open(pte_file, "wb") as file:
104*523fa7a6SAndroid Build Coastguard Worker            file.write(pte_data)
105*523fa7a6SAndroid Build Coastguard Worker            print(f"ExecuTorch program saved to {pte_file} ({len(pte_data)} bytes).")
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker        # Load the model from disk.
108*523fa7a6SAndroid Build Coastguard Worker        m = portable_lib._load_for_executorch(pte_file)
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard Worker        # Run the model.
111*523fa7a6SAndroid Build Coastguard Worker        outputs = m.forward(linear_model_inputs())
112*523fa7a6SAndroid Build Coastguard Worker
113*523fa7a6SAndroid Build Coastguard Worker        # Should see a single output with shape [2].
114*523fa7a6SAndroid Build Coastguard Worker        assert len(outputs) == 1, f"Unexpected output length {len(outputs)}: {outputs}"
115*523fa7a6SAndroid Build Coastguard Worker        assert outputs[0].shape == (2,), f"Unexpected output size {outputs[0].shape}"
116*523fa7a6SAndroid Build Coastguard Worker
117*523fa7a6SAndroid Build Coastguard Worker    print("PASS")
118*523fa7a6SAndroid Build Coastguard Worker
119*523fa7a6SAndroid Build Coastguard Worker
120*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__":
121*523fa7a6SAndroid Build Coastguard Worker    main()
122