1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# Copyright 2024 Arm Limited and/or its affiliates. 4*523fa7a6SAndroid Build Coastguard Worker# 5*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 6*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 7*523fa7a6SAndroid Build Coastguard Worker 8*523fa7a6SAndroid Build Coastguard Workerimport logging 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerimport torch 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerfrom torch.nn.quantizable.modules import rnn 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerfrom ..model_base import EagerModelBase 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerclass LSTMModel(EagerModelBase): 18*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 19*523fa7a6SAndroid Build Coastguard Worker pass 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Worker def get_eager_model(self) -> torch.nn.Module: 22*523fa7a6SAndroid Build Coastguard Worker logging.info("Loading LSTM model") 23*523fa7a6SAndroid Build Coastguard Worker lstm = rnn.LSTM(10, 20, 2) 24*523fa7a6SAndroid Build Coastguard Worker logging.info("Loaded LSTM model") 25*523fa7a6SAndroid Build Coastguard Worker return lstm 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Worker def get_example_inputs(self): 28*523fa7a6SAndroid Build Coastguard Worker input_tensor = torch.randn(5, 3, 10) 29*523fa7a6SAndroid Build Coastguard Worker h0 = torch.randn(2, 3, 20) 30*523fa7a6SAndroid Build Coastguard Worker c0 = torch.randn(2, 3, 20) 31*523fa7a6SAndroid Build Coastguard Worker return (input_tensor, (h0, c0)) 32