1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport logging 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport torch 10*523fa7a6SAndroid Build Coastguard Workerfrom torchaudio import models 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerfrom ..model_base import EagerModelBase 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workerclass Wav2LetterModel(EagerModelBase): 16*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 17*523fa7a6SAndroid Build Coastguard Worker self.batch_size = 10 18*523fa7a6SAndroid Build Coastguard Worker self.input_frames = 700 19*523fa7a6SAndroid Build Coastguard Worker self.vocab_size = 4096 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Worker def get_eager_model(self) -> torch.nn.Module: 22*523fa7a6SAndroid Build Coastguard Worker logging.info("Loading wav2letter model") 23*523fa7a6SAndroid Build Coastguard Worker wav2letter = models.Wav2Letter(num_classes=self.vocab_size) 24*523fa7a6SAndroid Build Coastguard Worker logging.info("Loaded wav2letter model") 25*523fa7a6SAndroid Build Coastguard Worker return wav2letter 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Worker def get_example_inputs(self): 28*523fa7a6SAndroid Build Coastguard Worker input_shape = (self.batch_size, 1, self.input_frames) 29*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(input_shape),) 30