1# Intro to LLMs in Executorch 2 3Welcome to LLM Manual! This manual is designed to provide a practical example to leverage 4ExecuTorch in onboarding your own Large Language Models (LLMs). Our primary goal is to offer 5 a clear and concise guideline on how to integrate our system with your own LLMs. 6 7Please note that this project is intended as a demonstration and not as a fully functional 8example with optimal performance. As such, certain components such as the sampler, tokenizer, 9and others are provided in their bare minimum versions solely for demonstration purposes. 10Consequently, the results produced by the model may vary and might not always be optimal. 11 12We encourage users to use this project as a starting point and adapt it to their specific needs, 13which includes creating your own versions of the tokenizer, sampler, acceleration backends, and 14other components. We hope this project serves as a useful guide in your journey with LLMs and ExecuTorch. 15 16For deploying Llama with optimal performance, please see [Llama guide](./llama.md). 17 18### Table Of Contents 19 20 211. Prerequisites 222. Hello World Example 233. Quantization 244. Using Mobile Acceleration 255. Debugging and Profiling 266. How to use custom kernels 277. How to build mobile apps 28 29 30## Prerequisites 31 32To follow this guide, you'll need to clone the ExecuTorch repository and install dependencies. 33ExecuTorch recommends Python 3.10 and the use of Conda to manage your environment. Conda is not 34required, though be aware that you may need to replace the use of python/pip with python3/pip3 35depending on your environment. 36 37::::{tab-set} 38:::{tab-item} conda 39Instructions on installing miniconda can be [found here](https://docs.anaconda.com/free/miniconda). 40 41``` 42# Create a directory for this example. 43mkdir et-nanogpt 44cd et-nanogpt 45 46# Clone the ExecuTorch repository and submodules. 47mkdir third-party 48git clone -b release/0.4 https://github.com/pytorch/executorch.git third-party/executorch 49cd third-party/executorch 50git submodule update --init 51 52# Create a conda environment and install requirements. 53conda create -yn executorch python=3.10.0 54conda activate executorch 55./install_requirements.sh 56 57cd ../.. 58``` 59::: 60:::{tab-item} pyenv-virtualenv 61Instructions on installing pyenv-virtualenv can be [found here](https://github.com/pyenv/pyenv-virtualenv?tab=readme-ov-file#installing-with-homebrew-for-macos-users). 62 63Importantly, if installing pyenv through brew, it does not automatically enable pyenv in the terminal, leading to errors. Run the following commands to enable. 64See the pyenv-virtualenv installation guide above on how to add this to your .bashrc or .zshrc to avoid needing to run these commands manually. 65``` 66eval "$(pyenv init -)" 67eval "$(pyenv virtualenv-init -)" 68``` 69 70``` 71# Create a directory for this example. 72mkdir et-nanogpt 73cd et-nanogpt 74 75pyenv install -s 3.10 76pyenv virtualenv 3.10 executorch 77pyenv activate executorch 78 79# Clone the ExecuTorch repository and submodules. 80mkdir third-party 81git clone -b release/0.4 https://github.com/pytorch/executorch.git third-party/executorch 82cd third-party/executorch 83git submodule update --init 84 85# Install requirements. 86PYTHON_EXECUTABLE=python ./install_requirements.sh 87 88cd ../.. 89``` 90::: 91:::: 92 93For more information, see [Setting Up ExecuTorch](../getting-started-setup.md). 94 95 96## Running a Large Language Model Locally 97 98This example uses Karpathy’s [nanoGPT](https://github.com/karpathy/nanoGPT), which is a minimal implementation of 99GPT-2 124M. This guide is applicable to other language models, as ExecuTorch is model-invariant. 100 101There are two steps to running a model with ExecuTorch: 102 1031. Export the model. This step preprocesses it into a format suitable for runtime execution. 1042. At runtime, load the model file and run with the ExecuTorch runtime. 105 106<br /> 107 108The export step happens ahead of time, typically as part of the application build or when the model changes. The resultant 109.pte file is distributed with the application. At runtime, the application loads the .pte file and passes it to the 110ExecuTorch runtime. 111 112### Step 1. Exporting to ExecuTorch 113 114Exporting takes a PyTorch model and converts it into a format that can run efficiently on consumer devices. 115 116For this example, you will need the nanoGPT model and the corresponding tokenizer vocabulary. 117 118::::{tab-set} 119:::{tab-item} curl 120``` 121curl https://raw.githubusercontent.com/karpathy/nanoGPT/master/model.py -O 122curl https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json -O 123``` 124::: 125:::{tab-item} wget 126``` 127wget https://raw.githubusercontent.com/karpathy/nanoGPT/master/model.py 128wget https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json 129``` 130::: 131:::: 132 133To convert the model into a format optimized for standalone execution, there are two steps. First, use the PyTorch 134`export` function to convert the PyTorch model into an intermediate, platform-independent intermediate representation. Then 135use the ExecuTorch `to_edge` and `to_executorch` methods to prepare the model for on-device execution. This creates a .pte 136file which can be loaded by a desktop or mobile application at runtime. 137 138Create a file called export_nanogpt.py with the following contents: 139 140```python 141# export_nanogpt.py 142 143import torch 144 145from executorch.exir import EdgeCompileConfig, to_edge 146from torch.nn.attention import sdpa_kernel, SDPBackend 147from torch.export import export, export_for_training 148 149from model import GPT 150 151# Load the model. 152model = GPT.from_pretrained('gpt2') 153 154# Create example inputs. This is used in the export process to provide 155# hints on the expected shape of the model input. 156example_inputs = (torch.randint(0, 100, (1, model.config.block_size), dtype=torch.long), ) 157 158# Set up dynamic shape configuration. This allows the sizes of the input tensors 159# to differ from the sizes of the tensors in `example_inputs` during runtime, as 160# long as they adhere to the rules specified in the dynamic shape configuration. 161# Here we set the range of 0th model input's 1st dimension as 162# [0, model.config.block_size]. 163# See https://pytorch.org/executorch/main/concepts.html#dynamic-shapes 164# for details about creating dynamic shapes. 165dynamic_shape = ( 166 {1: torch.export.Dim("token_dim", max=model.config.block_size)}, 167) 168 169# Trace the model, converting it to a portable intermediate representation. 170# The torch.no_grad() call tells PyTorch to exclude training-specific logic. 171with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): 172 m = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shape).module() 173 traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape) 174 175# Convert the model into a runnable ExecuTorch program. 176edge_config = EdgeCompileConfig(_check_ir_validity=False) 177edge_manager = to_edge(traced_model, compile_config=edge_config) 178et_program = edge_manager.to_executorch() 179 180# Save the ExecuTorch program to a file. 181with open("nanogpt.pte", "wb") as file: 182 file.write(et_program.buffer) 183``` 184 185To export, run the script with `python export_nanogpt.py` (or python3, as appropriate for your environment). It will generate a `nanogpt.pte` file in the current directory. 186 187For more information, see [Exporting to ExecuTorch](../tutorials/export-to-executorch-tutorial) and 188[torch.export](https://pytorch.org/docs/stable/export.html). 189 190### Step 2. Invoking the Runtime 191 192ExecuTorch provides a set of runtime APIs and types to load and run models. 193 194Create a file called main.cpp with the following contents: 195 196```cpp 197// main.cpp 198 199#include <cstdint> 200 201#include "basic_sampler.h" 202#include "basic_tokenizer.h" 203 204#include <executorch/extension/module/module.h> 205#include <executorch/extension/tensor/tensor.h> 206#include <executorch/runtime/core/evalue.h> 207#include <executorch/runtime/core/exec_aten/exec_aten.h> 208#include <executorch/runtime/core/result.h> 209 210using executorch::aten::ScalarType; 211using executorch::aten::Tensor; 212using executorch::extension::from_blob; 213using executorch::extension::Module; 214using executorch::runtime::EValue; 215using executorch::runtime::Result; 216``` 217 218The model inputs and outputs take the form of tensors. A tensor can be thought of as an multi-dimensional array. 219The ExecuTorch `EValue` class provides a wrapper around tensors and other ExecuTorch data types. 220 221Since the LLM generates one token at a time, the driver code needs to repeatedly invoke the model, building the 222output token by token. Each generated token is passed as input for the next run. 223 224```cpp 225// main.cpp 226 227// The value of the gpt2 `<|endoftext|>` token. 228#define ENDOFTEXT_TOKEN 50256 229 230std::string generate( 231 Module& llm_model, 232 std::string& prompt, 233 BasicTokenizer& tokenizer, 234 BasicSampler& sampler, 235 size_t max_input_length, 236 size_t max_output_length) { 237 // Convert the input text into a list of integers (tokens) that represents it, 238 // using the string-to-token mapping that the model was trained on. Each token 239 // is an integer that represents a word or part of a word. 240 std::vector<int64_t> input_tokens = tokenizer.encode(prompt); 241 std::vector<int64_t> output_tokens; 242 243 for (auto i = 0u; i < max_output_length; i++) { 244 // Convert the input_tokens from a vector of int64_t to EValue. EValue is a 245 // unified data type in the ExecuTorch runtime. 246 auto inputs = from_blob( 247 input_tokens.data(), 248 {1, static_cast<int>(input_tokens.size())}, 249 ScalarType::Long); 250 251 // Run the model. It will return a tensor of logits (log-probabilities). 252 auto logits_evalue = llm_model.forward(inputs); 253 254 // Convert the output logits from EValue to std::vector, which is what the 255 // sampler expects. 256 Tensor logits_tensor = logits_evalue.get()[0].toTensor(); 257 std::vector<float> logits( 258 logits_tensor.data_ptr<float>(), 259 logits_tensor.data_ptr<float>() + logits_tensor.numel()); 260 261 // Sample the next token from the logits. 262 int64_t next_token = sampler.sample(logits); 263 264 // Break if we reached the end of the text. 265 if (next_token == ENDOFTEXT_TOKEN) { 266 break; 267 } 268 269 // Add the next token to the output. 270 output_tokens.push_back(next_token); 271 272 std::cout << tokenizer.decode({next_token}); 273 std::cout.flush(); 274 275 // Update next input. 276 input_tokens.push_back(next_token); 277 if (input_tokens.size() > max_input_length) { 278 input_tokens.erase(input_tokens.begin()); 279 } 280 } 281 282 std::cout << std::endl; 283 284 // Convert the output tokens into a human-readable string. 285 std::string output_string = tokenizer.decode(output_tokens); 286 return output_string; 287} 288``` 289 290The `Module` class handles loading the .pte file and preparing for execution. 291 292The tokenizer is responsible for converting from a human-readable string representation of the prompt to the 293numerical form expected by the model. To do this, the tokenzier associates short substrings with a given token ID. 294The tokens can be thought of as representing words or parts of words, though, in-practice, they may be arbitrary 295sequences of characters. 296 297The tokenizer loads the vocabulary from a file, which contains the mapping between each token ID and the text it 298represents. Call `tokenizer.encode()` and `tokenizer.decode()` to convert between string and token representations. 299 300The sampler is responsible for selecting the next token, based on the logits, or log-probabilties, output by the 301model. The LLM returns a logit value for each possible next token. The sampler chooses which token to use based 302on some strategy. The simplest approach, used here, is to take the token with the highest logit value. 303 304Samplers may provide configurable options, such as configurable amount of randomness to the outputs selection, 305penalties for repeated tokens, and biases to prioritize or de-prioritize specific tokens. 306 307 308```cpp 309// main.cpp 310 311int main() { 312 // Set up the prompt. This provides the seed text for the model to elaborate. 313 std::cout << "Enter model prompt: "; 314 std::string prompt; 315 std::getline(std::cin, prompt); 316 317 // The tokenizer is used to convert between tokens (used by the model) and 318 // human-readable strings. 319 BasicTokenizer tokenizer("vocab.json"); 320 321 // The sampler is used to sample the next token from the logits. 322 BasicSampler sampler = BasicSampler(); 323 324 // Load the exported nanoGPT program, which was generated via the previous 325 // steps. 326 Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors); 327 328 const auto max_input_tokens = 1024; 329 const auto max_output_tokens = 30; 330 std::cout << prompt; 331 generate( 332 model, prompt, tokenizer, sampler, max_input_tokens, max_output_tokens); 333} 334``` 335 336Finally, download the following files into the same directory as main.cpp: 337 338``` 339curl -O https://raw.githubusercontent.com/pytorch/executorch/main/examples/llm_manual/basic_sampler.h 340curl -O https://raw.githubusercontent.com/pytorch/executorch/main/examples/llm_manual/basic_tokenizer.h 341``` 342 343To learn more, see the [Runtime APIs Tutorial](../extension-module.md). 344 345### Building and Running 346 347ExecuTorch uses the CMake build system. To compile and link against the ExecuTorch runtime, 348include the ExecuTorch project via `add_directory` and link against `executorch` and additional 349dependencies. 350 351Create a file named CMakeLists.txt with the following content: 352 353``` 354# CMakeLists.txt 355 356cmake_minimum_required(VERSION 3.19) 357project(nanogpt_runner) 358 359set(CMAKE_CXX_STANDARD 17) 360set(CMAKE_CXX_STANDARD_REQUIRED True) 361 362# Set options for executorch build. 363option(EXECUTORCH_ENABLE_LOGGING "" ON) 364option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON) 365option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON) 366option(EXECUTORCH_BUILD_EXTENSION_TENSOR "" ON) 367option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON) 368 369# Include the executorch subdirectory. 370add_subdirectory( 371 ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch 372 ${CMAKE_BINARY_DIR}/executorch 373) 374 375add_executable(nanogpt_runner main.cpp) 376target_link_libraries( 377 nanogpt_runner 378 PRIVATE executorch 379 extension_module_static # Provides the Module class 380 extension_tensor # Provides the TensorPtr class 381 optimized_native_cpu_ops_lib # Provides baseline cross-platform 382 # kernels 383) 384``` 385 386At this point, the working directory should contain the following files: 387 388- CMakeLists.txt 389- main.cpp 390- basic_tokenizer.h 391- basic_sampler.h 392- export_nanogpt.py 393- model.py 394- vocab.json 395- nanogpt.pte 396 397If all of these are present, you can now build and run: 398```bash 399(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..) 400cmake --build cmake-out -j10 401./cmake-out/nanogpt_runner 402``` 403 404You should see the message: 405 406``` 407Enter model prompt: 408``` 409 410Type some seed text for the model and press enter. Here we use "Hello world!" as 411an example prompt: 412 413``` 414Enter model prompt: Hello world! 415Hello world! 416 417I'm not sure if you've heard of the "Curse of the Dragon" or not, but it's a very popular game in 418``` 419 420At this point, it is likely to run very slowly. This is because ExecuTorch hasn't been told to optimize for 421specific hardware (delegation), and because it is doing all of the calculations in 32-bit floating point (no quantization). 422 423## Delegation 424 425While ExecuTorch provides a portable, cross-platform implementation for all 426operators, it also provides specialized backends for a number of different 427targets. These include, but are not limited to, x86 and ARM CPU acceleration via 428the XNNPACK backend, Apple acceleration via the Core ML backend and Metal 429Performance Shader (MPS) backend, and GPU acceleration via the Vulkan backend. 430 431Because optimizations are specific to a given backend, each pte file is specific 432to the backend(s) targeted at export. To support multiple devices, such as 433XNNPACK acceleration for Android and Core ML for iOS, export a separate PTE file 434for each backend. 435 436To delegate to a backend at export time, ExecuTorch provides the `to_backend()` 437function in the `EdgeProgramManager` object, which takes a backend-specific 438partitioner object. The partitioner is responsible for finding parts of the 439computation graph that can be accelerated by the target backend,and 440`to_backend()` function will delegate matched part to given backend for 441acceleration and optimization. Any portions of the computation graph not 442delegated will be executed by the ExecuTorch operator implementations. 443 444To delegate the exported model to a specific backend, we need to import its 445partitioner as well as edge compile config from ExecuTorch codebase first, then 446call `to_backend` with an instance of partitioner on the `EdgeProgramManager` 447object `to_edge` function created. 448 449Here's an example of how to delegate nanoGPT to XNNPACK (if you're deploying to an Android phone for instance): 450 451```python 452# export_nanogpt.py 453 454# Load partitioner for Xnnpack backend 455from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner 456 457# Model to be delegated to specific backend should use specific edge compile config 458from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config 459from executorch.exir import EdgeCompileConfig, to_edge 460 461import torch 462from torch.export import export 463from torch.nn.attention import sdpa_kernel, SDPBackend 464from torch.export import export_for_training 465 466from model import GPT 467 468# Load the nanoGPT model. 469model = GPT.from_pretrained('gpt2') 470 471# Create example inputs. This is used in the export process to provide 472# hints on the expected shape of the model input. 473example_inputs = ( 474 torch.randint(0, 100, (1, model.config.block_size - 1), dtype=torch.long), 475 ) 476 477# Set up dynamic shape configuration. This allows the sizes of the input tensors 478# to differ from the sizes of the tensors in `example_inputs` during runtime, as 479# long as they adhere to the rules specified in the dynamic shape configuration. 480# Here we set the range of 0th model input's 1st dimension as 481# [0, model.config.block_size]. 482# See https://pytorch.org/executorch/main/concepts.html#dynamic-shapes 483# for details about creating dynamic shapes. 484dynamic_shape = ( 485 {1: torch.export.Dim("token_dim", max=model.config.block_size - 1)}, 486) 487 488# Trace the model, converting it to a portable intermediate representation. 489# The torch.no_grad() call tells PyTorch to exclude training-specific logic. 490with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): 491 m = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shape).module() 492 traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape) 493 494# Convert the model into a runnable ExecuTorch program. 495# To be further lowered to Xnnpack backend, `traced_model` needs xnnpack-specific edge compile config 496edge_config = get_xnnpack_edge_compile_config() 497edge_manager = to_edge(traced_model, compile_config=edge_config) 498 499# Delegate exported model to Xnnpack backend by invoking `to_backend` function with Xnnpack partitioner. 500edge_manager = edge_manager.to_backend(XnnpackPartitioner()) 501et_program = edge_manager.to_executorch() 502 503# Save the Xnnpack-delegated ExecuTorch program to a file. 504with open("nanogpt.pte", "wb") as file: 505 file.write(et_program.buffer) 506 507 508``` 509 510Additionally, update CMakeLists.txt to build and link the XNNPACK backend to 511ExecuTorch runner. 512 513``` 514cmake_minimum_required(VERSION 3.19) 515project(nanogpt_runner) 516 517set(CMAKE_CXX_STANDARD 17) 518set(CMAKE_CXX_STANDARD_REQUIRED True) 519 520# Set options for executorch build. 521option(EXECUTORCH_ENABLE_LOGGING "" ON) 522option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON) 523option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON) 524option(EXECUTORCH_BUILD_EXTENSION_TENSOR "" ON) 525option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON) 526option(EXECUTORCH_BUILD_XNNPACK "" ON) # Build with Xnnpack backend 527 528# Include the executorch subdirectory. 529add_subdirectory( 530 ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch 531 ${CMAKE_BINARY_DIR}/executorch 532) 533 534add_executable(nanogpt_runner main.cpp) 535target_link_libraries( 536 nanogpt_runner 537 PRIVATE executorch 538 extension_module_static # Provides the Module class 539 extension_tensor # Provides the TensorPtr class 540 optimized_native_cpu_ops_lib # Provides baseline cross-platform 541 # kernels 542 xnnpack_backend # Provides the XNNPACK CPU acceleration backend 543) 544``` 545 546Keep the rest of the code the same. For more details refer to [Exporting 547to ExecuTorch](#step-1-exporting-to-executorch) and [Invoking the 548Runtime](#step-2-invoking-the-runtime) for more details 549 550At this point, the working directory should contain the following files: 551 552- CMakeLists.txt 553- main.cpp 554- basic_tokenizer.h 555- basic_sampler.h 556- export_nanogpt.py 557- model.py 558- vocab.json 559 560If all of these are present, you can now export Xnnpack delegated pte model: 561```bash 562python export_nanogpt.py 563``` 564 565It will generate `nanogpt.pte`, under the same working directory. 566 567Then we can build and run the model by: 568```bash 569(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..) 570cmake --build cmake-out -j10 571./cmake-out/nanogpt_runner 572``` 573 574 575You should see the message: 576 577``` 578Enter model prompt: 579``` 580 581Type some seed text for the model and press enter. Here we use "Hello world!" as 582an example prompt: 583 584``` 585Enter model prompt: Hello world! 586Hello world! 587 588I'm not sure if you've heard of the "Curse of the Dragon" or not, but it's a very popular game in 589``` 590 591The delegated model should be noticeably faster compared to the non-delegated model. 592 593For more information regarding backend delegateion, see the ExecuTorch guides 594for the [XNNPACK Backend](../tutorial-xnnpack-delegate-lowering.md), [Core ML 595Backend](../build-run-coreml.md) and [Qualcomm AI Engine Direct Backend](build-run-llama3-qualcomm-ai-engine-direct-backend.md). 596 597## Quantization 598 599Quantization refers to a set of techniques for running calculations and storing tensors using lower precision types. 600Compared to 32-bit floating point, using 8-bit integers can provide both a significant speedup and reduction in 601memory usage. There are many approaches to quantizing a model, varying in amount of pre-processing required, data 602types used, and impact on model accuracy and performance. 603 604Because compute and memory are highly constrained on mobile devices, some form of quantization is necessary to ship 605large models on consumer electronics. In particular, large language models, such as Llama2, may require quantizing 606model weights to 4 bits or less. 607 608Leveraging quantization requires transforming the model before export. PyTorch provides the pt2e (PyTorch 2 Export) 609API for this purpose. This example targets CPU acceleration using the XNNPACK delegate. As such, it needs to use the 610 XNNPACK-specific quantizer. Targeting a different backend will require use of the corresponding quantizer. 611 612To use 8-bit integer dynamic quantization with the XNNPACK delegate, call `prepare_pt2e`, calibrate the model by 613running with a representative input, and then call `convert_pt2e`. This updates the computational graph to use 614quantized operators where available. 615 616```python 617# export_nanogpt.py 618 619from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( 620 DuplicateDynamicQuantChainPass, 621) 622from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 623 get_symmetric_quantization_config, 624 XNNPACKQuantizer, 625) 626from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 627``` 628 629```python 630# Use dynamic, per-channel quantization. 631xnnpack_quant_config = get_symmetric_quantization_config( 632 is_per_channel=True, is_dynamic=True 633) 634xnnpack_quantizer = XNNPACKQuantizer() 635xnnpack_quantizer.set_global(xnnpack_quant_config) 636 637m = export_for_training(model, example_inputs).module() 638 639# Annotate the model for quantization. This prepares the model for calibration. 640m = prepare_pt2e(m, xnnpack_quantizer) 641 642# Calibrate the model using representative inputs. This allows the quantization 643# logic to determine the expected range of values in each tensor. 644m(*example_inputs) 645 646# Perform the actual quantization. 647m = convert_pt2e(m, fold_quantize=False) 648DuplicateDynamicQuantChainPass()(m) 649 650traced_model = export(m, example_inputs) 651``` 652 653Additionally, add or update the `to_backend()` call to use `XnnpackPartitioner`. This instructs ExecuTorch to 654optimize the model for CPU execution via the XNNPACK backend. 655 656```python 657from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( 658 XnnpackPartitioner, 659) 660``` 661 662```python 663edge_manager = to_edge(traced_model, compile_config=edge_config) 664edge_manager = edge_manager.to_backend(XnnpackPartitioner()) # Lower to XNNPACK. 665et_program = edge_manager.to_executorch() 666``` 667 668Finally, ensure that the runner links against the `xnnpack_backend` target in CMakeLists.txt. 669 670``` 671add_executable(nanogpt_runner main.cpp) 672target_link_libraries( 673 nanogpt_runner 674 PRIVATE 675 executorch 676 extension_module_static # Provides the Module class 677 optimized_native_cpu_ops_lib # Provides baseline cross-platform kernels 678 xnnpack_backend) # Provides the XNNPACK CPU acceleration backend 679``` 680 681For more information, see [Quantization in ExecuTorch](../quantization-overview.md). 682 683## Profiling and Debugging 684After lowering a model by calling `to_backend()`, you may want to see what got delegated and what didn’t. ExecuTorch 685provides utility methods to give insight on the delegation. You can use this information to gain visibility into 686the underlying computation and diagnose potential performance issues. Model authors can use this information to 687structure the model in a way that is compatible with the target backend. 688 689### Visualizing the Delegation 690 691The `get_delegation_info()` method provides a summary of what happened to the model after the `to_backend()` call: 692 693```python 694from executorch.devtools.backend_debug import get_delegation_info 695from tabulate import tabulate 696 697# ... After call to to_backend(), but before to_executorch() 698graph_module = edge_manager.exported_program().graph_module 699delegation_info = get_delegation_info(graph_module) 700print(delegation_info.get_summary()) 701df = delegation_info.get_operator_delegation_dataframe() 702print(tabulate(df, headers="keys", tablefmt="fancy_grid")) 703``` 704 705For nanoGPT targeting the XNNPACK backend, you might see the following (note that the numbers below are for illustration purposes only and actual values may vary): 706``` 707Total delegated subgraphs: 145 708Number of delegated nodes: 350 709Number of non-delegated nodes: 760 710``` 711 712 713| | op_type | # in_delegated_graphs | # in_non_delegated_graphs | 714|----|---------------------------------|------- |-----| 715| 0 | aten__softmax_default | 12 | 0 | 716| 1 | aten_add_tensor | 37 | 0 | 717| 2 | aten_addmm_default | 48 | 0 | 718| 3 | aten_any_dim | 0 | 12 | 719| | ... | | | 720| 25 | aten_view_copy_default | 96 | 122 | 721| | ... | | | 722| 30 | Total | 350 | 760 | 723 724From the table, the operator `aten_view_copy_default` appears 96 times in delegate graphs and 122 times in non-delegated graphs. 725To see a more detailed view, use the `format_delegated_graph()` method to get a formatted str of printout of the whole graph or use `print_delegated_graph()` to print directly: 726 727```python 728from executorch.exir.backend.utils import format_delegated_graph 729graph_module = edge_manager.exported_program().graph_module 730print(format_delegated_graph(graph_module)) 731``` 732This may generate a large amount of output for large models. Consider using "Control+F" or "Command+F" to locate the operator you’re interested in 733(e.g. “aten_view_copy_default”). Observe which instances are not under lowered graphs. 734 735In the fragment of the output for nanoGPT below, observe that a transformer module has been delegated to XNNPACK while the where operator is not. 736 737``` 738%aten_where_self_22 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.where.self](args = (%aten_logical_not_default_33, %scalar_tensor_23, %scalar_tensor_22), kwargs = {}) 739%lowered_module_144 : [num_users=1] = get_attr[target=lowered_module_144] 740backend_id: XnnpackBackend 741lowered graph(): 742 %p_transformer_h_0_attn_c_attn_weight : [num_users=1] = placeholder[target=p_transformer_h_0_attn_c_attn_weight] 743 %p_transformer_h_0_attn_c_attn_bias : [num_users=1] = placeholder[target=p_transformer_h_0_attn_c_attn_bias] 744 %getitem : [num_users=1] = placeholder[target=getitem] 745 %sym_size : [num_users=2] = placeholder[target=sym_size] 746 %aten_view_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%getitem, [%sym_size, 768]), kwargs = {}) 747 %aten_permute_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%p_transformer_h_0_attn_c_attn_weight, [1, 0]), kwargs = {}) 748 %aten_addmm_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.addmm.default](args = (%p_transformer_h_0_attn_c_attn_bias, %aten_view_copy_default, %aten_permute_copy_default), kwargs = {}) 749 %aten_view_copy_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_addmm_default, [1, %sym_size, 2304]), kwargs = {}) 750 return [aten_view_copy_default_1] 751``` 752 753### Performance Analysis 754 755Through the ExecuTorch Developer Tools, users are able to profile model execution, giving timing information for each operator in the model. 756 757#### Prerequisites 758 759##### ETRecord generation (Optional) 760 761An ETRecord is an artifact generated at the time of export that contains model graphs and source-level metadata linking the ExecuTorch program to the original PyTorch model. You can view all profiling events without an ETRecord, though with an ETRecord, you will also be able to link each event to the types of operators being executed, module hierarchy, and stack traces of the original PyTorch source code. For more information, see [the ETRecord docs](../etrecord.md). 762 763 764In your export script, after calling `to_edge()` and `to_executorch()`, call `generate_etrecord()` with the `EdgeProgramManager` from `to_edge()` and the `ExecuTorchProgramManager` from `to_executorch()`. Make sure to copy the `EdgeProgramManager`, as the call to `to_backend()` mutates the graph in-place. 765 766``` 767# export_nanogpt.py 768 769import copy 770from executorch.devtools import generate_etrecord 771 772# Make the deep copy immediately after to to_edge() 773edge_manager_copy = copy.deepcopy(edge_manager) 774 775# ... 776# Generate ETRecord right after to_executorch() 777etrecord_path = "etrecord.bin" 778generate_etrecord(etrecord_path, edge_manager_copy, et_program) 779``` 780 781Run the export script and the ETRecord will be generated as `etrecord.bin`. 782 783##### ETDump generation 784 785An ETDump is an artifact generated at runtime containing a trace of the model execution. For more information, see [the ETDump docs](../etdump.md). 786 787Include the ETDump header in your code. 788```cpp 789// main.cpp 790 791#include <executorch/devtools/etdump/etdump_flatcc.h> 792``` 793 794Create an Instance of the ETDumpGen class and pass it to the Module constructor. 795```cpp 796std::unique_ptr<ETDumpGen> etdump_gen_ = std::make_unique<ETDumpGen>(); 797Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors, std::move(etdump_gen_)); 798``` 799 800After calling `generate()`, save the ETDump to a file. You can capture multiple 801model runs in a single trace, if desired. 802```cpp 803ETDumpGen* etdump_gen = static_cast<ETDumpGen*>(model.event_tracer()); 804 805ET_LOG(Info, "ETDump size: %zu blocks", etdump_gen->get_num_blocks()); 806etdump_result result = etdump_gen->get_etdump_data(); 807if (result.buf != nullptr && result.size > 0) { 808 // On a device with a file system, users can just write it to a file. 809 FILE* f = fopen("etdump.etdp", "w+"); 810 fwrite((uint8_t*)result.buf, 1, result.size, f); 811 fclose(f); 812 free(result.buf); 813} 814``` 815 816Additionally, update CMakeLists.txt to build with Developer Tools and enable events to be traced and logged into ETDump: 817 818``` 819option(EXECUTORCH_ENABLE_EVENT_TRACER "" ON) 820option(EXECUTORCH_BUILD_DEVTOOLS "" ON) 821 822# ... 823 824target_link_libraries( 825 # ... omit existing ones 826 etdump) # Provides event tracing and logging 827 828target_compile_options(executorch PUBLIC -DET_EVENT_TRACER_ENABLED) 829target_compile_options(portable_ops_lib PUBLIC -DET_EVENT_TRACER_ENABLED) 830``` 831Build and run the runner, you will see a file named “etdump.etdp” is generated. (Note that this time we build in release mode to get around a flatccrt build limitation.) 832```bash 833(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake -DCMAKE_BUILD_TYPE=Release ..) 834cmake --build cmake-out -j10 835./cmake-out/nanogpt_runner 836``` 837 838#### Analyze with Inspector APIs 839 840Once you’ve collected debug artifacts ETDump (and optionally an ETRecord), you can use the Inspector API to view performance information. 841 842```python 843from executorch.devtools import Inspector 844 845inspector = Inspector(etdump_path="etdump.etdp") 846# If you also generated an ETRecord, then pass that in as well: `inspector = Inspector(etdump_path="etdump.etdp", etrecord="etrecord.bin")` 847 848with open("inspector_out.txt", "w") as file: 849 inspector.print_data_tabular(file) 850``` 851This prints the performance data in a tabular format in “inspector_out.txt”, with each row being a profiling event. Top rows look like this: 852 853<a href="../_static/img/llm_manual_print_data_tabular.png" target="_blank">View in full size</a> 854 855To learn more about the Inspector and the rich functionality it provides, see the [Inspector API Reference](../model-inspector.md). 856 857## Custom Kernels 858With the ExecuTorch custom operator APIs, custom operator and kernel authors can easily bring in their kernel into PyTorch/ExecuTorch. 859 860There are three steps to use custom kernels in ExecuTorch: 861 8621. Write the custom kernel using ExecuTorch types. 8632. Compile and link the custom kernel to both AOT Python environment as well as the runtime binary. 8643. Source-to-source transformation to swap an operator with a custom op. 865 866### Writing a Custom Kernel 867 868Define your custom operator schema for both functional variant (used in AOT compilation) and out variant (used in ExecuTorch runtime). The schema needs to follow PyTorch ATen convention (see [native_functions.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml)). 869 870``` 871custom_linear(Tensor weight, Tensor input, Tensor(?) bias) -> Tensor 872 873custom_linear.out(Tensor weight, Tensor input, Tensor(?) bias, *, Tensor(a!) out) -> Tensor(a!) 874``` 875 876Write your custom kernel according to the schema defined above. Use the `EXECUTORCH_LIBRARY` macro to make the kernel available to the ExecuTorch runtime. 877 878```cpp 879// custom_linear.h / custom_linear.cpp 880#include <executorch/runtime/kernel/kernel_includes.h> 881 882Tensor& custom_linear_out(const Tensor& weight, const Tensor& input, optional<Tensor> bias, Tensor& out) { 883 // calculation 884 return out; 885} 886 887// Register as myop::custom_linear.out 888EXECUTORCH_LIBRARY(myop, "custom_linear.out", custom_linear_out); 889``` 890 891To make this operator available in PyTorch, you can define a wrapper around the ExecuTorch custom kernel. Note that the ExecuTorch 892implementation uses ExecuTorch tensor types, while the PyTorch wrapper uses ATen tensors. 893 894```cpp 895// custom_linear_pytorch.cpp 896 897#include "custom_linear.h" 898#include <torch/library.h> 899 900at::Tensor custom_linear(const at::Tensor& weight, const at::Tensor& input, std::optional<at::Tensor> bias) { 901 902 // initialize out 903 at::Tensor out = at::empty({weight.size(1), input.size(1)}); 904 905 // wrap kernel in custom_linear.cpp into ATen kernel 906 WRAP_TO_ATEN(custom_linear_out, 3)(weight, input, bias, out); 907 908 return out; 909} 910 911// Register the operator with PyTorch. 912TORCH_LIBRARY(myop, m) { 913 m.def("custom_linear(Tensor weight, Tensor input, Tensor(?) bias) -> Tensor", custom_linear); 914 m.def("custom_linear.out(Tensor weight, Tensor input, Tensor(?) bias, *, Tensor(a!) out) -> Tensor(a!)", WRAP_TO_ATEN(custom_linear_out, 3)); 915} 916``` 917 918### Compile and Link the Custom Kernel 919 920To make it available to the ExecuTorch runtime, compile custom_linear.h/cpp into the binary target. You can also build the kernel as a dynamically loaded library (.so or .dylib) and link it as well. 921 922To make it available to PyTorch, package custom_linear.h, custom_linear.cpp and custom_linear_pytorch.cpp into a dynamically loaded library (.so or .dylib) and load it into the python environment. 923This is needed to make PyTorch aware of the custom operator at the time of export. 924 925```python 926import torch 927torch.ops.load_library("libcustom_linear.so") 928``` 929 930Once loaded, you can use the custom operator in PyTorch code. 931 932For more information, see [PyTorch Custom Operators](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html) and 933and [ExecuTorch Kernel Registration](../kernel-library-custom-aten-kernel.md). 934 935### Using a Custom Operator in a Model 936 937The custom operator can explicitly used in the PyTorch model, or you can write a transformation to replace instances of a core operator with the custom variant. For this example, you could find 938all instances of `torch.nn.Linear` and replace them with `CustomLinear`. 939 940```python 941def replace_linear_with_custom_linear(module): 942 for name, child in module.named_children(): 943 if isinstance(child, nn.Linear): 944 setattr( 945 module, 946 name, 947 CustomLinear(child.in_features, child.out_features, child.bias), 948 ) 949 else: 950 replace_linear_with_custom_linear(child) 951``` 952 953The remaining steps are the same as the normal flow. Now you can run this module in eager mode as well as export to ExecuTorch. 954 955## How to Build Mobile Apps 956See the instructions for building and running LLMs using ExecuTorch on iOS and Android. 957 958* **[iOS ExecuTorch LLaMA Demo App](llama-demo-ios.md)** 959* **[Android ExecuTorch LLaMA Demo App](llama-demo-android.md)** 960