1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner 11 12from executorch.backends.xnnpack.test.tester import Tester 13from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower 14 15 16class TestLSTM(unittest.TestCase): 17 class LSTMLinear(torch.nn.Module): 18 def __init__(self, input_size, hidden_size, out_size): 19 super().__init__() 20 self.lstm = torch.nn.LSTM( 21 input_size=input_size, hidden_size=hidden_size, batch_first=True 22 ) 23 self.linear = torch.nn.Linear(hidden_size, hidden_size) 24 self.linear2 = torch.nn.Linear(hidden_size, out_size) 25 26 def forward(self, x): 27 x, hs = self.lstm(x) 28 x = self.linear(x[:, -1, :]) 29 x = self.linear2(x) 30 return torch.nn.functional.log_softmax(x, dim=1) 31 32 def test_fp32_lstm(self): 33 ( 34 Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),)) 35 .export() 36 .to_edge_transform_and_lower() 37 .check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"]) 38 .check_not( 39 ["p_lstm_weight", "p_lstm_bias"] 40 ) # These Should be Consumed by Delegate 41 .to_executorch() 42 .serialize() 43 .run_method_and_compare_outputs() 44 ) 45 46 def test_fp32_lstm_force_dynamic_linear(self): 47 ( 48 Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),)) 49 .export() 50 .to_edge_transform_and_lower( 51 ToEdgeTransformAndLower( 52 partitioners=[XnnpackPartitioner(force_fp32_dynamic_linear=True)] 53 ) 54 ) 55 .check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"]) 56 # Weights are supplied as input to linears 57 .check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0"]) 58 # Biases are owned by delegates 59 .check_not(["p_lstm_bias"]) 60 .to_executorch() 61 .serialize() 62 .run_method_and_compare_outputs() 63 ) 64