xref: /aosp_15_r20/external/executorch/extension/training/README.md (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# ExecuTorch On-device Training
2*523fa7a6SAndroid Build Coastguard Worker
3*523fa7a6SAndroid Build Coastguard WorkerThis subtree contains infrastructure to facilitate on-device training using ExecuTorch.
4*523fa7a6SAndroid Build Coastguard WorkerThis feature is experimental and under heavy active development, all the APIs are
5*523fa7a6SAndroid Build Coastguard Workersubject to change and many things may not work out of the box or at all in the
6*523fa7a6SAndroid Build Coastguard Workercurrent state.
7*523fa7a6SAndroid Build Coastguard Worker
8*523fa7a6SAndroid Build Coastguard Worker## Layout
9*523fa7a6SAndroid Build Coastguard Worker- `examples/` : Example end to end flows from model definition to optimizer.step()
10*523fa7a6SAndroid Build Coastguard Worker- `module/`: Utility class to provide an improved UX when using ExecuTorch for Training.
11*523fa7a6SAndroid Build Coastguard Worker- `optimizer/`: Cpp implementations of various optimizers, currently only SGD though Adam is planned.
12*523fa7a6SAndroid Build Coastguard Worker- `test/`: Tests that cover multiple subdirs.
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Worker## Technical Birds Eye view
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard WorkerAt a high level ExecuTorch training follows a similar flow to inference with a few extra steps.
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard WorkerInstead of relying on autograd at runtime to dynamically generate the backward graph and then walk it,
19*523fa7a6SAndroid Build Coastguard Workerwe capture the backward graph ahead of time. This lets us be a lot leaner on-device as well as
20*523fa7a6SAndroid Build Coastguard Workerletting backends have more direct control over more of the model execution. Currently the optimizer is not
21*523fa7a6SAndroid Build Coastguard Workercaptured though this may change over time.
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard WorkerLoss functions must be embedded inside the model definition (and be the first output) this is used during
24*523fa7a6SAndroid Build Coastguard Workercapture to generate the backwards graph.
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard WorkerGradients become explicit graph outputs rather then hidden tensor state.
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard WorkerSince the weights now need to be mutable during execution, they are memory planned ahead of time and copied
29*523fa7a6SAndroid Build Coastguard Workerfrom the .pte into the HeirarchicalAllocator arenas during Method init.
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard WorkerIntegration with backends/delegates is still a work in progress.
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker## End to End Example
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard WorkerTo further understand the features of ExecuTorch Training and how to leverage it,
37*523fa7a6SAndroid Build Coastguard Workerconsider the following end to end example with a neural network learning the XOR function.
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker### Lowering a joint-graph model to ExecuTorch
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard WorkerAfter following the [setting up ExecuTorch] guide. You can run
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker```bash
44*523fa7a6SAndroid Build Coastguard Workerpython3 extension/training/examples/XOR/export_model.py --outdir /tmp/foobar
45*523fa7a6SAndroid Build Coastguard Worker```
46*523fa7a6SAndroid Build Coastguard Workerto generate the model file. Below is a walkthrough of how that script works.
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard WorkerFirst lets define our model.
49*523fa7a6SAndroid Build Coastguard Worker```python
50*523fa7a6SAndroid Build Coastguard Workerimport torch.nn as nn
51*523fa7a6SAndroid Build Coastguard Workerfrom torch.nn import functional as F
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export
54*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.experimental import _export_forward_backward
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker# Basic Net for XOR
58*523fa7a6SAndroid Build Coastguard Workerclass Net(nn.Module):
59*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
60*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
61*523fa7a6SAndroid Build Coastguard Worker        self.linear = nn.Linear(2, 10)
62*523fa7a6SAndroid Build Coastguard Worker        self.linear2 = nn.Linear(10, 2)
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
65*523fa7a6SAndroid Build Coastguard Worker        return self.linear2(F.sigmoid(self.linear(x)))
66*523fa7a6SAndroid Build Coastguard Worker```
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard WorkerThe first big difference from the normal ExecuTorch flow is that for training we must embed
69*523fa7a6SAndroid Build Coastguard Workerthe loss function into model and return the loss as our first output.
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard WorkerWe don't want to modify the original model definition so we will just wrap it.
72*523fa7a6SAndroid Build Coastguard Worker
73*523fa7a6SAndroid Build Coastguard Worker```python
74*523fa7a6SAndroid Build Coastguard Workerclass TrainingNet(nn.Module):
75*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, net):
76*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
77*523fa7a6SAndroid Build Coastguard Worker        self.net = net
78*523fa7a6SAndroid Build Coastguard Worker        self.loss = nn.CrossEntropyLoss()
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker    def forward(self, input, label):
81*523fa7a6SAndroid Build Coastguard Worker        pred = self.net(input)
82*523fa7a6SAndroid Build Coastguard Worker        return self.loss(pred, label), pred.detach().argmax(dim=1)
83*523fa7a6SAndroid Build Coastguard Worker```
84*523fa7a6SAndroid Build Coastguard Worker
85*523fa7a6SAndroid Build Coastguard WorkerNow that we have our model we can lower it to ExecuTorch. To do that we just have to follow
86*523fa7a6SAndroid Build Coastguard Workera few simple steps.
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker```python
89*523fa7a6SAndroid Build Coastguard Workernet = TrainingNet(Net())
90*523fa7a6SAndroid Build Coastguard Worker
91*523fa7a6SAndroid Build Coastguard Worker# Create our inputs, only the shapes of these matter.
92*523fa7a6SAndroid Build Coastguard Workerinput = torch.randn(1, 2)
93*523fa7a6SAndroid Build Coastguard Workerlabel = torch.ones(1, dtype=torch.int64)
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Worker# Captures the forward graph. The graph will look similar to the model definition now.
96*523fa7a6SAndroid Build Coastguard Worker# Will move to export_for_training soon which is the api planned to be supported in the long term.
97*523fa7a6SAndroid Build Coastguard Workerep = export(net, (input, label))
98*523fa7a6SAndroid Build Coastguard Worker```
99*523fa7a6SAndroid Build Coastguard Worker
100*523fa7a6SAndroid Build Coastguard WorkerThis is what the graph looks like after export
101*523fa7a6SAndroid Build Coastguard Worker```python
102*523fa7a6SAndroid Build Coastguard Worker>>>print(ep.graph_module.graph)
103*523fa7a6SAndroid Build Coastguard Worker
104*523fa7a6SAndroid Build Coastguard Workergraph():
105*523fa7a6SAndroid Build Coastguard Worker    %p_net_linear_weight : [num_users=1] = placeholder[target=p_net_linear_weight]
106*523fa7a6SAndroid Build Coastguard Worker    %p_net_linear_bias : [num_users=1] = placeholder[target=p_net_linear_bias]
107*523fa7a6SAndroid Build Coastguard Worker    %p_net_linear2_weight : [num_users=1] = placeholder[target=p_net_linear2_weight]
108*523fa7a6SAndroid Build Coastguard Worker    %p_net_linear2_bias : [num_users=1] = placeholder[target=p_net_linear2_bias]
109*523fa7a6SAndroid Build Coastguard Worker    %input : [num_users=1] = placeholder[target=input]
110*523fa7a6SAndroid Build Coastguard Worker    %label : [num_users=1] = placeholder[target=label]
111*523fa7a6SAndroid Build Coastguard Worker    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%input, %p_net_linear_weight, %p_net_linear_bias), kwargs = {})
112*523fa7a6SAndroid Build Coastguard Worker    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
113*523fa7a6SAndroid Build Coastguard Worker    %linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%sigmoid, %p_net_linear2_weight, %p_net_linear2_bias), kwargs = {})
114*523fa7a6SAndroid Build Coastguard Worker    %cross_entropy_loss : [num_users=1] = call_function[target=torch.ops.aten.cross_entropy_loss.default](args = (%linear_1, %label), kwargs = {})
115*523fa7a6SAndroid Build Coastguard Worker    %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%linear_1,), kwargs = {})
116*523fa7a6SAndroid Build Coastguard Worker    %argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%detach, 1), kwargs = {})
117*523fa7a6SAndroid Build Coastguard Worker    return (cross_entropy_loss, argmax)
118*523fa7a6SAndroid Build Coastguard Worker```
119*523fa7a6SAndroid Build Coastguard Worker
120*523fa7a6SAndroid Build Coastguard WorkerIt should look pretty similar to our model's forward function. Now we need to capture the backwards graph.
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Worker```python
123*523fa7a6SAndroid Build Coastguard Workerep = _export_forward_backward(ep)
124*523fa7a6SAndroid Build Coastguard Worker```
125*523fa7a6SAndroid Build Coastguard Worker
126*523fa7a6SAndroid Build Coastguard Workerand now the graph is
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker```python
129*523fa7a6SAndroid Build Coastguard Worker>>>print(ep.graph_module.graph)
130*523fa7a6SAndroid Build Coastguard Worker
131*523fa7a6SAndroid Build Coastguard Workergraph():
132*523fa7a6SAndroid Build Coastguard Worker    %p_net_linear_weight : [num_users=1] = placeholder[target=p_net_linear_weight]
133*523fa7a6SAndroid Build Coastguard Worker    %p_net_linear_bias : [num_users=1] = placeholder[target=p_net_linear_bias]
134*523fa7a6SAndroid Build Coastguard Worker    %p_net_linear2_weight : [num_users=1] = placeholder[target=p_net_linear2_weight]
135*523fa7a6SAndroid Build Coastguard Worker    %p_net_linear2_bias : [num_users=1] = placeholder[target=p_net_linear2_bias]
136*523fa7a6SAndroid Build Coastguard Worker    %input : [num_users=2] = placeholder[target=input]
137*523fa7a6SAndroid Build Coastguard Worker    %label : [num_users=5] = placeholder[target=label]
138*523fa7a6SAndroid Build Coastguard Worker    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%p_net_linear_weight, [1, 0]), kwargs = {})
139*523fa7a6SAndroid Build Coastguard Worker    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%p_net_linear_bias, %input, %permute), kwargs = {})
140*523fa7a6SAndroid Build Coastguard Worker    %sigmoid : [num_users=3] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
141*523fa7a6SAndroid Build Coastguard Worker    %alias : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%sigmoid,), kwargs = {})
142*523fa7a6SAndroid Build Coastguard Worker    %alias_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias,), kwargs = {})
143*523fa7a6SAndroid Build Coastguard Worker    %permute_1 : [num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%p_net_linear2_weight, [1, 0]), kwargs = {})
144*523fa7a6SAndroid Build Coastguard Worker    %addmm_1 : [num_users=2] = call_function[target=torch.ops.aten.addmm.default](args = (%p_net_linear2_bias, %sigmoid, %permute_1), kwargs = {})
145*523fa7a6SAndroid Build Coastguard Worker    %_log_softmax : [num_users=3] = call_function[target=torch.ops.aten._log_softmax.default](args = (%addmm_1, 1, False), kwargs = {})
146*523fa7a6SAndroid Build Coastguard Worker    %alias_2 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%_log_softmax,), kwargs = {})
147*523fa7a6SAndroid Build Coastguard Worker    %alias_3 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_2,), kwargs = {})
148*523fa7a6SAndroid Build Coastguard Worker    %ne : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {})
149*523fa7a6SAndroid Build Coastguard Worker    %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu})
150*523fa7a6SAndroid Build Coastguard Worker    %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne, %label, %scalar_tensor), kwargs = {})
151*523fa7a6SAndroid Build Coastguard Worker    %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%where, 1), kwargs = {})
152*523fa7a6SAndroid Build Coastguard Worker    %gather : [num_users=1] = call_function[target=torch.ops.aten.gather.default](args = (%_log_softmax, 1, %unsqueeze), kwargs = {})
153*523fa7a6SAndroid Build Coastguard Worker    %squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%gather, [1]), kwargs = {})
154*523fa7a6SAndroid Build Coastguard Worker    %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%squeeze,), kwargs = {})
155*523fa7a6SAndroid Build Coastguard Worker    %ne_1 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {})
156*523fa7a6SAndroid Build Coastguard Worker    %scalar_tensor_1 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})
157*523fa7a6SAndroid Build Coastguard Worker    %where_1 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_1, %neg, %scalar_tensor_1), kwargs = {})
158*523fa7a6SAndroid Build Coastguard Worker    %ne_2 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {})
159*523fa7a6SAndroid Build Coastguard Worker    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%ne_2, []), kwargs = {})
160*523fa7a6SAndroid Build Coastguard Worker    %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%sum_1,), kwargs = {dtype: torch.float32, device: cpu})
161*523fa7a6SAndroid Build Coastguard Worker    %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%where_1, []), kwargs = {})
162*523fa7a6SAndroid Build Coastguard Worker    %div : [num_users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_2, %_to_copy), kwargs = {})
163*523fa7a6SAndroid Build Coastguard Worker    %alias_4 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%addmm_1,), kwargs = {})
164*523fa7a6SAndroid Build Coastguard Worker    %alias_5 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_4,), kwargs = {})
165*523fa7a6SAndroid Build Coastguard Worker    %alias_6 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_5,), kwargs = {})
166*523fa7a6SAndroid Build Coastguard Worker    %argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%alias_6, 1), kwargs = {})
167*523fa7a6SAndroid Build Coastguard Worker    %full_like : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%div, 1), kwargs = {pin_memory: False, memory_format: torch.preserve_format})
168*523fa7a6SAndroid Build Coastguard Worker    %div_1 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%full_like, %_to_copy), kwargs = {})
169*523fa7a6SAndroid Build Coastguard Worker    %unsqueeze_1 : [num_users=3] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%label, 1), kwargs = {})
170*523fa7a6SAndroid Build Coastguard Worker    %ne_3 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_1, -100), kwargs = {})
171*523fa7a6SAndroid Build Coastguard Worker    %scalar_tensor_2 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu})
172*523fa7a6SAndroid Build Coastguard Worker    %where_2 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %unsqueeze_1, %scalar_tensor_2), kwargs = {})
173*523fa7a6SAndroid Build Coastguard Worker    %full_like_1 : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%_log_softmax, 0), kwargs = {pin_memory: False, memory_format: torch.preserve_format})
174*523fa7a6SAndroid Build Coastguard Worker    %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.value](args = (%full_like_1, 1, %where_2, -1.0), kwargs = {})
175*523fa7a6SAndroid Build Coastguard Worker    %ne_4 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_1, -100), kwargs = {})
176*523fa7a6SAndroid Build Coastguard Worker    %scalar_tensor_3 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})
177*523fa7a6SAndroid Build Coastguard Worker    %where_3 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_4, %div_1, %scalar_tensor_3), kwargs = {})
178*523fa7a6SAndroid Build Coastguard Worker    %mul : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%scatter, %where_3), kwargs = {})
179*523fa7a6SAndroid Build Coastguard Worker    %alias_7 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_3,), kwargs = {})
180*523fa7a6SAndroid Build Coastguard Worker    %alias_8 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_7,), kwargs = {})
181*523fa7a6SAndroid Build Coastguard Worker    %exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%alias_8,), kwargs = {})
182*523fa7a6SAndroid Build Coastguard Worker    %sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [1], True), kwargs = {})
183*523fa7a6SAndroid Build Coastguard Worker    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%exp, %sum_3), kwargs = {})
184*523fa7a6SAndroid Build Coastguard Worker    %sub : [num_users=3] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul, %mul_1), kwargs = {})
185*523fa7a6SAndroid Build Coastguard Worker    %permute_2 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_1, [1, 0]), kwargs = {})
186*523fa7a6SAndroid Build Coastguard Worker    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%sub, %permute_2), kwargs = {})
187*523fa7a6SAndroid Build Coastguard Worker    %permute_3 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sub, [1, 0]), kwargs = {})
188*523fa7a6SAndroid Build Coastguard Worker    %mm_1 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%permute_3, %sigmoid), kwargs = {})
189*523fa7a6SAndroid Build Coastguard Worker    %permute_4 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mm_1, [1, 0]), kwargs = {})
190*523fa7a6SAndroid Build Coastguard Worker    %sum_4 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%sub, [0], True), kwargs = {})
191*523fa7a6SAndroid Build Coastguard Worker    %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_4, [2]), kwargs = {})
192*523fa7a6SAndroid Build Coastguard Worker    %permute_5 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_4, [1, 0]), kwargs = {})
193*523fa7a6SAndroid Build Coastguard Worker    %alias_9 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_1,), kwargs = {})
194*523fa7a6SAndroid Build Coastguard Worker    %alias_10 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%alias_9,), kwargs = {})
195*523fa7a6SAndroid Build Coastguard Worker    %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (1, %alias_10), kwargs = {})
196*523fa7a6SAndroid Build Coastguard Worker    %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%alias_10, %sub_1), kwargs = {})
197*523fa7a6SAndroid Build Coastguard Worker    %mul_3 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mm, %mul_2), kwargs = {})
198*523fa7a6SAndroid Build Coastguard Worker    %permute_6 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mul_3, [1, 0]), kwargs = {})
199*523fa7a6SAndroid Build Coastguard Worker    %mm_2 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%permute_6, %input), kwargs = {})
200*523fa7a6SAndroid Build Coastguard Worker    %permute_7 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mm_2, [1, 0]), kwargs = {})
201*523fa7a6SAndroid Build Coastguard Worker    %sum_5 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_3, [0], True), kwargs = {})
202*523fa7a6SAndroid Build Coastguard Worker    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_5, [10]), kwargs = {})
203*523fa7a6SAndroid Build Coastguard Worker    %permute_8 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_7, [1, 0]), kwargs = {})
204*523fa7a6SAndroid Build Coastguard Worker    return (div, argmax, permute_8, view_1, permute_5, view)
205*523fa7a6SAndroid Build Coastguard Worker```
206*523fa7a6SAndroid Build Coastguard Worker
207*523fa7a6SAndroid Build Coastguard WorkerIts a lot bigger! We call this the 'joint graph' or the 'forwards backwards graph'. We have explicitly captured the backwards graph
208*523fa7a6SAndroid Build Coastguard Workeralongside the forward and now our model returns [Loss, Any other user outputs, Gradients].
209*523fa7a6SAndroid Build Coastguard Worker
210*523fa7a6SAndroid Build Coastguard WorkerFrom here we can lower the rest of the way to ExecuTorch
211*523fa7a6SAndroid Build Coastguard Worker```python
212*523fa7a6SAndroid Build Coastguard Workerep = to_edge(ep)
213*523fa7a6SAndroid Build Coastguard Worker
214*523fa7a6SAndroid Build Coastguard Worker# After calling to_executorch the weights themselves are also appended to the model outputs. This is to make
215*523fa7a6SAndroid Build Coastguard Worker# some downstream passes like memory planning a little easier. A couple of hidden utility functions are also
216*523fa7a6SAndroid Build Coastguard Worker# embedded in the model __et_training_gradients_index_<method_name>,
217*523fa7a6SAndroid Build Coastguard Worker# __et_training_parameters_index_<method_name>, __et_training_fqn_<method_name>.
218*523fa7a6SAndroid Build Coastguard Worker#
219*523fa7a6SAndroid Build Coastguard Worker# These help us partition the huge list of model outputs into meaningful sections as well as assign names to each weight/gradient.
220*523fa7a6SAndroid Build Coastguard Workerep = ep.to_executorch()
221*523fa7a6SAndroid Build Coastguard Worker
222*523fa7a6SAndroid Build Coastguard Workerwith open("xor.pte", "wb") as file:
223*523fa7a6SAndroid Build Coastguard Worker    ep.write_to_file(file)
224*523fa7a6SAndroid Build Coastguard Worker```
225*523fa7a6SAndroid Build Coastguard Worker
226*523fa7a6SAndroid Build Coastguard Worker### Run the model train script with CMAKE
227*523fa7a6SAndroid Build Coastguard WorkerAfter exporting the model for training, we can now try learning using CMake. We can build and use the train_xor, which is a sample wrapper for the ExecuTorch Runtime, TrainingModule, and SGD optimizer. We first begin by configuring the CMake build like such:
228*523fa7a6SAndroid Build Coastguard Worker```bash
229*523fa7a6SAndroid Build Coastguard Worker# cd to the root of executorch repo
230*523fa7a6SAndroid Build Coastguard Workercd executorch
231*523fa7a6SAndroid Build Coastguard Worker
232*523fa7a6SAndroid Build Coastguard Worker# Get a clean cmake-out directory
233*523fa7a6SAndroid Build Coastguard Workerrm -rf cmake-out
234*523fa7a6SAndroid Build Coastguard Workermkdir cmake-out
235*523fa7a6SAndroid Build Coastguard Worker
236*523fa7a6SAndroid Build Coastguard Worker# Configure cmake
237*523fa7a6SAndroid Build Coastguard Workercmake \
238*523fa7a6SAndroid Build Coastguard Worker    -DCMAKE_INSTALL_PREFIX=cmake-out \
239*523fa7a6SAndroid Build Coastguard Worker    -DCMAKE_BUILD_TYPE=Release \
240*523fa7a6SAndroid Build Coastguard Worker    -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
241*523fa7a6SAndroid Build Coastguard Worker    -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
242*523fa7a6SAndroid Build Coastguard Worker    -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
243*523fa7a6SAndroid Build Coastguard Worker    -DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON \
244*523fa7a6SAndroid Build Coastguard Worker    -DEXECUTORCH_ENABLE_LOGGING=ON \
245*523fa7a6SAndroid Build Coastguard Worker    -DPYTHON_EXECUTABLE=python \
246*523fa7a6SAndroid Build Coastguard Worker    -Bcmake-out .
247*523fa7a6SAndroid Build Coastguard Worker```
248*523fa7a6SAndroid Build Coastguard WorkerThen you can build the runtime componenets with
249*523fa7a6SAndroid Build Coastguard Worker
250*523fa7a6SAndroid Build Coastguard Worker```bash
251*523fa7a6SAndroid Build Coastguard Workercmake --build cmake-out -j9 --target install --config Release
252*523fa7a6SAndroid Build Coastguard Worker```
253*523fa7a6SAndroid Build Coastguard Worker
254*523fa7a6SAndroid Build Coastguard WorkerNow you should be able to find the executable built at `./cmake-out/extension/training/train_xor` you can run the executable with the model you generated as such
255*523fa7a6SAndroid Build Coastguard Worker```bash
256*523fa7a6SAndroid Build Coastguard Worker./cmake-out/extension/training/train_xor --model_path=./xor.pte
257*523fa7a6SAndroid Build Coastguard Worker```
258*523fa7a6SAndroid Build Coastguard Worker
259*523fa7a6SAndroid Build Coastguard Worker## What is missing?/ What is next?
260*523fa7a6SAndroid Build Coastguard WorkerA ton! ExecuTorch training is still quite experimental and under heavy active development. Whats here currently is more of a technical preview.
261*523fa7a6SAndroid Build Coastguard Worker
262*523fa7a6SAndroid Build Coastguard WorkerThe _export_forward_backward is not very stable yet and may fail on more complicated model architectures, though we have verified it works for LoRA with LLMs.
263*523fa7a6SAndroid Build Coastguard Worker
264*523fa7a6SAndroid Build Coastguard WorkerThe ExecuTorch portable operator lib does not yet have full coverage of ops that might show up in the backwards graphs.
265*523fa7a6SAndroid Build Coastguard Worker
266*523fa7a6SAndroid Build Coastguard WorkerWe don't have a way yet to serialize the newly trained weights natively in ExecuTorch (though you can convert them to ATen tensors using extension/aten_util and then serialize them using ATen APIs).
267*523fa7a6SAndroid Build Coastguard Worker
268*523fa7a6SAndroid Build Coastguard WorkerWe plan to add a way to update models in place on-device (will be needed for finetuning).
269*523fa7a6SAndroid Build Coastguard Worker
270*523fa7a6SAndroid Build Coastguard WorkerWe are looking to integrate with many of the existing delegates/backends on ET enabling accelerated training.
271*523fa7a6SAndroid Build Coastguard Worker
272*523fa7a6SAndroid Build Coastguard Workerand so much more!
273*523fa7a6SAndroid Build Coastguard Worker
274*523fa7a6SAndroid Build Coastguard Worker## Help & Improvements
275*523fa7a6SAndroid Build Coastguard WorkerIf you have problems or questions, or have suggestions for ways to make
276*523fa7a6SAndroid Build Coastguard Workerimplementation and testing better, please reach out to the PyTorch Edge team or
277*523fa7a6SAndroid Build Coastguard Workercreate an issue on [github](https://www.github.com/pytorch/executorch/issues).
278