xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/lstm.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
11
12from executorch.backends.xnnpack.test.tester import Tester
13from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
14
15
16class TestLSTM(unittest.TestCase):
17    class LSTMLinear(torch.nn.Module):
18        def __init__(self, input_size, hidden_size, out_size):
19            super().__init__()
20            self.lstm = torch.nn.LSTM(
21                input_size=input_size, hidden_size=hidden_size, batch_first=True
22            )
23            self.linear = torch.nn.Linear(hidden_size, hidden_size)
24            self.linear2 = torch.nn.Linear(hidden_size, out_size)
25
26        def forward(self, x):
27            x, hs = self.lstm(x)
28            x = self.linear(x[:, -1, :])
29            x = self.linear2(x)
30            return torch.nn.functional.log_softmax(x, dim=1)
31
32    def test_fp32_lstm(self):
33        (
34            Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),))
35            .export()
36            .to_edge_transform_and_lower()
37            .check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
38            .check_not(
39                ["p_lstm_weight", "p_lstm_bias"]
40            )  # These Should be Consumed by Delegate
41            .to_executorch()
42            .serialize()
43            .run_method_and_compare_outputs()
44        )
45
46    def test_fp32_lstm_force_dynamic_linear(self):
47        (
48            Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),))
49            .export()
50            .to_edge_transform_and_lower(
51                ToEdgeTransformAndLower(
52                    partitioners=[XnnpackPartitioner(force_fp32_dynamic_linear=True)]
53                )
54            )
55            .check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
56            # Weights are supplied as input to linears
57            .check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0"])
58            # Biases are owned by delegates
59            .check_not(["p_lstm_bias"])
60            .to_executorch()
61            .serialize()
62            .run_method_and_compare_outputs()
63        )
64