Name Date Size #Lines LOC

..--

examples/XOR/H25-Apr-2025-264185

module/H25-Apr-2025-434287

optimizer/H25-Apr-2025-591397

pybindings/H25-Apr-2025-407300

test/H25-Apr-2025-13091

CMakeLists.txtH A D25-Apr-20251.4 KiB5040

README.mdH A D25-Apr-202516.8 KiB278225

TARGETSH A D25-Apr-2025543 2116

__init__.pyH A D25-Apr-2025650 2312

targets.bzlH A D25-Apr-2025247 96

README.md

1# ExecuTorch On-device Training
2
3This subtree contains infrastructure to facilitate on-device training using ExecuTorch.
4This feature is experimental and under heavy active development, all the APIs are
5subject to change and many things may not work out of the box or at all in the
6current state.
7
8## Layout
9- `examples/` : Example end to end flows from model definition to optimizer.step()
10- `module/`: Utility class to provide an improved UX when using ExecuTorch for Training.
11- `optimizer/`: Cpp implementations of various optimizers, currently only SGD though Adam is planned.
12- `test/`: Tests that cover multiple subdirs.
13
14## Technical Birds Eye view
15
16At a high level ExecuTorch training follows a similar flow to inference with a few extra steps.
17
18Instead of relying on autograd at runtime to dynamically generate the backward graph and then walk it,
19we capture the backward graph ahead of time. This lets us be a lot leaner on-device as well as
20letting backends have more direct control over more of the model execution. Currently the optimizer is not
21captured though this may change over time.
22
23Loss functions must be embedded inside the model definition (and be the first output) this is used during
24capture to generate the backwards graph.
25
26Gradients become explicit graph outputs rather then hidden tensor state.
27
28Since the weights now need to be mutable during execution, they are memory planned ahead of time and copied
29from the .pte into the HeirarchicalAllocator arenas during Method init.
30
31Integration with backends/delegates is still a work in progress.
32
33
34## End to End Example
35
36To further understand the features of ExecuTorch Training and how to leverage it,
37consider the following end to end example with a neural network learning the XOR function.
38
39### Lowering a joint-graph model to ExecuTorch
40
41After following the [setting up ExecuTorch] guide. You can run
42
43```bash
44python3 extension/training/examples/XOR/export_model.py --outdir /tmp/foobar
45```
46to generate the model file. Below is a walkthrough of how that script works.
47
48First lets define our model.
49```python
50import torch.nn as nn
51from torch.nn import functional as F
52
53from torch.export import export
54from torch.export.experimental import _export_forward_backward
55
56
57# Basic Net for XOR
58class Net(nn.Module):
59    def __init__(self):
60        super().__init__()
61        self.linear = nn.Linear(2, 10)
62        self.linear2 = nn.Linear(10, 2)
63
64    def forward(self, x):
65        return self.linear2(F.sigmoid(self.linear(x)))
66```
67
68The first big difference from the normal ExecuTorch flow is that for training we must embed
69the loss function into model and return the loss as our first output.
70
71We don't want to modify the original model definition so we will just wrap it.
72
73```python
74class TrainingNet(nn.Module):
75    def __init__(self, net):
76        super().__init__()
77        self.net = net
78        self.loss = nn.CrossEntropyLoss()
79
80    def forward(self, input, label):
81        pred = self.net(input)
82        return self.loss(pred, label), pred.detach().argmax(dim=1)
83```
84
85Now that we have our model we can lower it to ExecuTorch. To do that we just have to follow
86a few simple steps.
87
88```python
89net = TrainingNet(Net())
90
91# Create our inputs, only the shapes of these matter.
92input = torch.randn(1, 2)
93label = torch.ones(1, dtype=torch.int64)
94
95# Captures the forward graph. The graph will look similar to the model definition now.
96# Will move to export_for_training soon which is the api planned to be supported in the long term.
97ep = export(net, (input, label))
98```
99
100This is what the graph looks like after export
101```python
102>>>print(ep.graph_module.graph)
103
104graph():
105    %p_net_linear_weight : [num_users=1] = placeholder[target=p_net_linear_weight]
106    %p_net_linear_bias : [num_users=1] = placeholder[target=p_net_linear_bias]
107    %p_net_linear2_weight : [num_users=1] = placeholder[target=p_net_linear2_weight]
108    %p_net_linear2_bias : [num_users=1] = placeholder[target=p_net_linear2_bias]
109    %input : [num_users=1] = placeholder[target=input]
110    %label : [num_users=1] = placeholder[target=label]
111    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%input, %p_net_linear_weight, %p_net_linear_bias), kwargs = {})
112    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
113    %linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%sigmoid, %p_net_linear2_weight, %p_net_linear2_bias), kwargs = {})
114    %cross_entropy_loss : [num_users=1] = call_function[target=torch.ops.aten.cross_entropy_loss.default](args = (%linear_1, %label), kwargs = {})
115    %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%linear_1,), kwargs = {})
116    %argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%detach, 1), kwargs = {})
117    return (cross_entropy_loss, argmax)
118```
119
120It should look pretty similar to our model's forward function. Now we need to capture the backwards graph.
121
122```python
123ep = _export_forward_backward(ep)
124```
125
126and now the graph is
127
128```python
129>>>print(ep.graph_module.graph)
130
131graph():
132    %p_net_linear_weight : [num_users=1] = placeholder[target=p_net_linear_weight]
133    %p_net_linear_bias : [num_users=1] = placeholder[target=p_net_linear_bias]
134    %p_net_linear2_weight : [num_users=1] = placeholder[target=p_net_linear2_weight]
135    %p_net_linear2_bias : [num_users=1] = placeholder[target=p_net_linear2_bias]
136    %input : [num_users=2] = placeholder[target=input]
137    %label : [num_users=5] = placeholder[target=label]
138    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%p_net_linear_weight, [1, 0]), kwargs = {})
139    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%p_net_linear_bias, %input, %permute), kwargs = {})
140    %sigmoid : [num_users=3] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
141    %alias : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%sigmoid,), kwargs = {})
142    %alias_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias,), kwargs = {})
143    %permute_1 : [num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%p_net_linear2_weight, [1, 0]), kwargs = {})
144    %addmm_1 : [num_users=2] = call_function[target=torch.ops.aten.addmm.default](args = (%p_net_linear2_bias, %sigmoid, %permute_1), kwargs = {})
145    %_log_softmax : [num_users=3] = call_function[target=torch.ops.aten._log_softmax.default](args = (%addmm_1, 1, False), kwargs = {})
146    %alias_2 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%_log_softmax,), kwargs = {})
147    %alias_3 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_2,), kwargs = {})
148    %ne : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {})
149    %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu})
150    %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne, %label, %scalar_tensor), kwargs = {})
151    %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%where, 1), kwargs = {})
152    %gather : [num_users=1] = call_function[target=torch.ops.aten.gather.default](args = (%_log_softmax, 1, %unsqueeze), kwargs = {})
153    %squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%gather, [1]), kwargs = {})
154    %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%squeeze,), kwargs = {})
155    %ne_1 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {})
156    %scalar_tensor_1 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})
157    %where_1 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_1, %neg, %scalar_tensor_1), kwargs = {})
158    %ne_2 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {})
159    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%ne_2, []), kwargs = {})
160    %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%sum_1,), kwargs = {dtype: torch.float32, device: cpu})
161    %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%where_1, []), kwargs = {})
162    %div : [num_users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_2, %_to_copy), kwargs = {})
163    %alias_4 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%addmm_1,), kwargs = {})
164    %alias_5 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_4,), kwargs = {})
165    %alias_6 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_5,), kwargs = {})
166    %argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%alias_6, 1), kwargs = {})
167    %full_like : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%div, 1), kwargs = {pin_memory: False, memory_format: torch.preserve_format})
168    %div_1 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%full_like, %_to_copy), kwargs = {})
169    %unsqueeze_1 : [num_users=3] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%label, 1), kwargs = {})
170    %ne_3 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_1, -100), kwargs = {})
171    %scalar_tensor_2 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu})
172    %where_2 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %unsqueeze_1, %scalar_tensor_2), kwargs = {})
173    %full_like_1 : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%_log_softmax, 0), kwargs = {pin_memory: False, memory_format: torch.preserve_format})
174    %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.value](args = (%full_like_1, 1, %where_2, -1.0), kwargs = {})
175    %ne_4 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_1, -100), kwargs = {})
176    %scalar_tensor_3 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})
177    %where_3 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_4, %div_1, %scalar_tensor_3), kwargs = {})
178    %mul : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%scatter, %where_3), kwargs = {})
179    %alias_7 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_3,), kwargs = {})
180    %alias_8 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_7,), kwargs = {})
181    %exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%alias_8,), kwargs = {})
182    %sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [1], True), kwargs = {})
183    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%exp, %sum_3), kwargs = {})
184    %sub : [num_users=3] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul, %mul_1), kwargs = {})
185    %permute_2 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_1, [1, 0]), kwargs = {})
186    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%sub, %permute_2), kwargs = {})
187    %permute_3 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sub, [1, 0]), kwargs = {})
188    %mm_1 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%permute_3, %sigmoid), kwargs = {})
189    %permute_4 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mm_1, [1, 0]), kwargs = {})
190    %sum_4 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%sub, [0], True), kwargs = {})
191    %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_4, [2]), kwargs = {})
192    %permute_5 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_4, [1, 0]), kwargs = {})
193    %alias_9 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_1,), kwargs = {})
194    %alias_10 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%alias_9,), kwargs = {})
195    %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (1, %alias_10), kwargs = {})
196    %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%alias_10, %sub_1), kwargs = {})
197    %mul_3 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mm, %mul_2), kwargs = {})
198    %permute_6 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mul_3, [1, 0]), kwargs = {})
199    %mm_2 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%permute_6, %input), kwargs = {})
200    %permute_7 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mm_2, [1, 0]), kwargs = {})
201    %sum_5 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_3, [0], True), kwargs = {})
202    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_5, [10]), kwargs = {})
203    %permute_8 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_7, [1, 0]), kwargs = {})
204    return (div, argmax, permute_8, view_1, permute_5, view)
205```
206
207Its a lot bigger! We call this the 'joint graph' or the 'forwards backwards graph'. We have explicitly captured the backwards graph
208alongside the forward and now our model returns [Loss, Any other user outputs, Gradients].
209
210From here we can lower the rest of the way to ExecuTorch
211```python
212ep = to_edge(ep)
213
214# After calling to_executorch the weights themselves are also appended to the model outputs. This is to make
215# some downstream passes like memory planning a little easier. A couple of hidden utility functions are also
216# embedded in the model __et_training_gradients_index_<method_name>,
217# __et_training_parameters_index_<method_name>, __et_training_fqn_<method_name>.
218#
219# These help us partition the huge list of model outputs into meaningful sections as well as assign names to each weight/gradient.
220ep = ep.to_executorch()
221
222with open("xor.pte", "wb") as file:
223    ep.write_to_file(file)
224```
225
226### Run the model train script with CMAKE
227After exporting the model for training, we can now try learning using CMake. We can build and use the train_xor, which is a sample wrapper for the ExecuTorch Runtime, TrainingModule, and SGD optimizer. We first begin by configuring the CMake build like such:
228```bash
229# cd to the root of executorch repo
230cd executorch
231
232# Get a clean cmake-out directory
233rm -rf cmake-out
234mkdir cmake-out
235
236# Configure cmake
237cmake \
238    -DCMAKE_INSTALL_PREFIX=cmake-out \
239    -DCMAKE_BUILD_TYPE=Release \
240    -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
241    -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
242    -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
243    -DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON \
244    -DEXECUTORCH_ENABLE_LOGGING=ON \
245    -DPYTHON_EXECUTABLE=python \
246    -Bcmake-out .
247```
248Then you can build the runtime componenets with
249
250```bash
251cmake --build cmake-out -j9 --target install --config Release
252```
253
254Now you should be able to find the executable built at `./cmake-out/extension/training/train_xor` you can run the executable with the model you generated as such
255```bash
256./cmake-out/extension/training/train_xor --model_path=./xor.pte
257```
258
259## What is missing?/ What is next?
260A ton! ExecuTorch training is still quite experimental and under heavy active development. Whats here currently is more of a technical preview.
261
262The _export_forward_backward is not very stable yet and may fail on more complicated model architectures, though we have verified it works for LoRA with LLMs.
263
264The ExecuTorch portable operator lib does not yet have full coverage of ops that might show up in the backwards graphs.
265
266We don't have a way yet to serialize the newly trained weights natively in ExecuTorch (though you can convert them to ATen tensors using extension/aten_util and then serialize them using ATen APIs).
267
268We plan to add a way to update models in place on-device (will be needed for finetuning).
269
270We are looking to integrate with many of the existing delegates/backends on ET enabling accelerated training.
271
272and so much more!
273
274## Help & Improvements
275If you have problems or questions, or have suggestions for ways to make
276implementation and testing better, please reach out to the PyTorch Edge team or
277create an issue on [github](https://www.github.com/pytorch/executorch/issues).
278