xref: /aosp_15_r20/external/executorch/backends/xnnpack/README.md (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# ExecuTorch XNNPACK Delegate
2
3This subtree contains the XNNPACK Delegate implementation for ExecuTorch.
4XNNPACK is an optimized library of neural network inference operators for ARM
5and x86 CPUs. It is an open source project used by PyTorch. The delegate is the
6mechanism for leveraging the XNNPACK library to accelerate operators running on
7CPU.
8
9## Layout
10- `cmake/` : CMake related files
11- `operators`: the directory to store all of op visitors
12    - `node_visitor.py`: Implementation of serializing each lowerable operator
13      node
14    - ...
15- `partition/`: Partitioner is used to identify operators in model's graph that
16  are suitable for lowering to XNNPACK delegate
17    - `xnnpack_partitioner.py`: Contains partitioner that tags graph patterns
18      for XNNPACK lowering
19    - `configs.py`: Contains lists of op/modules for XNNPACK lowering
20- `passes/`: Contains passes which are used before preprocessing to prepare the
21  graph for XNNPACK lowering
22- `runtime/` : Runtime logic used at inference. This contains all the cpp files
23  used to build the runtime graph and execute the XNNPACK model
24- `serialization/`: Contains files related to serializing the XNNPACK graph
25  representation of the PyTorch model
26    - `schema.fbs`: Flatbuffer schema of serialization format
27    - `xnnpack_graph_schema.py`: Python dataclasses mirroring the flatbuffer
28      schema
29    - `xnnpack_graph_serialize`: Implementation for serializing dataclasses
30      from graph schema to flatbuffer
31- `test/`: Tests for XNNPACK Delegate
32- `third-party/`: third-party libraries used by XNNPACK Delegate
33- `xnnpack_preprocess.py`: Contains preprocess implementation which is called
34  by `to_backend` on the graph or subgraph of a model returning a preprocessed
35  blob responsible for executing the graph or subgraph at runtime
36
37## End to End Example
38
39To further understand the features of the XNNPACK Delegate and how to use it, consider the following end to end example with MobilenetV2.
40
41### Lowering a model to XNNPACK
42```python
43import torch
44import torchvision.models as models
45
46from torch.export import export, ExportedProgram
47from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
48from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
49from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge
50from executorch.exir.backend.backend_api import to_backend
51
52
53mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
54sample_inputs = (torch.randn(1, 3, 224, 224), )
55
56exported_program: ExportedProgram = export(mobilenet_v2, sample_inputs)
57edge: EdgeProgramManager = to_edge(exported_program)
58
59edge = edge.to_backend(XnnpackPartitioner())
60```
61
62We will go through this example with the [MobileNetV2](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/) pretrained model downloaded from the TorchVision library. The flow of lowering a model starts after exporting the model `to_edge`. We call the `to_backend` api with the `XnnpackPartitioner`. The partitioner identifies the subgraphs suitable for XNNPACK backend delegate to consume. Afterwards, the identified subgraphs will be serialized with the XNNPACK Delegate flatbuffer schema and each subgraph will be replaced with a call to the XNNPACK Delegate.
63
64```python
65>>> print(edge.exported_program().graph_module)
66GraphModule(
67  (lowered_module_0): LoweredBackendModule()
68  (lowered_module_1): LoweredBackendModule()
69)
70
71def forward(self, arg314_1):
72    lowered_module_0 = self.lowered_module_0
73    executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, arg314_1);  lowered_module_0 = arg314_1 = None
74    getitem = executorch_call_delegate[0];  executorch_call_delegate = None
75    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem, [1, 1280]);  getitem = None
76    aten_clone_default = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default);  aten_view_copy_default = None
77    lowered_module_1 = self.lowered_module_1
78    executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, aten_clone_default);  lowered_module_1 = aten_clone_default = None
79    getitem_1 = executorch_call_delegate_1[0];  executorch_call_delegate_1 = None
80    return (getitem_1,)
81```
82
83We print the graph after lowering above to show the new nodes that were inserted to call the XNNPACK Delegate. The subgraphs which are being delegated to XNNPACK are the first argument at each call site. It can be observed that the majority of `convolution-relu-add` blocks and `linear` blocks were able to be delegated to XNNPACK. We can also see the operators which were not able to be lowered to the XNNPACK delegate, such as `clone` and `view_copy`.
84
85```python
86exec_prog = edge.to_executorch()
87
88with open("xnnpack_mobilenetv2.pte", "wb") as file:
89    exec_prog.write_to_file(file)
90```
91After lowering to the XNNPACK Program, we can then prepare it for executorch and save the model as a `.pte` file. `.pte` is a binary format that stores the serialized ExecuTorch graph.
92
93
94### Running the XNNPACK Model with CMake
95After exporting the XNNPACK Delegated model, we can now try running it with example inputs using CMake. We can build and use the xnn_executor_runner, which is a sample wrapper for the ExecuTorch Runtime and XNNPACK Backend. We first begin by configuring the CMake build like such:
96```bash
97# cd to the root of executorch repo
98cd executorch
99
100# Get a clean cmake-out directory
101rm -rf cmake-out
102mkdir cmake-out
103
104# Configure cmake
105cmake \
106    -DCMAKE_INSTALL_PREFIX=cmake-out \
107    -DCMAKE_BUILD_TYPE=Release \
108    -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
109    -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
110    -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
111    -DEXECUTORCH_BUILD_XNNPACK=ON \
112    -DEXECUTORCH_ENABLE_LOGGING=ON \
113    -DPYTHON_EXECUTABLE=python \
114    -Bcmake-out .
115```
116Then you can build the runtime componenets with
117
118```bash
119cmake --build cmake-out -j9 --target install --config Release
120```
121
122Now you should be able to find the executable built at `./cmake-out/backends/xnnpack/xnn_executor_runner` you can run the executable with the model you generated as such
123```bash
124./cmake-out/backends/xnnpack/xnn_executor_runner --model_path=./mv2_xnnpack_fp32.pte
125```
126
127## Help & Improvements
128If you have problems or questions, or have suggestions for ways to make
129implementation and testing better, please reach out to the PyTorch Edge team or
130create an issue on [github](https://www.github.com/pytorch/executorch/issues).
131
132
133## See Also
134For more information about the XNNPACK Delegate, please check out the following resources:
135- [ExecuTorch XNNPACK Delegate](https://pytorch.org/executorch/0.2/native-delegates-executorch-xnnpack-delegate.html)
136- [Building and Running ExecuTorch with XNNPACK Backend](https://pytorch.org/executorch/0.2/native-delegates-executorch-xnnpack-delegate.html)
137