1*523fa7a6SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 3*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 4*523fa7a6SAndroid Build Coastguard Worker# 5*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 6*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 7*523fa7a6SAndroid Build Coastguard Worker 8*523fa7a6SAndroid Build Coastguard Worker""" 9*523fa7a6SAndroid Build Coastguard WorkerThis script is run by CI after building the executorch wheel. Before running 10*523fa7a6SAndroid Build Coastguard Workerthis, the job will install the matching torch package as well as the newly-built 11*523fa7a6SAndroid Build Coastguard Workerexecutorch package and its dependencies. 12*523fa7a6SAndroid Build Coastguard Worker""" 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Worker# Import this first. If it can't find the torch.so libraries, the dynamic load 15*523fa7a6SAndroid Build Coastguard Worker# will fail and the process will exit. 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.pybindings import portable_lib # usort: skip 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Worker# Import custom ops. This requires portable_lib to be loaded first. 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.llm.custom_ops import ( # noqa: F401, F403 20*523fa7a6SAndroid Build Coastguard Worker sdpa_with_kv_cache, 21*523fa7a6SAndroid Build Coastguard Worker) # usort: skip 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Worker# Import quantized ops. This requires portable_lib to be loaded first. 24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.kernels import quantized # usort: skip # noqa: F401, F403 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Worker# Import this after importing the ExecuTorch pybindings. If the pybindings 27*523fa7a6SAndroid Build Coastguard Worker# links against a different torch.so than this uses, there will be a set of 28*523fa7a6SAndroid Build Coastguard Worker# symbol comflicts; the process will either exit now, or there will be issues 29*523fa7a6SAndroid Build Coastguard Worker# later in the smoke test. 30*523fa7a6SAndroid Build Coastguard Workerimport torch # usort: skip 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard Worker# Import everything else later to help isolate the critical imports above. 33*523fa7a6SAndroid Build Coastguard Workerimport os 34*523fa7a6SAndroid Build Coastguard Workerimport tempfile 35*523fa7a6SAndroid Build Coastguard Workerfrom typing import Tuple 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import to_edge 38*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Workerclass LinearModel(torch.nn.Module): 42*523fa7a6SAndroid Build Coastguard Worker """Runs Linear on its input, which should have shape [4].""" 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 45*523fa7a6SAndroid Build Coastguard Worker super().__init__() 46*523fa7a6SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 2) 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor): 49*523fa7a6SAndroid Build Coastguard Worker """Expects a single tensor of shape [4].""" 50*523fa7a6SAndroid Build Coastguard Worker return self.linear(x) 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Workerdef linear_model_inputs() -> Tuple[torch.Tensor]: 54*523fa7a6SAndroid Build Coastguard Worker """Returns some example inputs compatible with LinearModel.""" 55*523fa7a6SAndroid Build Coastguard Worker # The model takes a single tensor of shape [4] as an input. 56*523fa7a6SAndroid Build Coastguard Worker return (torch.ones(4),) 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Worker 59*523fa7a6SAndroid Build Coastguard Workerdef export_linear_model() -> bytes: 60*523fa7a6SAndroid Build Coastguard Worker """Exports LinearModel and returns the .pte data.""" 61*523fa7a6SAndroid Build Coastguard Worker 62*523fa7a6SAndroid Build Coastguard Worker # This helps the exporter understand the shapes of tensors used in the model. 63*523fa7a6SAndroid Build Coastguard Worker # Since our model only takes one input, this is a one-tuple. 64*523fa7a6SAndroid Build Coastguard Worker example_inputs = linear_model_inputs() 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Worker # Export the pytorch model and process for ExecuTorch. 67*523fa7a6SAndroid Build Coastguard Worker print("Exporting program...") 68*523fa7a6SAndroid Build Coastguard Worker exported_program = export(LinearModel(), example_inputs) 69*523fa7a6SAndroid Build Coastguard Worker print("Lowering to edge...") 70*523fa7a6SAndroid Build Coastguard Worker edge_program = to_edge(exported_program) 71*523fa7a6SAndroid Build Coastguard Worker print("Creating ExecuTorch program...") 72*523fa7a6SAndroid Build Coastguard Worker et_program = edge_program.to_executorch() 73*523fa7a6SAndroid Build Coastguard Worker 74*523fa7a6SAndroid Build Coastguard Worker return et_program.buffer 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Workerdef main(): 78*523fa7a6SAndroid Build Coastguard Worker """Tests the export and execution of a simple model.""" 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Worker # If the pybindings loaded correctly, we should be able to ask for the set 81*523fa7a6SAndroid Build Coastguard Worker # of operators. 82*523fa7a6SAndroid Build Coastguard Worker ops = portable_lib._get_operator_names() 83*523fa7a6SAndroid Build Coastguard Worker assert len(ops) > 0, "Empty operator list" 84*523fa7a6SAndroid Build Coastguard Worker print(f"Found {len(ops)} operators; first element '{ops[0]}'") 85*523fa7a6SAndroid Build Coastguard Worker 86*523fa7a6SAndroid Build Coastguard Worker # Make sure custom ops are registered. 87*523fa7a6SAndroid Build Coastguard Worker assert ( 88*523fa7a6SAndroid Build Coastguard Worker "llama::sdpa_with_kv_cache.out" in ops 89*523fa7a6SAndroid Build Coastguard Worker ), f"llama::sdpa_with_kv_cache.out not registered, Got ops: {ops}" 90*523fa7a6SAndroid Build Coastguard Worker 91*523fa7a6SAndroid Build Coastguard Worker # Make sure quantized ops are registered. 92*523fa7a6SAndroid Build Coastguard Worker assert ( 93*523fa7a6SAndroid Build Coastguard Worker "quantized_decomposed::add.out" in ops 94*523fa7a6SAndroid Build Coastguard Worker ), f"quantized_decomposed::add.out not registered, Got ops: {ops}" 95*523fa7a6SAndroid Build Coastguard Worker # Export LinearModel to .pte data. 96*523fa7a6SAndroid Build Coastguard Worker pte_data: bytes = export_linear_model() 97*523fa7a6SAndroid Build Coastguard Worker 98*523fa7a6SAndroid Build Coastguard Worker # Try saving to and loading from a file. 99*523fa7a6SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tempdir: 100*523fa7a6SAndroid Build Coastguard Worker pte_file = os.path.join(tempdir, "linear.pte") 101*523fa7a6SAndroid Build Coastguard Worker 102*523fa7a6SAndroid Build Coastguard Worker # Save the .pte data to a file. 103*523fa7a6SAndroid Build Coastguard Worker with open(pte_file, "wb") as file: 104*523fa7a6SAndroid Build Coastguard Worker file.write(pte_data) 105*523fa7a6SAndroid Build Coastguard Worker print(f"ExecuTorch program saved to {pte_file} ({len(pte_data)} bytes).") 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker # Load the model from disk. 108*523fa7a6SAndroid Build Coastguard Worker m = portable_lib._load_for_executorch(pte_file) 109*523fa7a6SAndroid Build Coastguard Worker 110*523fa7a6SAndroid Build Coastguard Worker # Run the model. 111*523fa7a6SAndroid Build Coastguard Worker outputs = m.forward(linear_model_inputs()) 112*523fa7a6SAndroid Build Coastguard Worker 113*523fa7a6SAndroid Build Coastguard Worker # Should see a single output with shape [2]. 114*523fa7a6SAndroid Build Coastguard Worker assert len(outputs) == 1, f"Unexpected output length {len(outputs)}: {outputs}" 115*523fa7a6SAndroid Build Coastguard Worker assert outputs[0].shape == (2,), f"Unexpected output size {outputs[0].shape}" 116*523fa7a6SAndroid Build Coastguard Worker 117*523fa7a6SAndroid Build Coastguard Worker print("PASS") 118*523fa7a6SAndroid Build Coastguard Worker 119*523fa7a6SAndroid Build Coastguard Worker 120*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__": 121*523fa7a6SAndroid Build Coastguard Worker main() 122