xref: /aosp_15_r20/external/executorch/examples/models/emformer_rnnt/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8import logging
9
10import torch
11import torchaudio
12
13from ..model_base import EagerModelBase
14
15
16FORMAT = "[%(filename)s:%(lineno)s] %(message)s"
17logging.basicConfig(format=FORMAT)
18
19
20__all__ = [
21    "EmformerRnntTranscriberModel",
22    "EmformerRnntPredictorModel",
23    "EmformerRnntJoinerModel",
24]
25
26
27class EmformerRnntTranscriberExample(torch.nn.Module):
28    """
29    This is a wrapper for validating transcriber for the Emformer RNN-T architecture.
30    It does not reflect the actual usage such as beam search, but rather an example for the export workflow.
31    """
32
33    def __init__(self) -> None:
34        super().__init__()
35        bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH
36        decoder = bundle.get_decoder()
37        m = decoder.model
38        self.rnnt = m
39
40    def forward(self, transcribe_inputs):
41        return self.rnnt.transcribe(*transcribe_inputs)
42
43
44class EmformerRnntTranscriberModel(EagerModelBase):
45    def __init__(self):
46        pass
47
48    def get_eager_model(self) -> torch.nn.Module:
49        logging.info("Loading emformer rnnt transcriber")
50        m = EmformerRnntTranscriberExample()
51        logging.info("Loaded emformer rnnt transcriber")
52        return m
53
54    def get_example_inputs(self):
55        transcribe_inputs = (
56            torch.randn(1, 128, 80),
57            torch.tensor([128]),
58        )
59        return (transcribe_inputs,)
60
61
62class EmformerRnntPredictorExample(torch.nn.Module):
63    """
64    This is a wrapper for validating predictor for the Emformer RNN-T architecture.
65    It does not reflect the actual usage such as beam search, but rather an example for the export workflow.
66    """
67
68    def __init__(self) -> None:
69        super().__init__()
70        bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH
71        decoder = bundle.get_decoder()
72        m = decoder.model
73        self.rnnt = m
74
75    def forward(self, predict_inputs):
76        return self.rnnt.predict(*predict_inputs)
77
78
79class EmformerRnntPredictorModel(EagerModelBase):
80    def __init__(self):
81        pass
82
83    def get_eager_model(self) -> torch.nn.Module:
84        logging.info("Loading emformer rnnt predictor")
85        m = EmformerRnntPredictorExample()
86        logging.info("Loaded emformer rnnt predictor")
87        return m
88
89    def get_example_inputs(self):
90        predict_inputs = (
91            torch.zeros([1, 128], dtype=int),
92            torch.tensor([128], dtype=int),
93            None,
94        )
95        return (predict_inputs,)
96
97
98class EmformerRnntJoinerExample(torch.nn.Module):
99    """
100    This is a wrapper for validating joiner for the Emformer RNN-T architecture.
101    It does not reflect the actual usage such as beam search, but rather an example for the export workflow.
102    """
103
104    def __init__(self) -> None:
105        super().__init__()
106        bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH
107        decoder = bundle.get_decoder()
108        m = decoder.model
109        self.rnnt = m
110
111    def forward(self, predict_inputs):
112        return self.rnnt.join(*predict_inputs)
113
114
115class EmformerRnntJoinerModel(EagerModelBase):
116    def __init__(self):
117        pass
118
119    def get_eager_model(self) -> torch.nn.Module:
120        logging.info("Loading emformer rnnt joiner")
121        m = EmformerRnntJoinerExample()
122        logging.info("Loaded emformer rnnt joiner")
123        return m
124
125    def get_example_inputs(self):
126        join_inputs = (
127            torch.rand([1, 128, 1024]),
128            torch.tensor([128]),
129            torch.rand([1, 128, 1024]),
130            torch.tensor([128]),
131        )
132        return (join_inputs,)
133