1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8import logging 9 10import torch 11import torchaudio 12 13from ..model_base import EagerModelBase 14 15 16FORMAT = "[%(filename)s:%(lineno)s] %(message)s" 17logging.basicConfig(format=FORMAT) 18 19 20__all__ = [ 21 "EmformerRnntTranscriberModel", 22 "EmformerRnntPredictorModel", 23 "EmformerRnntJoinerModel", 24] 25 26 27class EmformerRnntTranscriberExample(torch.nn.Module): 28 """ 29 This is a wrapper for validating transcriber for the Emformer RNN-T architecture. 30 It does not reflect the actual usage such as beam search, but rather an example for the export workflow. 31 """ 32 33 def __init__(self) -> None: 34 super().__init__() 35 bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH 36 decoder = bundle.get_decoder() 37 m = decoder.model 38 self.rnnt = m 39 40 def forward(self, transcribe_inputs): 41 return self.rnnt.transcribe(*transcribe_inputs) 42 43 44class EmformerRnntTranscriberModel(EagerModelBase): 45 def __init__(self): 46 pass 47 48 def get_eager_model(self) -> torch.nn.Module: 49 logging.info("Loading emformer rnnt transcriber") 50 m = EmformerRnntTranscriberExample() 51 logging.info("Loaded emformer rnnt transcriber") 52 return m 53 54 def get_example_inputs(self): 55 transcribe_inputs = ( 56 torch.randn(1, 128, 80), 57 torch.tensor([128]), 58 ) 59 return (transcribe_inputs,) 60 61 62class EmformerRnntPredictorExample(torch.nn.Module): 63 """ 64 This is a wrapper for validating predictor for the Emformer RNN-T architecture. 65 It does not reflect the actual usage such as beam search, but rather an example for the export workflow. 66 """ 67 68 def __init__(self) -> None: 69 super().__init__() 70 bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH 71 decoder = bundle.get_decoder() 72 m = decoder.model 73 self.rnnt = m 74 75 def forward(self, predict_inputs): 76 return self.rnnt.predict(*predict_inputs) 77 78 79class EmformerRnntPredictorModel(EagerModelBase): 80 def __init__(self): 81 pass 82 83 def get_eager_model(self) -> torch.nn.Module: 84 logging.info("Loading emformer rnnt predictor") 85 m = EmformerRnntPredictorExample() 86 logging.info("Loaded emformer rnnt predictor") 87 return m 88 89 def get_example_inputs(self): 90 predict_inputs = ( 91 torch.zeros([1, 128], dtype=int), 92 torch.tensor([128], dtype=int), 93 None, 94 ) 95 return (predict_inputs,) 96 97 98class EmformerRnntJoinerExample(torch.nn.Module): 99 """ 100 This is a wrapper for validating joiner for the Emformer RNN-T architecture. 101 It does not reflect the actual usage such as beam search, but rather an example for the export workflow. 102 """ 103 104 def __init__(self) -> None: 105 super().__init__() 106 bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH 107 decoder = bundle.get_decoder() 108 m = decoder.model 109 self.rnnt = m 110 111 def forward(self, predict_inputs): 112 return self.rnnt.join(*predict_inputs) 113 114 115class EmformerRnntJoinerModel(EagerModelBase): 116 def __init__(self): 117 pass 118 119 def get_eager_model(self) -> torch.nn.Module: 120 logging.info("Loading emformer rnnt joiner") 121 m = EmformerRnntJoinerExample() 122 logging.info("Loaded emformer rnnt joiner") 123 return m 124 125 def get_example_inputs(self): 126 join_inputs = ( 127 torch.rand([1, 128, 1024]), 128 torch.tensor([128]), 129 torch.rand([1, 128, 1024]), 130 torch.tensor([128]), 131 ) 132 return (join_inputs,) 133