1# ExecuTorch On-device Training 2 3This subtree contains infrastructure to facilitate on-device training using ExecuTorch. 4This feature is experimental and under heavy active development, all the APIs are 5subject to change and many things may not work out of the box or at all in the 6current state. 7 8## Layout 9- `examples/` : Example end to end flows from model definition to optimizer.step() 10- `module/`: Utility class to provide an improved UX when using ExecuTorch for Training. 11- `optimizer/`: Cpp implementations of various optimizers, currently only SGD though Adam is planned. 12- `test/`: Tests that cover multiple subdirs. 13 14## Technical Birds Eye view 15 16At a high level ExecuTorch training follows a similar flow to inference with a few extra steps. 17 18Instead of relying on autograd at runtime to dynamically generate the backward graph and then walk it, 19we capture the backward graph ahead of time. This lets us be a lot leaner on-device as well as 20letting backends have more direct control over more of the model execution. Currently the optimizer is not 21captured though this may change over time. 22 23Loss functions must be embedded inside the model definition (and be the first output) this is used during 24capture to generate the backwards graph. 25 26Gradients become explicit graph outputs rather then hidden tensor state. 27 28Since the weights now need to be mutable during execution, they are memory planned ahead of time and copied 29from the .pte into the HeirarchicalAllocator arenas during Method init. 30 31Integration with backends/delegates is still a work in progress. 32 33 34## End to End Example 35 36To further understand the features of ExecuTorch Training and how to leverage it, 37consider the following end to end example with a neural network learning the XOR function. 38 39### Lowering a joint-graph model to ExecuTorch 40 41After following the [setting up ExecuTorch] guide. You can run 42 43```bash 44python3 extension/training/examples/XOR/export_model.py --outdir /tmp/foobar 45``` 46to generate the model file. Below is a walkthrough of how that script works. 47 48First lets define our model. 49```python 50import torch.nn as nn 51from torch.nn import functional as F 52 53from torch.export import export 54from torch.export.experimental import _export_forward_backward 55 56 57# Basic Net for XOR 58class Net(nn.Module): 59 def __init__(self): 60 super().__init__() 61 self.linear = nn.Linear(2, 10) 62 self.linear2 = nn.Linear(10, 2) 63 64 def forward(self, x): 65 return self.linear2(F.sigmoid(self.linear(x))) 66``` 67 68The first big difference from the normal ExecuTorch flow is that for training we must embed 69the loss function into model and return the loss as our first output. 70 71We don't want to modify the original model definition so we will just wrap it. 72 73```python 74class TrainingNet(nn.Module): 75 def __init__(self, net): 76 super().__init__() 77 self.net = net 78 self.loss = nn.CrossEntropyLoss() 79 80 def forward(self, input, label): 81 pred = self.net(input) 82 return self.loss(pred, label), pred.detach().argmax(dim=1) 83``` 84 85Now that we have our model we can lower it to ExecuTorch. To do that we just have to follow 86a few simple steps. 87 88```python 89net = TrainingNet(Net()) 90 91# Create our inputs, only the shapes of these matter. 92input = torch.randn(1, 2) 93label = torch.ones(1, dtype=torch.int64) 94 95# Captures the forward graph. The graph will look similar to the model definition now. 96# Will move to export_for_training soon which is the api planned to be supported in the long term. 97ep = export(net, (input, label)) 98``` 99 100This is what the graph looks like after export 101```python 102>>>print(ep.graph_module.graph) 103 104graph(): 105 %p_net_linear_weight : [num_users=1] = placeholder[target=p_net_linear_weight] 106 %p_net_linear_bias : [num_users=1] = placeholder[target=p_net_linear_bias] 107 %p_net_linear2_weight : [num_users=1] = placeholder[target=p_net_linear2_weight] 108 %p_net_linear2_bias : [num_users=1] = placeholder[target=p_net_linear2_bias] 109 %input : [num_users=1] = placeholder[target=input] 110 %label : [num_users=1] = placeholder[target=label] 111 %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%input, %p_net_linear_weight, %p_net_linear_bias), kwargs = {}) 112 %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {}) 113 %linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%sigmoid, %p_net_linear2_weight, %p_net_linear2_bias), kwargs = {}) 114 %cross_entropy_loss : [num_users=1] = call_function[target=torch.ops.aten.cross_entropy_loss.default](args = (%linear_1, %label), kwargs = {}) 115 %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%linear_1,), kwargs = {}) 116 %argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%detach, 1), kwargs = {}) 117 return (cross_entropy_loss, argmax) 118``` 119 120It should look pretty similar to our model's forward function. Now we need to capture the backwards graph. 121 122```python 123ep = _export_forward_backward(ep) 124``` 125 126and now the graph is 127 128```python 129>>>print(ep.graph_module.graph) 130 131graph(): 132 %p_net_linear_weight : [num_users=1] = placeholder[target=p_net_linear_weight] 133 %p_net_linear_bias : [num_users=1] = placeholder[target=p_net_linear_bias] 134 %p_net_linear2_weight : [num_users=1] = placeholder[target=p_net_linear2_weight] 135 %p_net_linear2_bias : [num_users=1] = placeholder[target=p_net_linear2_bias] 136 %input : [num_users=2] = placeholder[target=input] 137 %label : [num_users=5] = placeholder[target=label] 138 %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%p_net_linear_weight, [1, 0]), kwargs = {}) 139 %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%p_net_linear_bias, %input, %permute), kwargs = {}) 140 %sigmoid : [num_users=3] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {}) 141 %alias : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%sigmoid,), kwargs = {}) 142 %alias_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias,), kwargs = {}) 143 %permute_1 : [num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%p_net_linear2_weight, [1, 0]), kwargs = {}) 144 %addmm_1 : [num_users=2] = call_function[target=torch.ops.aten.addmm.default](args = (%p_net_linear2_bias, %sigmoid, %permute_1), kwargs = {}) 145 %_log_softmax : [num_users=3] = call_function[target=torch.ops.aten._log_softmax.default](args = (%addmm_1, 1, False), kwargs = {}) 146 %alias_2 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%_log_softmax,), kwargs = {}) 147 %alias_3 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_2,), kwargs = {}) 148 %ne : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {}) 149 %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu}) 150 %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne, %label, %scalar_tensor), kwargs = {}) 151 %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%where, 1), kwargs = {}) 152 %gather : [num_users=1] = call_function[target=torch.ops.aten.gather.default](args = (%_log_softmax, 1, %unsqueeze), kwargs = {}) 153 %squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%gather, [1]), kwargs = {}) 154 %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%squeeze,), kwargs = {}) 155 %ne_1 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {}) 156 %scalar_tensor_1 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu}) 157 %where_1 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_1, %neg, %scalar_tensor_1), kwargs = {}) 158 %ne_2 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {}) 159 %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%ne_2, []), kwargs = {}) 160 %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%sum_1,), kwargs = {dtype: torch.float32, device: cpu}) 161 %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%where_1, []), kwargs = {}) 162 %div : [num_users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_2, %_to_copy), kwargs = {}) 163 %alias_4 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%addmm_1,), kwargs = {}) 164 %alias_5 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_4,), kwargs = {}) 165 %alias_6 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_5,), kwargs = {}) 166 %argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%alias_6, 1), kwargs = {}) 167 %full_like : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%div, 1), kwargs = {pin_memory: False, memory_format: torch.preserve_format}) 168 %div_1 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%full_like, %_to_copy), kwargs = {}) 169 %unsqueeze_1 : [num_users=3] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%label, 1), kwargs = {}) 170 %ne_3 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_1, -100), kwargs = {}) 171 %scalar_tensor_2 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu}) 172 %where_2 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %unsqueeze_1, %scalar_tensor_2), kwargs = {}) 173 %full_like_1 : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%_log_softmax, 0), kwargs = {pin_memory: False, memory_format: torch.preserve_format}) 174 %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.value](args = (%full_like_1, 1, %where_2, -1.0), kwargs = {}) 175 %ne_4 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_1, -100), kwargs = {}) 176 %scalar_tensor_3 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu}) 177 %where_3 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_4, %div_1, %scalar_tensor_3), kwargs = {}) 178 %mul : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%scatter, %where_3), kwargs = {}) 179 %alias_7 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_3,), kwargs = {}) 180 %alias_8 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_7,), kwargs = {}) 181 %exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%alias_8,), kwargs = {}) 182 %sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [1], True), kwargs = {}) 183 %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%exp, %sum_3), kwargs = {}) 184 %sub : [num_users=3] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul, %mul_1), kwargs = {}) 185 %permute_2 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_1, [1, 0]), kwargs = {}) 186 %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%sub, %permute_2), kwargs = {}) 187 %permute_3 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sub, [1, 0]), kwargs = {}) 188 %mm_1 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%permute_3, %sigmoid), kwargs = {}) 189 %permute_4 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mm_1, [1, 0]), kwargs = {}) 190 %sum_4 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%sub, [0], True), kwargs = {}) 191 %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_4, [2]), kwargs = {}) 192 %permute_5 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_4, [1, 0]), kwargs = {}) 193 %alias_9 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_1,), kwargs = {}) 194 %alias_10 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%alias_9,), kwargs = {}) 195 %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (1, %alias_10), kwargs = {}) 196 %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%alias_10, %sub_1), kwargs = {}) 197 %mul_3 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mm, %mul_2), kwargs = {}) 198 %permute_6 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mul_3, [1, 0]), kwargs = {}) 199 %mm_2 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%permute_6, %input), kwargs = {}) 200 %permute_7 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mm_2, [1, 0]), kwargs = {}) 201 %sum_5 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_3, [0], True), kwargs = {}) 202 %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_5, [10]), kwargs = {}) 203 %permute_8 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_7, [1, 0]), kwargs = {}) 204 return (div, argmax, permute_8, view_1, permute_5, view) 205``` 206 207Its a lot bigger! We call this the 'joint graph' or the 'forwards backwards graph'. We have explicitly captured the backwards graph 208alongside the forward and now our model returns [Loss, Any other user outputs, Gradients]. 209 210From here we can lower the rest of the way to ExecuTorch 211```python 212ep = to_edge(ep) 213 214# After calling to_executorch the weights themselves are also appended to the model outputs. This is to make 215# some downstream passes like memory planning a little easier. A couple of hidden utility functions are also 216# embedded in the model __et_training_gradients_index_<method_name>, 217# __et_training_parameters_index_<method_name>, __et_training_fqn_<method_name>. 218# 219# These help us partition the huge list of model outputs into meaningful sections as well as assign names to each weight/gradient. 220ep = ep.to_executorch() 221 222with open("xor.pte", "wb") as file: 223 ep.write_to_file(file) 224``` 225 226### Run the model train script with CMAKE 227After exporting the model for training, we can now try learning using CMake. We can build and use the train_xor, which is a sample wrapper for the ExecuTorch Runtime, TrainingModule, and SGD optimizer. We first begin by configuring the CMake build like such: 228```bash 229# cd to the root of executorch repo 230cd executorch 231 232# Get a clean cmake-out directory 233rm -rf cmake-out 234mkdir cmake-out 235 236# Configure cmake 237cmake \ 238 -DCMAKE_INSTALL_PREFIX=cmake-out \ 239 -DCMAKE_BUILD_TYPE=Release \ 240 -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ 241 -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ 242 -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ 243 -DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON \ 244 -DEXECUTORCH_ENABLE_LOGGING=ON \ 245 -DPYTHON_EXECUTABLE=python \ 246 -Bcmake-out . 247``` 248Then you can build the runtime componenets with 249 250```bash 251cmake --build cmake-out -j9 --target install --config Release 252``` 253 254Now you should be able to find the executable built at `./cmake-out/extension/training/train_xor` you can run the executable with the model you generated as such 255```bash 256./cmake-out/extension/training/train_xor --model_path=./xor.pte 257``` 258 259## What is missing?/ What is next? 260A ton! ExecuTorch training is still quite experimental and under heavy active development. Whats here currently is more of a technical preview. 261 262The _export_forward_backward is not very stable yet and may fail on more complicated model architectures, though we have verified it works for LoRA with LLMs. 263 264The ExecuTorch portable operator lib does not yet have full coverage of ops that might show up in the backwards graphs. 265 266We don't have a way yet to serialize the newly trained weights natively in ExecuTorch (though you can convert them to ATen tensors using extension/aten_util and then serialize them using ATen APIs). 267 268We plan to add a way to update models in place on-device (will be needed for finetuning). 269 270We are looking to integrate with many of the existing delegates/backends on ET enabling accelerated training. 271 272and so much more! 273 274## Help & Improvements 275If you have problems or questions, or have suggestions for ways to make 276implementation and testing better, please reach out to the PyTorch Edge team or 277create an issue on [github](https://www.github.com/pytorch/executorch/issues). 278