1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import unittest 10 11import torch 12from executorch.exir import to_edge 13 14from executorch.extension.training import ( 15 _load_for_executorch_for_training_from_buffer, 16 get_sgd_optimizer, 17) 18from torch.export.experimental import _export_forward_backward 19 20 21class TestTraining(unittest.TestCase): 22 class ModuleSimpleTrain(torch.nn.Module): 23 def __init__(self): 24 super().__init__() 25 self.linear = torch.nn.Linear(3, 3) 26 self.loss = torch.nn.CrossEntropyLoss() 27 28 def forward(self, x, y): 29 return self.loss(self.linear(x).softmax(dim=0), y) 30 31 def get_random_inputs(self): 32 return (torch.randn(3), torch.tensor([1.0, 0.0, 0.0])) 33 34 def test(self): 35 m = self.ModuleSimpleTrain() 36 ep = torch.export.export(m, m.get_random_inputs()) 37 ep = _export_forward_backward(ep) 38 ep = to_edge(ep) 39 ep = ep.to_executorch() 40 buffer = ep.buffer 41 tm = _load_for_executorch_for_training_from_buffer(buffer) 42 43 tm.forward_backward("forward", m.get_random_inputs()) 44 orig_param = list(tm.named_parameters().values())[0].clone() 45 optimizer = get_sgd_optimizer( 46 tm.named_parameters(), 47 0.1, 48 0, 49 0, 50 0, 51 False, 52 ) 53 optimizer.step(tm.named_gradients()) 54 self.assertFalse( 55 torch.allclose(orig_param, list(tm.named_parameters().values())[0]) 56 ) 57