1*523fa7a6SAndroid Build Coastguard Worker# ExecuTorch On-device Training 2*523fa7a6SAndroid Build Coastguard Worker 3*523fa7a6SAndroid Build Coastguard WorkerThis subtree contains infrastructure to facilitate on-device training using ExecuTorch. 4*523fa7a6SAndroid Build Coastguard WorkerThis feature is experimental and under heavy active development, all the APIs are 5*523fa7a6SAndroid Build Coastguard Workersubject to change and many things may not work out of the box or at all in the 6*523fa7a6SAndroid Build Coastguard Workercurrent state. 7*523fa7a6SAndroid Build Coastguard Worker 8*523fa7a6SAndroid Build Coastguard Worker## Layout 9*523fa7a6SAndroid Build Coastguard Worker- `examples/` : Example end to end flows from model definition to optimizer.step() 10*523fa7a6SAndroid Build Coastguard Worker- `module/`: Utility class to provide an improved UX when using ExecuTorch for Training. 11*523fa7a6SAndroid Build Coastguard Worker- `optimizer/`: Cpp implementations of various optimizers, currently only SGD though Adam is planned. 12*523fa7a6SAndroid Build Coastguard Worker- `test/`: Tests that cover multiple subdirs. 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Worker## Technical Birds Eye view 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard WorkerAt a high level ExecuTorch training follows a similar flow to inference with a few extra steps. 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard WorkerInstead of relying on autograd at runtime to dynamically generate the backward graph and then walk it, 19*523fa7a6SAndroid Build Coastguard Workerwe capture the backward graph ahead of time. This lets us be a lot leaner on-device as well as 20*523fa7a6SAndroid Build Coastguard Workerletting backends have more direct control over more of the model execution. Currently the optimizer is not 21*523fa7a6SAndroid Build Coastguard Workercaptured though this may change over time. 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard WorkerLoss functions must be embedded inside the model definition (and be the first output) this is used during 24*523fa7a6SAndroid Build Coastguard Workercapture to generate the backwards graph. 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard WorkerGradients become explicit graph outputs rather then hidden tensor state. 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard WorkerSince the weights now need to be mutable during execution, they are memory planned ahead of time and copied 29*523fa7a6SAndroid Build Coastguard Workerfrom the .pte into the HeirarchicalAllocator arenas during Method init. 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard WorkerIntegration with backends/delegates is still a work in progress. 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker## End to End Example 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard WorkerTo further understand the features of ExecuTorch Training and how to leverage it, 37*523fa7a6SAndroid Build Coastguard Workerconsider the following end to end example with a neural network learning the XOR function. 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker### Lowering a joint-graph model to ExecuTorch 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard WorkerAfter following the [setting up ExecuTorch] guide. You can run 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker```bash 44*523fa7a6SAndroid Build Coastguard Workerpython3 extension/training/examples/XOR/export_model.py --outdir /tmp/foobar 45*523fa7a6SAndroid Build Coastguard Worker``` 46*523fa7a6SAndroid Build Coastguard Workerto generate the model file. Below is a walkthrough of how that script works. 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard WorkerFirst lets define our model. 49*523fa7a6SAndroid Build Coastguard Worker```python 50*523fa7a6SAndroid Build Coastguard Workerimport torch.nn as nn 51*523fa7a6SAndroid Build Coastguard Workerfrom torch.nn import functional as F 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export 54*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.experimental import _export_forward_backward 55*523fa7a6SAndroid Build Coastguard Worker 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker# Basic Net for XOR 58*523fa7a6SAndroid Build Coastguard Workerclass Net(nn.Module): 59*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 60*523fa7a6SAndroid Build Coastguard Worker super().__init__() 61*523fa7a6SAndroid Build Coastguard Worker self.linear = nn.Linear(2, 10) 62*523fa7a6SAndroid Build Coastguard Worker self.linear2 = nn.Linear(10, 2) 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 65*523fa7a6SAndroid Build Coastguard Worker return self.linear2(F.sigmoid(self.linear(x))) 66*523fa7a6SAndroid Build Coastguard Worker``` 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard WorkerThe first big difference from the normal ExecuTorch flow is that for training we must embed 69*523fa7a6SAndroid Build Coastguard Workerthe loss function into model and return the loss as our first output. 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard WorkerWe don't want to modify the original model definition so we will just wrap it. 72*523fa7a6SAndroid Build Coastguard Worker 73*523fa7a6SAndroid Build Coastguard Worker```python 74*523fa7a6SAndroid Build Coastguard Workerclass TrainingNet(nn.Module): 75*523fa7a6SAndroid Build Coastguard Worker def __init__(self, net): 76*523fa7a6SAndroid Build Coastguard Worker super().__init__() 77*523fa7a6SAndroid Build Coastguard Worker self.net = net 78*523fa7a6SAndroid Build Coastguard Worker self.loss = nn.CrossEntropyLoss() 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Worker def forward(self, input, label): 81*523fa7a6SAndroid Build Coastguard Worker pred = self.net(input) 82*523fa7a6SAndroid Build Coastguard Worker return self.loss(pred, label), pred.detach().argmax(dim=1) 83*523fa7a6SAndroid Build Coastguard Worker``` 84*523fa7a6SAndroid Build Coastguard Worker 85*523fa7a6SAndroid Build Coastguard WorkerNow that we have our model we can lower it to ExecuTorch. To do that we just have to follow 86*523fa7a6SAndroid Build Coastguard Workera few simple steps. 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker```python 89*523fa7a6SAndroid Build Coastguard Workernet = TrainingNet(Net()) 90*523fa7a6SAndroid Build Coastguard Worker 91*523fa7a6SAndroid Build Coastguard Worker# Create our inputs, only the shapes of these matter. 92*523fa7a6SAndroid Build Coastguard Workerinput = torch.randn(1, 2) 93*523fa7a6SAndroid Build Coastguard Workerlabel = torch.ones(1, dtype=torch.int64) 94*523fa7a6SAndroid Build Coastguard Worker 95*523fa7a6SAndroid Build Coastguard Worker# Captures the forward graph. The graph will look similar to the model definition now. 96*523fa7a6SAndroid Build Coastguard Worker# Will move to export_for_training soon which is the api planned to be supported in the long term. 97*523fa7a6SAndroid Build Coastguard Workerep = export(net, (input, label)) 98*523fa7a6SAndroid Build Coastguard Worker``` 99*523fa7a6SAndroid Build Coastguard Worker 100*523fa7a6SAndroid Build Coastguard WorkerThis is what the graph looks like after export 101*523fa7a6SAndroid Build Coastguard Worker```python 102*523fa7a6SAndroid Build Coastguard Worker>>>print(ep.graph_module.graph) 103*523fa7a6SAndroid Build Coastguard Worker 104*523fa7a6SAndroid Build Coastguard Workergraph(): 105*523fa7a6SAndroid Build Coastguard Worker %p_net_linear_weight : [num_users=1] = placeholder[target=p_net_linear_weight] 106*523fa7a6SAndroid Build Coastguard Worker %p_net_linear_bias : [num_users=1] = placeholder[target=p_net_linear_bias] 107*523fa7a6SAndroid Build Coastguard Worker %p_net_linear2_weight : [num_users=1] = placeholder[target=p_net_linear2_weight] 108*523fa7a6SAndroid Build Coastguard Worker %p_net_linear2_bias : [num_users=1] = placeholder[target=p_net_linear2_bias] 109*523fa7a6SAndroid Build Coastguard Worker %input : [num_users=1] = placeholder[target=input] 110*523fa7a6SAndroid Build Coastguard Worker %label : [num_users=1] = placeholder[target=label] 111*523fa7a6SAndroid Build Coastguard Worker %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%input, %p_net_linear_weight, %p_net_linear_bias), kwargs = {}) 112*523fa7a6SAndroid Build Coastguard Worker %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {}) 113*523fa7a6SAndroid Build Coastguard Worker %linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%sigmoid, %p_net_linear2_weight, %p_net_linear2_bias), kwargs = {}) 114*523fa7a6SAndroid Build Coastguard Worker %cross_entropy_loss : [num_users=1] = call_function[target=torch.ops.aten.cross_entropy_loss.default](args = (%linear_1, %label), kwargs = {}) 115*523fa7a6SAndroid Build Coastguard Worker %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%linear_1,), kwargs = {}) 116*523fa7a6SAndroid Build Coastguard Worker %argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%detach, 1), kwargs = {}) 117*523fa7a6SAndroid Build Coastguard Worker return (cross_entropy_loss, argmax) 118*523fa7a6SAndroid Build Coastguard Worker``` 119*523fa7a6SAndroid Build Coastguard Worker 120*523fa7a6SAndroid Build Coastguard WorkerIt should look pretty similar to our model's forward function. Now we need to capture the backwards graph. 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker```python 123*523fa7a6SAndroid Build Coastguard Workerep = _export_forward_backward(ep) 124*523fa7a6SAndroid Build Coastguard Worker``` 125*523fa7a6SAndroid Build Coastguard Worker 126*523fa7a6SAndroid Build Coastguard Workerand now the graph is 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker```python 129*523fa7a6SAndroid Build Coastguard Worker>>>print(ep.graph_module.graph) 130*523fa7a6SAndroid Build Coastguard Worker 131*523fa7a6SAndroid Build Coastguard Workergraph(): 132*523fa7a6SAndroid Build Coastguard Worker %p_net_linear_weight : [num_users=1] = placeholder[target=p_net_linear_weight] 133*523fa7a6SAndroid Build Coastguard Worker %p_net_linear_bias : [num_users=1] = placeholder[target=p_net_linear_bias] 134*523fa7a6SAndroid Build Coastguard Worker %p_net_linear2_weight : [num_users=1] = placeholder[target=p_net_linear2_weight] 135*523fa7a6SAndroid Build Coastguard Worker %p_net_linear2_bias : [num_users=1] = placeholder[target=p_net_linear2_bias] 136*523fa7a6SAndroid Build Coastguard Worker %input : [num_users=2] = placeholder[target=input] 137*523fa7a6SAndroid Build Coastguard Worker %label : [num_users=5] = placeholder[target=label] 138*523fa7a6SAndroid Build Coastguard Worker %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%p_net_linear_weight, [1, 0]), kwargs = {}) 139*523fa7a6SAndroid Build Coastguard Worker %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%p_net_linear_bias, %input, %permute), kwargs = {}) 140*523fa7a6SAndroid Build Coastguard Worker %sigmoid : [num_users=3] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {}) 141*523fa7a6SAndroid Build Coastguard Worker %alias : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%sigmoid,), kwargs = {}) 142*523fa7a6SAndroid Build Coastguard Worker %alias_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias,), kwargs = {}) 143*523fa7a6SAndroid Build Coastguard Worker %permute_1 : [num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%p_net_linear2_weight, [1, 0]), kwargs = {}) 144*523fa7a6SAndroid Build Coastguard Worker %addmm_1 : [num_users=2] = call_function[target=torch.ops.aten.addmm.default](args = (%p_net_linear2_bias, %sigmoid, %permute_1), kwargs = {}) 145*523fa7a6SAndroid Build Coastguard Worker %_log_softmax : [num_users=3] = call_function[target=torch.ops.aten._log_softmax.default](args = (%addmm_1, 1, False), kwargs = {}) 146*523fa7a6SAndroid Build Coastguard Worker %alias_2 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%_log_softmax,), kwargs = {}) 147*523fa7a6SAndroid Build Coastguard Worker %alias_3 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_2,), kwargs = {}) 148*523fa7a6SAndroid Build Coastguard Worker %ne : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {}) 149*523fa7a6SAndroid Build Coastguard Worker %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu}) 150*523fa7a6SAndroid Build Coastguard Worker %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne, %label, %scalar_tensor), kwargs = {}) 151*523fa7a6SAndroid Build Coastguard Worker %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%where, 1), kwargs = {}) 152*523fa7a6SAndroid Build Coastguard Worker %gather : [num_users=1] = call_function[target=torch.ops.aten.gather.default](args = (%_log_softmax, 1, %unsqueeze), kwargs = {}) 153*523fa7a6SAndroid Build Coastguard Worker %squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%gather, [1]), kwargs = {}) 154*523fa7a6SAndroid Build Coastguard Worker %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%squeeze,), kwargs = {}) 155*523fa7a6SAndroid Build Coastguard Worker %ne_1 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {}) 156*523fa7a6SAndroid Build Coastguard Worker %scalar_tensor_1 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu}) 157*523fa7a6SAndroid Build Coastguard Worker %where_1 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_1, %neg, %scalar_tensor_1), kwargs = {}) 158*523fa7a6SAndroid Build Coastguard Worker %ne_2 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {}) 159*523fa7a6SAndroid Build Coastguard Worker %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%ne_2, []), kwargs = {}) 160*523fa7a6SAndroid Build Coastguard Worker %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%sum_1,), kwargs = {dtype: torch.float32, device: cpu}) 161*523fa7a6SAndroid Build Coastguard Worker %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%where_1, []), kwargs = {}) 162*523fa7a6SAndroid Build Coastguard Worker %div : [num_users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_2, %_to_copy), kwargs = {}) 163*523fa7a6SAndroid Build Coastguard Worker %alias_4 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%addmm_1,), kwargs = {}) 164*523fa7a6SAndroid Build Coastguard Worker %alias_5 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_4,), kwargs = {}) 165*523fa7a6SAndroid Build Coastguard Worker %alias_6 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_5,), kwargs = {}) 166*523fa7a6SAndroid Build Coastguard Worker %argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%alias_6, 1), kwargs = {}) 167*523fa7a6SAndroid Build Coastguard Worker %full_like : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%div, 1), kwargs = {pin_memory: False, memory_format: torch.preserve_format}) 168*523fa7a6SAndroid Build Coastguard Worker %div_1 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%full_like, %_to_copy), kwargs = {}) 169*523fa7a6SAndroid Build Coastguard Worker %unsqueeze_1 : [num_users=3] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%label, 1), kwargs = {}) 170*523fa7a6SAndroid Build Coastguard Worker %ne_3 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_1, -100), kwargs = {}) 171*523fa7a6SAndroid Build Coastguard Worker %scalar_tensor_2 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu}) 172*523fa7a6SAndroid Build Coastguard Worker %where_2 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %unsqueeze_1, %scalar_tensor_2), kwargs = {}) 173*523fa7a6SAndroid Build Coastguard Worker %full_like_1 : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%_log_softmax, 0), kwargs = {pin_memory: False, memory_format: torch.preserve_format}) 174*523fa7a6SAndroid Build Coastguard Worker %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.value](args = (%full_like_1, 1, %where_2, -1.0), kwargs = {}) 175*523fa7a6SAndroid Build Coastguard Worker %ne_4 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_1, -100), kwargs = {}) 176*523fa7a6SAndroid Build Coastguard Worker %scalar_tensor_3 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu}) 177*523fa7a6SAndroid Build Coastguard Worker %where_3 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_4, %div_1, %scalar_tensor_3), kwargs = {}) 178*523fa7a6SAndroid Build Coastguard Worker %mul : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%scatter, %where_3), kwargs = {}) 179*523fa7a6SAndroid Build Coastguard Worker %alias_7 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_3,), kwargs = {}) 180*523fa7a6SAndroid Build Coastguard Worker %alias_8 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_7,), kwargs = {}) 181*523fa7a6SAndroid Build Coastguard Worker %exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%alias_8,), kwargs = {}) 182*523fa7a6SAndroid Build Coastguard Worker %sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [1], True), kwargs = {}) 183*523fa7a6SAndroid Build Coastguard Worker %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%exp, %sum_3), kwargs = {}) 184*523fa7a6SAndroid Build Coastguard Worker %sub : [num_users=3] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul, %mul_1), kwargs = {}) 185*523fa7a6SAndroid Build Coastguard Worker %permute_2 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_1, [1, 0]), kwargs = {}) 186*523fa7a6SAndroid Build Coastguard Worker %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%sub, %permute_2), kwargs = {}) 187*523fa7a6SAndroid Build Coastguard Worker %permute_3 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sub, [1, 0]), kwargs = {}) 188*523fa7a6SAndroid Build Coastguard Worker %mm_1 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%permute_3, %sigmoid), kwargs = {}) 189*523fa7a6SAndroid Build Coastguard Worker %permute_4 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mm_1, [1, 0]), kwargs = {}) 190*523fa7a6SAndroid Build Coastguard Worker %sum_4 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%sub, [0], True), kwargs = {}) 191*523fa7a6SAndroid Build Coastguard Worker %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_4, [2]), kwargs = {}) 192*523fa7a6SAndroid Build Coastguard Worker %permute_5 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_4, [1, 0]), kwargs = {}) 193*523fa7a6SAndroid Build Coastguard Worker %alias_9 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_1,), kwargs = {}) 194*523fa7a6SAndroid Build Coastguard Worker %alias_10 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%alias_9,), kwargs = {}) 195*523fa7a6SAndroid Build Coastguard Worker %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (1, %alias_10), kwargs = {}) 196*523fa7a6SAndroid Build Coastguard Worker %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%alias_10, %sub_1), kwargs = {}) 197*523fa7a6SAndroid Build Coastguard Worker %mul_3 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mm, %mul_2), kwargs = {}) 198*523fa7a6SAndroid Build Coastguard Worker %permute_6 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mul_3, [1, 0]), kwargs = {}) 199*523fa7a6SAndroid Build Coastguard Worker %mm_2 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%permute_6, %input), kwargs = {}) 200*523fa7a6SAndroid Build Coastguard Worker %permute_7 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mm_2, [1, 0]), kwargs = {}) 201*523fa7a6SAndroid Build Coastguard Worker %sum_5 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_3, [0], True), kwargs = {}) 202*523fa7a6SAndroid Build Coastguard Worker %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_5, [10]), kwargs = {}) 203*523fa7a6SAndroid Build Coastguard Worker %permute_8 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_7, [1, 0]), kwargs = {}) 204*523fa7a6SAndroid Build Coastguard Worker return (div, argmax, permute_8, view_1, permute_5, view) 205*523fa7a6SAndroid Build Coastguard Worker``` 206*523fa7a6SAndroid Build Coastguard Worker 207*523fa7a6SAndroid Build Coastguard WorkerIts a lot bigger! We call this the 'joint graph' or the 'forwards backwards graph'. We have explicitly captured the backwards graph 208*523fa7a6SAndroid Build Coastguard Workeralongside the forward and now our model returns [Loss, Any other user outputs, Gradients]. 209*523fa7a6SAndroid Build Coastguard Worker 210*523fa7a6SAndroid Build Coastguard WorkerFrom here we can lower the rest of the way to ExecuTorch 211*523fa7a6SAndroid Build Coastguard Worker```python 212*523fa7a6SAndroid Build Coastguard Workerep = to_edge(ep) 213*523fa7a6SAndroid Build Coastguard Worker 214*523fa7a6SAndroid Build Coastguard Worker# After calling to_executorch the weights themselves are also appended to the model outputs. This is to make 215*523fa7a6SAndroid Build Coastguard Worker# some downstream passes like memory planning a little easier. A couple of hidden utility functions are also 216*523fa7a6SAndroid Build Coastguard Worker# embedded in the model __et_training_gradients_index_<method_name>, 217*523fa7a6SAndroid Build Coastguard Worker# __et_training_parameters_index_<method_name>, __et_training_fqn_<method_name>. 218*523fa7a6SAndroid Build Coastguard Worker# 219*523fa7a6SAndroid Build Coastguard Worker# These help us partition the huge list of model outputs into meaningful sections as well as assign names to each weight/gradient. 220*523fa7a6SAndroid Build Coastguard Workerep = ep.to_executorch() 221*523fa7a6SAndroid Build Coastguard Worker 222*523fa7a6SAndroid Build Coastguard Workerwith open("xor.pte", "wb") as file: 223*523fa7a6SAndroid Build Coastguard Worker ep.write_to_file(file) 224*523fa7a6SAndroid Build Coastguard Worker``` 225*523fa7a6SAndroid Build Coastguard Worker 226*523fa7a6SAndroid Build Coastguard Worker### Run the model train script with CMAKE 227*523fa7a6SAndroid Build Coastguard WorkerAfter exporting the model for training, we can now try learning using CMake. We can build and use the train_xor, which is a sample wrapper for the ExecuTorch Runtime, TrainingModule, and SGD optimizer. We first begin by configuring the CMake build like such: 228*523fa7a6SAndroid Build Coastguard Worker```bash 229*523fa7a6SAndroid Build Coastguard Worker# cd to the root of executorch repo 230*523fa7a6SAndroid Build Coastguard Workercd executorch 231*523fa7a6SAndroid Build Coastguard Worker 232*523fa7a6SAndroid Build Coastguard Worker# Get a clean cmake-out directory 233*523fa7a6SAndroid Build Coastguard Workerrm -rf cmake-out 234*523fa7a6SAndroid Build Coastguard Workermkdir cmake-out 235*523fa7a6SAndroid Build Coastguard Worker 236*523fa7a6SAndroid Build Coastguard Worker# Configure cmake 237*523fa7a6SAndroid Build Coastguard Workercmake \ 238*523fa7a6SAndroid Build Coastguard Worker -DCMAKE_INSTALL_PREFIX=cmake-out \ 239*523fa7a6SAndroid Build Coastguard Worker -DCMAKE_BUILD_TYPE=Release \ 240*523fa7a6SAndroid Build Coastguard Worker -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ 241*523fa7a6SAndroid Build Coastguard Worker -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ 242*523fa7a6SAndroid Build Coastguard Worker -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ 243*523fa7a6SAndroid Build Coastguard Worker -DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON \ 244*523fa7a6SAndroid Build Coastguard Worker -DEXECUTORCH_ENABLE_LOGGING=ON \ 245*523fa7a6SAndroid Build Coastguard Worker -DPYTHON_EXECUTABLE=python \ 246*523fa7a6SAndroid Build Coastguard Worker -Bcmake-out . 247*523fa7a6SAndroid Build Coastguard Worker``` 248*523fa7a6SAndroid Build Coastguard WorkerThen you can build the runtime componenets with 249*523fa7a6SAndroid Build Coastguard Worker 250*523fa7a6SAndroid Build Coastguard Worker```bash 251*523fa7a6SAndroid Build Coastguard Workercmake --build cmake-out -j9 --target install --config Release 252*523fa7a6SAndroid Build Coastguard Worker``` 253*523fa7a6SAndroid Build Coastguard Worker 254*523fa7a6SAndroid Build Coastguard WorkerNow you should be able to find the executable built at `./cmake-out/extension/training/train_xor` you can run the executable with the model you generated as such 255*523fa7a6SAndroid Build Coastguard Worker```bash 256*523fa7a6SAndroid Build Coastguard Worker./cmake-out/extension/training/train_xor --model_path=./xor.pte 257*523fa7a6SAndroid Build Coastguard Worker``` 258*523fa7a6SAndroid Build Coastguard Worker 259*523fa7a6SAndroid Build Coastguard Worker## What is missing?/ What is next? 260*523fa7a6SAndroid Build Coastguard WorkerA ton! ExecuTorch training is still quite experimental and under heavy active development. Whats here currently is more of a technical preview. 261*523fa7a6SAndroid Build Coastguard Worker 262*523fa7a6SAndroid Build Coastguard WorkerThe _export_forward_backward is not very stable yet and may fail on more complicated model architectures, though we have verified it works for LoRA with LLMs. 263*523fa7a6SAndroid Build Coastguard Worker 264*523fa7a6SAndroid Build Coastguard WorkerThe ExecuTorch portable operator lib does not yet have full coverage of ops that might show up in the backwards graphs. 265*523fa7a6SAndroid Build Coastguard Worker 266*523fa7a6SAndroid Build Coastguard WorkerWe don't have a way yet to serialize the newly trained weights natively in ExecuTorch (though you can convert them to ATen tensors using extension/aten_util and then serialize them using ATen APIs). 267*523fa7a6SAndroid Build Coastguard Worker 268*523fa7a6SAndroid Build Coastguard WorkerWe plan to add a way to update models in place on-device (will be needed for finetuning). 269*523fa7a6SAndroid Build Coastguard Worker 270*523fa7a6SAndroid Build Coastguard WorkerWe are looking to integrate with many of the existing delegates/backends on ET enabling accelerated training. 271*523fa7a6SAndroid Build Coastguard Worker 272*523fa7a6SAndroid Build Coastguard Workerand so much more! 273*523fa7a6SAndroid Build Coastguard Worker 274*523fa7a6SAndroid Build Coastguard Worker## Help & Improvements 275*523fa7a6SAndroid Build Coastguard WorkerIf you have problems or questions, or have suggestions for ways to make 276*523fa7a6SAndroid Build Coastguard Workerimplementation and testing better, please reach out to the PyTorch Edge team or 277*523fa7a6SAndroid Build Coastguard Workercreate an issue on [github](https://www.github.com/pytorch/executorch/issues). 278