xref: /aosp_15_r20/external/executorch/extension/training/pybindings/test/test.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import unittest
10
11import torch
12from executorch.exir import to_edge
13
14from executorch.extension.training import (
15    _load_for_executorch_for_training_from_buffer,
16    get_sgd_optimizer,
17)
18from torch.export.experimental import _export_forward_backward
19
20
21class TestTraining(unittest.TestCase):
22    class ModuleSimpleTrain(torch.nn.Module):
23        def __init__(self):
24            super().__init__()
25            self.linear = torch.nn.Linear(3, 3)
26            self.loss = torch.nn.CrossEntropyLoss()
27
28        def forward(self, x, y):
29            return self.loss(self.linear(x).softmax(dim=0), y)
30
31        def get_random_inputs(self):
32            return (torch.randn(3), torch.tensor([1.0, 0.0, 0.0]))
33
34    def test(self):
35        m = self.ModuleSimpleTrain()
36        ep = torch.export.export(m, m.get_random_inputs())
37        ep = _export_forward_backward(ep)
38        ep = to_edge(ep)
39        ep = ep.to_executorch()
40        buffer = ep.buffer
41        tm = _load_for_executorch_for_training_from_buffer(buffer)
42
43        tm.forward_backward("forward", m.get_random_inputs())
44        orig_param = list(tm.named_parameters().values())[0].clone()
45        optimizer = get_sgd_optimizer(
46            tm.named_parameters(),
47            0.1,
48            0,
49            0,
50            0,
51            False,
52        )
53        optimizer.step(tm.named_gradients())
54        self.assertFalse(
55            torch.allclose(orig_param, list(tm.named_parameters().values())[0])
56        )
57