xref: /aosp_15_r20/external/executorch/examples/cadence/models/rnnt_predictor.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# Example script for exporting simple models to flatbuffer
8
9import logging
10
11import torch
12
13from executorch.backends.cadence.aot.ops_registrations import *  # noqa
14
15from typing import Tuple
16
17from executorch.backends.cadence.aot.export_example import export_model
18
19
20FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
21logging.basicConfig(level=logging.INFO, format=FORMAT)
22
23
24if __name__ == "__main__":
25
26    class Predictor(torch.nn.Module):
27        def __init__(
28            self,
29            num_symbols: int,
30            symbol_embedding_dim: int,
31        ) -> None:
32            super().__init__()
33            self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim)
34            self.relu = torch.nn.ReLU()
35            self.linear = torch.nn.Linear(symbol_embedding_dim, symbol_embedding_dim)
36            self.layer_norm = torch.nn.LayerNorm(symbol_embedding_dim)
37
38        def forward(
39            self,
40            input: torch.Tensor,
41            lengths: torch.Tensor,
42        ) -> Tuple[torch.Tensor, torch.Tensor]:
43            input_tb = input.permute(1, 0)
44            embedding_out = self.embedding(input_tb)
45            relu_out = self.relu(embedding_out)
46            linear_out = self.linear(relu_out)
47            layer_norm_out = self.layer_norm(linear_out)
48            return layer_norm_out.permute(1, 0, 2), lengths
49
50    # Predictor
51    model = Predictor(128, 256)
52    model.eval()
53
54    # Batch size
55    batch_size = 1
56
57    num_symbols = 128
58    max_target_length = 10
59
60    # Dummy inputs
61    predictor_input = torch.randint(0, num_symbols, (batch_size, max_target_length))
62    predictor_lengths = torch.randint(1, max_target_length + 1, (batch_size,))
63
64    example_inputs = (
65        predictor_input,
66        predictor_lengths,
67    )
68
69    export_model(model, example_inputs)
70