1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# Example script for exporting simple models to flatbuffer 8 9import logging 10 11import torch 12 13from executorch.backends.cadence.aot.ops_registrations import * # noqa 14 15from typing import List, Optional, Tuple 16 17from executorch.backends.cadence.aot.export_example import export_model 18from torchaudio.prototype.models import ConvEmformer 19 20 21FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 22logging.basicConfig(level=logging.INFO, format=FORMAT) 23 24 25if __name__ == "__main__": 26 27 class _TimeReduction(torch.nn.Module): 28 def __init__(self, stride: int) -> None: 29 super().__init__() 30 self.stride = stride 31 32 def forward( 33 self, input: torch.Tensor, lengths: torch.Tensor 34 ) -> Tuple[torch.Tensor, torch.Tensor]: 35 B, T, D = input.shape 36 num_frames = T - (T % self.stride) 37 input = input[:, :num_frames, :] 38 lengths = lengths.div(self.stride, rounding_mode="trunc") 39 T_max = num_frames // self.stride 40 41 output = input.reshape(B, T_max, D * self.stride) 42 output = output.contiguous() 43 return output, lengths 44 45 class ConvEmformerEncoder(torch.nn.Module): 46 def __init__( 47 self, 48 *, 49 input_dim: int, 50 output_dim: int, 51 segment_length: int, 52 kernel_size: int, 53 right_context_length: int, 54 time_reduction_stride: int, 55 transformer_input_dim: int, 56 transformer_num_heads: int, 57 transformer_ffn_dim: int, 58 transformer_num_layers: int, 59 transformer_left_context_length: int, 60 transformer_dropout: float = 0.0, 61 transformer_activation: str = "relu", 62 transformer_max_memory_size: int = 0, 63 transformer_weight_init_scale_strategy: str = "depthwise", 64 transformer_tanh_on_mem: bool = False, 65 ) -> None: 66 super().__init__() 67 self.time_reduction = _TimeReduction(time_reduction_stride) 68 self.input_linear = torch.nn.Linear( 69 input_dim * time_reduction_stride, 70 transformer_input_dim, 71 bias=False, 72 ) 73 self.transformer = ConvEmformer( 74 transformer_input_dim, 75 transformer_num_heads, 76 transformer_ffn_dim, 77 transformer_num_layers, 78 segment_length // time_reduction_stride, 79 kernel_size=kernel_size, 80 dropout=transformer_dropout, 81 ffn_activation=transformer_activation, 82 left_context_length=transformer_left_context_length, 83 right_context_length=right_context_length // time_reduction_stride, 84 max_memory_size=transformer_max_memory_size, 85 weight_init_scale_strategy=transformer_weight_init_scale_strategy, 86 tanh_on_mem=transformer_tanh_on_mem, 87 conv_activation="silu", 88 ) 89 self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim) 90 self.layer_norm = torch.nn.LayerNorm(output_dim) 91 92 def forward( 93 self, input: torch.Tensor, lengths: torch.Tensor 94 ) -> Tuple[torch.Tensor, torch.Tensor]: 95 time_reduction_out, time_reduction_lengths = self.time_reduction( 96 input, lengths 97 ) 98 input_linear_out = self.input_linear(time_reduction_out) 99 transformer_out, transformer_lengths = self.transformer( 100 input_linear_out, time_reduction_lengths 101 ) 102 output_linear_out = self.output_linear(transformer_out) 103 layer_norm_out = self.layer_norm(output_linear_out) 104 return layer_norm_out, transformer_lengths 105 106 @torch.jit.export 107 def infer( 108 self, 109 input: torch.Tensor, 110 lengths: torch.Tensor, 111 states: Optional[List[List[torch.Tensor]]], 112 ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: 113 time_reduction_out, time_reduction_lengths = self.time_reduction( 114 input, lengths 115 ) 116 input_linear_out = self.input_linear(time_reduction_out) 117 ( 118 transformer_out, 119 transformer_lengths, 120 transformer_states, 121 ) = self.transformer.infer(input_linear_out, time_reduction_lengths, states) 122 output_linear_out = self.output_linear(transformer_out) 123 layer_norm_out = self.layer_norm(output_linear_out) 124 return layer_norm_out, transformer_lengths, transformer_states 125 126 # Instantiate model 127 time_reduction_stride = 4 128 encoder = ConvEmformerEncoder( 129 input_dim=80, 130 output_dim=256, 131 segment_length=4 * time_reduction_stride, 132 kernel_size=7, 133 right_context_length=1 * time_reduction_stride, 134 time_reduction_stride=time_reduction_stride, 135 transformer_input_dim=128, 136 transformer_num_heads=4, 137 transformer_ffn_dim=512, 138 transformer_num_layers=1, 139 transformer_left_context_length=10, 140 transformer_tanh_on_mem=True, 141 ) 142 143 # Batch size 144 batch_size = 1 145 146 max_input_length = 100 147 input_dim = 80 148 right_context_length = 4 149 150 # Dummy inputs 151 transcriber_input = torch.rand( 152 batch_size, max_input_length + right_context_length, input_dim 153 ) 154 transcriber_lengths = torch.randint(1, max_input_length + 1, (batch_size,)) 155 156 example_inputs = ( 157 transcriber_input, 158 transcriber_lengths, 159 ) 160 161 export_model(encoder, example_inputs) 162