xref: /aosp_15_r20/external/executorch/examples/models/lstm/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3# Copyright 2024 Arm Limited and/or its affiliates.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import logging
9
10import torch
11
12from torch.nn.quantizable.modules import rnn
13
14from ..model_base import EagerModelBase
15
16
17class LSTMModel(EagerModelBase):
18    def __init__(self):
19        pass
20
21    def get_eager_model(self) -> torch.nn.Module:
22        logging.info("Loading LSTM model")
23        lstm = rnn.LSTM(10, 20, 2)
24        logging.info("Loaded LSTM model")
25        return lstm
26
27    def get_example_inputs(self):
28        input_tensor = torch.randn(5, 3, 10)
29        h0 = torch.randn(2, 3, 20)
30        c0 = torch.randn(2, 3, 20)
31        return (input_tensor, (h0, c0))
32