1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# Example script for exporting simple models to flatbuffer 8 9import logging 10 11from executorch.backends.cadence.aot.ops_registrations import * # noqa 12 13import torch 14 15from executorch.backends.cadence.aot.export_example import export_model 16from torchaudio.models.wav2vec2.model import wav2vec2_model, Wav2Vec2Model 17 18FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 19logging.basicConfig(level=logging.INFO, format=FORMAT) 20 21 22def main() -> None: 23 # The wrapper is needed to avoid issues with the optional second arguments 24 # of Wav2Vec2Models. 25 class Wav2Vec2ModelWrapper(torch.nn.Module): 26 def __init__(self, model: Wav2Vec2Model): 27 super().__init__() 28 self.model = model 29 30 def forward(self, x): 31 out, _ = self.model(x) 32 return out 33 34 _model = wav2vec2_model( 35 extractor_mode="layer_norm", 36 extractor_conv_layer_config=None, 37 extractor_conv_bias=False, 38 encoder_embed_dim=768, 39 encoder_projection_dropout=0.1, 40 encoder_pos_conv_kernel=128, 41 encoder_pos_conv_groups=16, 42 encoder_num_layers=12, 43 encoder_num_heads=12, 44 encoder_attention_dropout=0.1, 45 encoder_ff_interm_features=3072, 46 encoder_ff_interm_dropout=0.0, 47 encoder_dropout=0.1, 48 encoder_layer_norm_first=False, 49 encoder_layer_drop=0.1, 50 aux_num_out=None, 51 ) 52 _model.eval() 53 54 model = Wav2Vec2ModelWrapper(_model) 55 model.eval() 56 57 # test input 58 audio_len = 1680 59 example_inputs = (torch.rand(1, audio_len),) 60 61 export_model(model, example_inputs) 62 63 64if __name__ == "__main__": 65 main() 66