1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport itertools 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, List, Optional, Tuple, Union 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerimport torch # noqa: F401 15*523fa7a6SAndroid Build Coastguard Workerimport torch.nn as nn 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import to_edge 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import LoweredBackendModule 18*523fa7a6SAndroid Build Coastguard Workerfrom torch import Tensor 19*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Worker# TODO: add one more test for data dependent op plus repeat 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Workerclass TensorItem(nn.Module): 25*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 26*523fa7a6SAndroid Build Coastguard Worker super().__init__() 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard Worker def forward(self, arg1: torch.Tensor, arg2: torch.Tensor) -> torch.Tensor: 29*523fa7a6SAndroid Build Coastguard Worker h = arg1.item() 30*523fa7a6SAndroid Build Coastguard Worker w = arg2.item() 31*523fa7a6SAndroid Build Coastguard Worker torch._check(h >= 2) 32*523fa7a6SAndroid Build Coastguard Worker torch._check(h <= 100) 33*523fa7a6SAndroid Build Coastguard Worker torch._check(w >= 2) 34*523fa7a6SAndroid Build Coastguard Worker torch._check(w <= 100) 35*523fa7a6SAndroid Build Coastguard Worker return torch.ones(int(h), int(w)) 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]: 38*523fa7a6SAndroid Build Coastguard Worker return (torch.tensor(10), torch.tensor(20)) 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Workerclass Repeat(nn.Module): 42*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 43*523fa7a6SAndroid Build Coastguard Worker super().__init__() 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Worker def forward( 46*523fa7a6SAndroid Build Coastguard Worker self, arg1: torch.Tensor, arg2: torch.Tensor 47*523fa7a6SAndroid Build Coastguard Worker ) -> Tuple[torch.Tensor, torch.Tensor]: 48*523fa7a6SAndroid Build Coastguard Worker x = arg2.repeat(arg1.size(0), 1) 49*523fa7a6SAndroid Build Coastguard Worker return x * x, arg2 + arg2 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]: 52*523fa7a6SAndroid Build Coastguard Worker return (torch.rand(4), torch.rand(5)) 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker def get_dynamic_shape(self) -> Any: # pyre-ignore[3] 55*523fa7a6SAndroid Build Coastguard Worker dim = torch.export.Dim("dim", max=10) 56*523fa7a6SAndroid Build Coastguard Worker dim2 = torch.export.Dim("dim2", max=10) 57*523fa7a6SAndroid Build Coastguard Worker return ({0: dim}, {0: dim2}) 58*523fa7a6SAndroid Build Coastguard Worker 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Workerclass ModelWithUnusedArg(nn.Module): 61*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 62*523fa7a6SAndroid Build Coastguard Worker super().__init__() 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker def forward(self, arg1: torch.Tensor, arg2: torch.Tensor) -> torch.Tensor: 65*523fa7a6SAndroid Build Coastguard Worker return torch.sin(arg1) 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]: 68*523fa7a6SAndroid Build Coastguard Worker return (torch.rand(4), torch.rand(5)) 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Workerclass MLP(nn.Module): 72*523fa7a6SAndroid Build Coastguard Worker def __init__(self, n_layer: int = 1, output_size: int = 1) -> None: 73*523fa7a6SAndroid Build Coastguard Worker super().__init__() 74*523fa7a6SAndroid Build Coastguard Worker self.n_layer = n_layer 75*523fa7a6SAndroid Build Coastguard Worker self.output_size = output_size 76*523fa7a6SAndroid Build Coastguard Worker # input shape [batch_size, n_layer+output_size] 77*523fa7a6SAndroid Build Coastguard Worker # each linear layer reduce the activation dim 1 size by 1. 78*523fa7a6SAndroid Build Coastguard Worker self.mlp = torch.nn.Sequential( 79*523fa7a6SAndroid Build Coastguard Worker *itertools.chain( 80*523fa7a6SAndroid Build Coastguard Worker *( 81*523fa7a6SAndroid Build Coastguard Worker [nn.Linear(i + output_size, i - 1 + output_size)] 82*523fa7a6SAndroid Build Coastguard Worker + ([nn.ReLU()] if i != 1 else []) 83*523fa7a6SAndroid Build Coastguard Worker for i in range(n_layer, 0, -1) 84*523fa7a6SAndroid Build Coastguard Worker ) 85*523fa7a6SAndroid Build Coastguard Worker ) 86*523fa7a6SAndroid Build Coastguard Worker ) 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker def forward(self, inputs: torch.Tensor) -> torch.Tensor: 89*523fa7a6SAndroid Build Coastguard Worker return self.mlp(inputs) 90*523fa7a6SAndroid Build Coastguard Worker 91*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: 92*523fa7a6SAndroid Build Coastguard Worker return (torch.rand(2, self.n_layer + self.output_size),) 93*523fa7a6SAndroid Build Coastguard Worker 94*523fa7a6SAndroid Build Coastguard Worker 95*523fa7a6SAndroid Build Coastguard Workerclass Identity(nn.Module): 96*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 97*523fa7a6SAndroid Build Coastguard Worker super().__init__() 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker def forward(self, input: Tensor) -> Tensor: 100*523fa7a6SAndroid Build Coastguard Worker return torch.clone(input) 101*523fa7a6SAndroid Build Coastguard Worker 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Workerclass Reshape(nn.Module): 104*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 105*523fa7a6SAndroid Build Coastguard Worker super().__init__() 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker def forward( 108*523fa7a6SAndroid Build Coastguard Worker self, x: Tensor, *new_shape: Union[torch.Size, Tuple[int, ...], List[int]] 109*523fa7a6SAndroid Build Coastguard Worker ) -> Tensor: 110*523fa7a6SAndroid Build Coastguard Worker if len(new_shape) == 1 and ( 111*523fa7a6SAndroid Build Coastguard Worker isinstance(new_shape[0], tuple) or isinstance(new_shape[0], list) 112*523fa7a6SAndroid Build Coastguard Worker ): 113*523fa7a6SAndroid Build Coastguard Worker return x.reshape(new_shape[0]) 114*523fa7a6SAndroid Build Coastguard Worker assert isinstance(new_shape, Union[torch.Size, Tuple[int, ...], List[int]]) 115*523fa7a6SAndroid Build Coastguard Worker return x.reshape(new_shape) 116*523fa7a6SAndroid Build Coastguard Worker 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Workerclass Transpose(nn.Module): 119*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 120*523fa7a6SAndroid Build Coastguard Worker super().__init__() 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor: 123*523fa7a6SAndroid Build Coastguard Worker return x.transpose(dim0, dim1) 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Worker 126*523fa7a6SAndroid Build Coastguard Workerclass Mul(nn.Module): 127*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 128*523fa7a6SAndroid Build Coastguard Worker super().__init__() 129*523fa7a6SAndroid Build Coastguard Worker 130*523fa7a6SAndroid Build Coastguard Worker def forward(self, input: Tensor, other: Tensor) -> Tensor: 131*523fa7a6SAndroid Build Coastguard Worker # or return torch.mul(input, other) 132*523fa7a6SAndroid Build Coastguard Worker return input * other 133*523fa7a6SAndroid Build Coastguard Worker 134*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self) -> Tuple[Tensor, Tensor]: 135*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(3, 2), torch.randn(3, 2)) 136*523fa7a6SAndroid Build Coastguard Worker 137*523fa7a6SAndroid Build Coastguard Worker 138*523fa7a6SAndroid Build Coastguard Workerclass ElementwiseAdd(nn.Module): 139*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 140*523fa7a6SAndroid Build Coastguard Worker super().__init__() 141*523fa7a6SAndroid Build Coastguard Worker 142*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: Tensor, y: Tensor) -> Tensor: 143*523fa7a6SAndroid Build Coastguard Worker return x + y 144*523fa7a6SAndroid Build Coastguard Worker 145*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self) -> Tuple[Tensor, Tensor]: 146*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(1, 3), torch.randn(1, 3)) 147*523fa7a6SAndroid Build Coastguard Worker 148*523fa7a6SAndroid Build Coastguard Worker 149*523fa7a6SAndroid Build Coastguard Workerclass BasicSinMax(nn.Module): 150*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 151*523fa7a6SAndroid Build Coastguard Worker super().__init__() 152*523fa7a6SAndroid Build Coastguard Worker 153*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: Tensor) -> Tensor: 154*523fa7a6SAndroid Build Coastguard Worker return torch.sin(x) 155*523fa7a6SAndroid Build Coastguard Worker 156*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self) -> Tuple[Tensor]: 157*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(100),) 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Worker 160*523fa7a6SAndroid Build Coastguard Workerclass CompositeDelegateModule(torch.nn.Module): 161*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 162*523fa7a6SAndroid Build Coastguard Worker super().__init__() 163*523fa7a6SAndroid Build Coastguard Worker 164*523fa7a6SAndroid Build Coastguard Worker class DelegateAdd(nn.Module): 165*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 166*523fa7a6SAndroid Build Coastguard Worker super().__init__() 167*523fa7a6SAndroid Build Coastguard Worker 168*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: Tensor, y: Tensor) -> Tensor: 169*523fa7a6SAndroid Build Coastguard Worker return [x + y] 170*523fa7a6SAndroid Build Coastguard Worker 171*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self) -> Tuple[Tensor, Tensor]: 172*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(1, 3), torch.randn(1, 3)) 173*523fa7a6SAndroid Build Coastguard Worker 174*523fa7a6SAndroid Build Coastguard Worker delegated_m = DelegateAdd() 175*523fa7a6SAndroid Build Coastguard Worker edge_ir_m = to_edge( 176*523fa7a6SAndroid Build Coastguard Worker export( 177*523fa7a6SAndroid Build Coastguard Worker delegated_m, 178*523fa7a6SAndroid Build Coastguard Worker delegated_m.get_random_inputs(), 179*523fa7a6SAndroid Build Coastguard Worker ) 180*523fa7a6SAndroid Build Coastguard Worker ) 181*523fa7a6SAndroid Build Coastguard Worker lowered_module = LoweredBackendModule( 182*523fa7a6SAndroid Build Coastguard Worker edge_program=edge_ir_m.exported_program(), 183*523fa7a6SAndroid Build Coastguard Worker backend_id="backend_demo", 184*523fa7a6SAndroid Build Coastguard Worker processed_bytes=bytes("basic_module_add", encoding="utf8"), 185*523fa7a6SAndroid Build Coastguard Worker compile_specs=[], 186*523fa7a6SAndroid Build Coastguard Worker ) 187*523fa7a6SAndroid Build Coastguard Worker self.lowered_module: LoweredBackendModule = lowered_module 188*523fa7a6SAndroid Build Coastguard Worker 189*523fa7a6SAndroid Build Coastguard Worker def forward(self, a: exir.Value, b: exir.Value, s: Tensor) -> Tensor: 190*523fa7a6SAndroid Build Coastguard Worker res = self.lowered_module(a, b) 191*523fa7a6SAndroid Build Coastguard Worker res = res[0] * s 192*523fa7a6SAndroid Build Coastguard Worker return res 193*523fa7a6SAndroid Build Coastguard Worker 194*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self) -> Tuple[Tensor, Tensor, Tensor]: 195*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(1, 3), torch.randn(1, 3), torch.randn(1, 3)) 196*523fa7a6SAndroid Build Coastguard Worker 197*523fa7a6SAndroid Build Coastguard Worker 198*523fa7a6SAndroid Build Coastguard Workerclass BatchMatrixMultiplication(nn.Module): 199*523fa7a6SAndroid Build Coastguard Worker def __init__(self, transposed: bool = False) -> None: 200*523fa7a6SAndroid Build Coastguard Worker super().__init__() 201*523fa7a6SAndroid Build Coastguard Worker 202*523fa7a6SAndroid Build Coastguard Worker # Whether the last 2 dims (-1, -2) of the input has already been 203*523fa7a6SAndroid Build Coastguard Worker # transposed. If yes, transpose it back before feeding to torch.bmm 204*523fa7a6SAndroid Build Coastguard Worker self.transposed: bool = transposed 205*523fa7a6SAndroid Build Coastguard Worker 206*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: Tensor, y: Tensor) -> Tensor: 207*523fa7a6SAndroid Build Coastguard Worker if self.transposed: 208*523fa7a6SAndroid Build Coastguard Worker return torch.bmm(x, y.transpose(-1, -2)) 209*523fa7a6SAndroid Build Coastguard Worker else: 210*523fa7a6SAndroid Build Coastguard Worker return torch.bmm(x, y) 211*523fa7a6SAndroid Build Coastguard Worker 212*523fa7a6SAndroid Build Coastguard Worker def extra_repr(self) -> str: 213*523fa7a6SAndroid Build Coastguard Worker return f"transposed={self.transposed}" 214*523fa7a6SAndroid Build Coastguard Worker 215*523fa7a6SAndroid Build Coastguard Worker 216*523fa7a6SAndroid Build Coastguard Workerclass TensorSplit(nn.Module): 217*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 218*523fa7a6SAndroid Build Coastguard Worker super().__init__() 219*523fa7a6SAndroid Build Coastguard Worker 220*523fa7a6SAndroid Build Coastguard Worker def forward(self, input: Tensor, sections: int, dim: int = 0) -> List[Tensor]: 221*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[7]: Expected `List[Tensor]` but got `Tuple[Tensor, ...]`. 222*523fa7a6SAndroid Build Coastguard Worker return torch.tensor_split(input, sections, dim) 223*523fa7a6SAndroid Build Coastguard Worker 224*523fa7a6SAndroid Build Coastguard Worker 225*523fa7a6SAndroid Build Coastguard Workerclass TensorSplitWithSizes(nn.Module): 226*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 227*523fa7a6SAndroid Build Coastguard Worker super().__init__() 228*523fa7a6SAndroid Build Coastguard Worker 229*523fa7a6SAndroid Build Coastguard Worker def forward(self, input: Tensor, split_size: int, dim: int = 0) -> List[Tensor]: 230*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[7]: Expected `List[Tensor]` but got `Tuple[Tensor, ...]`. 231*523fa7a6SAndroid Build Coastguard Worker return torch.split(input, split_size, dim) 232*523fa7a6SAndroid Build Coastguard Worker 233*523fa7a6SAndroid Build Coastguard Worker 234*523fa7a6SAndroid Build Coastguard Workerclass Cat(nn.Module): 235*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 236*523fa7a6SAndroid Build Coastguard Worker super().__init__() 237*523fa7a6SAndroid Build Coastguard Worker 238*523fa7a6SAndroid Build Coastguard Worker # def forward(self, tensors, dim=0): 239*523fa7a6SAndroid Build Coastguard Worker def forward(self, *args: Tensor, dim: int) -> Tensor: 240*523fa7a6SAndroid Build Coastguard Worker tensors = args[:-1] 241*523fa7a6SAndroid Build Coastguard Worker return torch.cat(tensors, dim) 242*523fa7a6SAndroid Build Coastguard Worker 243*523fa7a6SAndroid Build Coastguard Worker 244*523fa7a6SAndroid Build Coastguard Workerclass FeedForwardBlock(nn.Module): 245*523fa7a6SAndroid Build Coastguard Worker def __init__(self, input_dim: int, hidden_dim: int) -> None: 246*523fa7a6SAndroid Build Coastguard Worker super().__init__() 247*523fa7a6SAndroid Build Coastguard Worker self.input_dim = input_dim 248*523fa7a6SAndroid Build Coastguard Worker self.hidden_dim = hidden_dim 249*523fa7a6SAndroid Build Coastguard Worker 250*523fa7a6SAndroid Build Coastguard Worker self.layer_norm = nn.LayerNorm(input_dim) 251*523fa7a6SAndroid Build Coastguard Worker 252*523fa7a6SAndroid Build Coastguard Worker self.relu = nn.ReLU() 253*523fa7a6SAndroid Build Coastguard Worker 254*523fa7a6SAndroid Build Coastguard Worker self.linear1 = nn.Linear(input_dim, hidden_dim) 255*523fa7a6SAndroid Build Coastguard Worker self.dropout1 = nn.Dropout() 256*523fa7a6SAndroid Build Coastguard Worker 257*523fa7a6SAndroid Build Coastguard Worker self.linear2 = nn.Linear(hidden_dim, input_dim) 258*523fa7a6SAndroid Build Coastguard Worker self.dropout2 = nn.Dropout() 259*523fa7a6SAndroid Build Coastguard Worker 260*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: Tensor) -> Tensor: 261*523fa7a6SAndroid Build Coastguard Worker # LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout 262*523fa7a6SAndroid Build Coastguard Worker y = self.layer_norm(x) 263*523fa7a6SAndroid Build Coastguard Worker y = self.linear1(y) 264*523fa7a6SAndroid Build Coastguard Worker y = self.dropout1(y) 265*523fa7a6SAndroid Build Coastguard Worker y = self.relu(y) 266*523fa7a6SAndroid Build Coastguard Worker y = self.linear2(y) 267*523fa7a6SAndroid Build Coastguard Worker y = self.dropout2(y) 268*523fa7a6SAndroid Build Coastguard Worker return y 269*523fa7a6SAndroid Build Coastguard Worker 270*523fa7a6SAndroid Build Coastguard Worker 271*523fa7a6SAndroid Build Coastguard Workerclass NoOp(nn.Module): 272*523fa7a6SAndroid Build Coastguard Worker """ 273*523fa7a6SAndroid Build Coastguard Worker NoOp simply passes the input as the output. 274*523fa7a6SAndroid Build Coastguard Worker """ 275*523fa7a6SAndroid Build Coastguard Worker 276*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 277*523fa7a6SAndroid Build Coastguard Worker super().__init__() 278*523fa7a6SAndroid Build Coastguard Worker 279*523fa7a6SAndroid Build Coastguard Worker def forward(self, input: Tensor) -> Tensor: 280*523fa7a6SAndroid Build Coastguard Worker return input 281*523fa7a6SAndroid Build Coastguard Worker 282*523fa7a6SAndroid Build Coastguard Worker 283*523fa7a6SAndroid Build Coastguard Workerclass MultiLayerPerceptron(nn.Module): 284*523fa7a6SAndroid Build Coastguard Worker def __init__( 285*523fa7a6SAndroid Build Coastguard Worker self, 286*523fa7a6SAndroid Build Coastguard Worker input_dim: int, 287*523fa7a6SAndroid Build Coastguard Worker hidden_dim1: int, 288*523fa7a6SAndroid Build Coastguard Worker hidden_dim2: int, 289*523fa7a6SAndroid Build Coastguard Worker hidden_dim3: int, 290*523fa7a6SAndroid Build Coastguard Worker output_dim: int, 291*523fa7a6SAndroid Build Coastguard Worker ) -> None: 292*523fa7a6SAndroid Build Coastguard Worker super().__init__() 293*523fa7a6SAndroid Build Coastguard Worker self.input_dim = input_dim 294*523fa7a6SAndroid Build Coastguard Worker self.hidden_dim1 = hidden_dim1 295*523fa7a6SAndroid Build Coastguard Worker self.hidden_dim2 = hidden_dim2 296*523fa7a6SAndroid Build Coastguard Worker self.hidden_dim3 = hidden_dim3 297*523fa7a6SAndroid Build Coastguard Worker self.output_dim = output_dim 298*523fa7a6SAndroid Build Coastguard Worker self.layers = nn.Sequential( 299*523fa7a6SAndroid Build Coastguard Worker nn.Linear(input_dim, hidden_dim1), 300*523fa7a6SAndroid Build Coastguard Worker nn.ReLU(), 301*523fa7a6SAndroid Build Coastguard Worker nn.Linear(hidden_dim1, hidden_dim2), 302*523fa7a6SAndroid Build Coastguard Worker nn.ReLU(), 303*523fa7a6SAndroid Build Coastguard Worker nn.Linear(hidden_dim2, hidden_dim3), 304*523fa7a6SAndroid Build Coastguard Worker nn.ReLU(), 305*523fa7a6SAndroid Build Coastguard Worker nn.Linear(hidden_dim3, output_dim), 306*523fa7a6SAndroid Build Coastguard Worker ) 307*523fa7a6SAndroid Build Coastguard Worker 308*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: Tensor) -> Tensor: 309*523fa7a6SAndroid Build Coastguard Worker return self.layers(x) 310*523fa7a6SAndroid Build Coastguard Worker 311*523fa7a6SAndroid Build Coastguard Worker 312*523fa7a6SAndroid Build Coastguard Workerclass ScaledDotProductAttentionModularized(nn.Module): 313*523fa7a6SAndroid Build Coastguard Worker def __init__( 314*523fa7a6SAndroid Build Coastguard Worker self, 315*523fa7a6SAndroid Build Coastguard Worker embed_dim: int, 316*523fa7a6SAndroid Build Coastguard Worker num_heads: int, 317*523fa7a6SAndroid Build Coastguard Worker dropout_p: float = 0.5, 318*523fa7a6SAndroid Build Coastguard Worker ) -> None: 319*523fa7a6SAndroid Build Coastguard Worker super().__init__() 320*523fa7a6SAndroid Build Coastguard Worker self.embed_dim = embed_dim 321*523fa7a6SAndroid Build Coastguard Worker self.num_heads = num_heads 322*523fa7a6SAndroid Build Coastguard Worker self.dropout_p = dropout_p 323*523fa7a6SAndroid Build Coastguard Worker self.dropout = nn.Dropout(p=dropout_p) 324*523fa7a6SAndroid Build Coastguard Worker 325*523fa7a6SAndroid Build Coastguard Worker self.head_dim: int = embed_dim // num_heads 326*523fa7a6SAndroid Build Coastguard Worker self.scaling: float = self.head_dim**-0.5 327*523fa7a6SAndroid Build Coastguard Worker 328*523fa7a6SAndroid Build Coastguard Worker self.mul = Mul() 329*523fa7a6SAndroid Build Coastguard Worker self.reshape = Reshape() 330*523fa7a6SAndroid Build Coastguard Worker self.transpose = Transpose() 331*523fa7a6SAndroid Build Coastguard Worker self.bmm = BatchMatrixMultiplication(transposed=False) 332*523fa7a6SAndroid Build Coastguard Worker self.bmm_t = BatchMatrixMultiplication(transposed=True) 333*523fa7a6SAndroid Build Coastguard Worker self.softmax = nn.Softmax(dim=-1) 334*523fa7a6SAndroid Build Coastguard Worker 335*523fa7a6SAndroid Build Coastguard Worker def forward( 336*523fa7a6SAndroid Build Coastguard Worker self, 337*523fa7a6SAndroid Build Coastguard Worker q: Tensor, 338*523fa7a6SAndroid Build Coastguard Worker k: Tensor, 339*523fa7a6SAndroid Build Coastguard Worker v: Tensor, 340*523fa7a6SAndroid Build Coastguard Worker ) -> Tensor: 341*523fa7a6SAndroid Build Coastguard Worker # q: (L, B, D) k: (S, B, D) v: (S, B, D) 342*523fa7a6SAndroid Build Coastguard Worker # assert k.shape == v.shape 343*523fa7a6SAndroid Build Coastguard Worker # assert q.dim() == 3 and k.dim() == 3 344*523fa7a6SAndroid Build Coastguard Worker # assert q.size(1) == k.size(1) and q.size(2) == k.size(2) 345*523fa7a6SAndroid Build Coastguard Worker 346*523fa7a6SAndroid Build Coastguard Worker L, B, D = q.shape 347*523fa7a6SAndroid Build Coastguard Worker S = k.size(0) 348*523fa7a6SAndroid Build Coastguard Worker # assert D % self.head_dim == 0 349*523fa7a6SAndroid Build Coastguard Worker 350*523fa7a6SAndroid Build Coastguard Worker # FIXME(poweic): scaling layer!? 351*523fa7a6SAndroid Build Coastguard Worker # this will break the modular assumption, which makes the following 352*523fa7a6SAndroid Build Coastguard Worker # self.reshape to think it is using some floating inputs q because 353*523fa7a6SAndroid Build Coastguard Worker # id(q) is no longer the same id(q) 354*523fa7a6SAndroid Build Coastguard Worker # This is equiv. to `q = q * self.scaling` 355*523fa7a6SAndroid Build Coastguard Worker q = self.mul(q, self.scaling) 356*523fa7a6SAndroid Build Coastguard Worker 357*523fa7a6SAndroid Build Coastguard Worker # Reshape & transpose q from (L, B, D) to (B*H, L, D/H) 358*523fa7a6SAndroid Build Coastguard Worker q = self.reshape(q, (L, B * self.num_heads, self.head_dim)) 359*523fa7a6SAndroid Build Coastguard Worker q = self.transpose(q, 0, 1) 360*523fa7a6SAndroid Build Coastguard Worker 361*523fa7a6SAndroid Build Coastguard Worker # Reshape & transpose k from (S, B, D) to (B*H, S, D/H) 362*523fa7a6SAndroid Build Coastguard Worker k = self.reshape(k, (S, B * self.num_heads, self.head_dim)) 363*523fa7a6SAndroid Build Coastguard Worker k = self.transpose(k, 0, 1) 364*523fa7a6SAndroid Build Coastguard Worker 365*523fa7a6SAndroid Build Coastguard Worker # Reshape & transpose v from (S, B, D) to (B*H, S, D/H) 366*523fa7a6SAndroid Build Coastguard Worker v = self.reshape(v, (S, B * self.num_heads, self.head_dim)) 367*523fa7a6SAndroid Build Coastguard Worker v = self.transpose(v, 0, 1) 368*523fa7a6SAndroid Build Coastguard Worker 369*523fa7a6SAndroid Build Coastguard Worker # bmm((B*H, L, D/H), (B*H, D/H, S)) -> (B*H, L, S). 370*523fa7a6SAndroid Build Coastguard Worker # this is equiv. to `qk = torch.bmm(q, k.transpose(-1, -2))` 371*523fa7a6SAndroid Build Coastguard Worker qk = self.bmm_t(q, k) 372*523fa7a6SAndroid Build Coastguard Worker # assert qk.shape == (B * self.num_heads, L, S) 373*523fa7a6SAndroid Build Coastguard Worker 374*523fa7a6SAndroid Build Coastguard Worker softmax_qk = self.softmax(qk) 375*523fa7a6SAndroid Build Coastguard Worker 376*523fa7a6SAndroid Build Coastguard Worker softmax_qk = self.dropout(softmax_qk) 377*523fa7a6SAndroid Build Coastguard Worker 378*523fa7a6SAndroid Build Coastguard Worker # bmm((B*H, L, S), (B*H, S, D/H)) -> (B*H, L, D/H). 379*523fa7a6SAndroid Build Coastguard Worker # this is equiv. to `attention = torch.bmm(softmax_qk, v)` 380*523fa7a6SAndroid Build Coastguard Worker attention = self.bmm(softmax_qk, v) 381*523fa7a6SAndroid Build Coastguard Worker # assert attention.shape == (B * self.num_heads, L, self.head_dim) 382*523fa7a6SAndroid Build Coastguard Worker 383*523fa7a6SAndroid Build Coastguard Worker # Transpose & reshape attention: (B*H, L, D/H) -> (L, B*H, D/H) -> (L, B, D). 384*523fa7a6SAndroid Build Coastguard Worker attention = self.transpose(attention, 0, 1) 385*523fa7a6SAndroid Build Coastguard Worker attention = self.reshape(attention, (L, B, self.embed_dim)) 386*523fa7a6SAndroid Build Coastguard Worker 387*523fa7a6SAndroid Build Coastguard Worker return attention 388*523fa7a6SAndroid Build Coastguard Worker 389*523fa7a6SAndroid Build Coastguard Worker 390*523fa7a6SAndroid Build Coastguard Worker# ------------------------------------------------------------------------------ 391*523fa7a6SAndroid Build Coastguard Worker# Scaled Dot-Product Attention 392*523fa7a6SAndroid Build Coastguard Worker# ------------------------------------------------------------------------------ 393*523fa7a6SAndroid Build Coastguard Workerclass ScaledDotProductAttention(nn.Module): 394*523fa7a6SAndroid Build Coastguard Worker def __init__( 395*523fa7a6SAndroid Build Coastguard Worker self, 396*523fa7a6SAndroid Build Coastguard Worker embed_dim: int, 397*523fa7a6SAndroid Build Coastguard Worker num_heads: int, 398*523fa7a6SAndroid Build Coastguard Worker dropout: Optional[float] = None, 399*523fa7a6SAndroid Build Coastguard Worker ) -> None: 400*523fa7a6SAndroid Build Coastguard Worker if embed_dim % num_heads: 401*523fa7a6SAndroid Build Coastguard Worker raise ValueError( 402*523fa7a6SAndroid Build Coastguard Worker "embed_dim ({}) must be divisible by num_heads ({})".format( 403*523fa7a6SAndroid Build Coastguard Worker embed_dim, num_heads 404*523fa7a6SAndroid Build Coastguard Worker ) 405*523fa7a6SAndroid Build Coastguard Worker ) 406*523fa7a6SAndroid Build Coastguard Worker 407*523fa7a6SAndroid Build Coastguard Worker super().__init__() 408*523fa7a6SAndroid Build Coastguard Worker 409*523fa7a6SAndroid Build Coastguard Worker self.embed_dim = embed_dim 410*523fa7a6SAndroid Build Coastguard Worker self.num_heads = num_heads 411*523fa7a6SAndroid Build Coastguard Worker if dropout is not None and dropout > 0.0: 412*523fa7a6SAndroid Build Coastguard Worker self.dropout: nn.Module = nn.Dropout(p=dropout) 413*523fa7a6SAndroid Build Coastguard Worker else: 414*523fa7a6SAndroid Build Coastguard Worker self.dropout = NoOp() 415*523fa7a6SAndroid Build Coastguard Worker 416*523fa7a6SAndroid Build Coastguard Worker self.head_dim: int = embed_dim // num_heads 417*523fa7a6SAndroid Build Coastguard Worker self.scaling: float = self.head_dim**-0.5 418*523fa7a6SAndroid Build Coastguard Worker 419*523fa7a6SAndroid Build Coastguard Worker def forward( 420*523fa7a6SAndroid Build Coastguard Worker self, 421*523fa7a6SAndroid Build Coastguard Worker q: Tensor, 422*523fa7a6SAndroid Build Coastguard Worker k: Tensor, 423*523fa7a6SAndroid Build Coastguard Worker v: Tensor, 424*523fa7a6SAndroid Build Coastguard Worker padding_mask: Optional[Tensor] = None, 425*523fa7a6SAndroid Build Coastguard Worker attention_mask: Optional[Tensor] = None, 426*523fa7a6SAndroid Build Coastguard Worker ) -> Tensor: 427*523fa7a6SAndroid Build Coastguard Worker # q: (L, B, D) k: (S, B, D) v: (S, B, D) 428*523fa7a6SAndroid Build Coastguard Worker # assert k.shape == v.shape 429*523fa7a6SAndroid Build Coastguard Worker # assert q.dim() == 3 and k.dim() == 3 430*523fa7a6SAndroid Build Coastguard Worker # assert q.size(1) == k.size(1) and q.size(2) == k.size(2) 431*523fa7a6SAndroid Build Coastguard Worker 432*523fa7a6SAndroid Build Coastguard Worker L, B, D = q.shape 433*523fa7a6SAndroid Build Coastguard Worker S = k.size(0) 434*523fa7a6SAndroid Build Coastguard Worker # assert D % self.head_dim == 0 435*523fa7a6SAndroid Build Coastguard Worker 436*523fa7a6SAndroid Build Coastguard Worker q = q * self.scaling 437*523fa7a6SAndroid Build Coastguard Worker q = q.reshape(L, B * self.num_heads, self.head_dim).transpose( 438*523fa7a6SAndroid Build Coastguard Worker 0, 1 439*523fa7a6SAndroid Build Coastguard Worker ) # (B*H, L, D/H) 440*523fa7a6SAndroid Build Coastguard Worker 441*523fa7a6SAndroid Build Coastguard Worker k = k.reshape(S, B * self.num_heads, self.head_dim).transpose( 442*523fa7a6SAndroid Build Coastguard Worker 0, 1 443*523fa7a6SAndroid Build Coastguard Worker ) # (B*H, S, D/H) 444*523fa7a6SAndroid Build Coastguard Worker 445*523fa7a6SAndroid Build Coastguard Worker v = v.reshape(S, B * self.num_heads, self.head_dim).transpose( 446*523fa7a6SAndroid Build Coastguard Worker 0, 1 447*523fa7a6SAndroid Build Coastguard Worker ) # (B*H, S, D/H) 448*523fa7a6SAndroid Build Coastguard Worker 449*523fa7a6SAndroid Build Coastguard Worker # bmm((B*H, L, D/H), (B*H, D/H, S)) -> (B*H, L, S). 450*523fa7a6SAndroid Build Coastguard Worker qk = torch.bmm(q, k.transpose(1, 2)) 451*523fa7a6SAndroid Build Coastguard Worker # assert qk.shape == (B * self.num_heads, L, S) 452*523fa7a6SAndroid Build Coastguard Worker 453*523fa7a6SAndroid Build Coastguard Worker # TODO(cfyeh): figure out if we really need input to be float. 454*523fa7a6SAndroid Build Coastguard Worker softmax_qk = nn.functional.softmax(qk.float(), dim=-1) 455*523fa7a6SAndroid Build Coastguard Worker 456*523fa7a6SAndroid Build Coastguard Worker # softmax_qk = self.dropout(softmax_qk) 457*523fa7a6SAndroid Build Coastguard Worker 458*523fa7a6SAndroid Build Coastguard Worker # bmm((B*H, L, S), (B*H, S, D/H)) -> (B*H, L, D/H). 459*523fa7a6SAndroid Build Coastguard Worker attention = torch.bmm(softmax_qk, v) 460*523fa7a6SAndroid Build Coastguard Worker # assert attention.shape == (B * self.num_heads, L, self.head_dim) 461*523fa7a6SAndroid Build Coastguard Worker 462*523fa7a6SAndroid Build Coastguard Worker # (B*H, L, D/H) -> (L, B*H, D/H) -> (L, B, D). 463*523fa7a6SAndroid Build Coastguard Worker attention = attention.transpose(0, 1).reshape(L, B, self.embed_dim) 464*523fa7a6SAndroid Build Coastguard Worker 465*523fa7a6SAndroid Build Coastguard Worker return attention 466*523fa7a6SAndroid Build Coastguard Worker 467*523fa7a6SAndroid Build Coastguard Worker 468*523fa7a6SAndroid Build Coastguard Workerclass Emformer(nn.Module): 469*523fa7a6SAndroid Build Coastguard Worker def __init__( 470*523fa7a6SAndroid Build Coastguard Worker self, 471*523fa7a6SAndroid Build Coastguard Worker l_dim: int = 32, 472*523fa7a6SAndroid Build Coastguard Worker m_dim: int = 8, 473*523fa7a6SAndroid Build Coastguard Worker c_dim: int = 8, 474*523fa7a6SAndroid Build Coastguard Worker r_dim: int = 8, 475*523fa7a6SAndroid Build Coastguard Worker input_dim: int = 256, 476*523fa7a6SAndroid Build Coastguard Worker ffn_hidden_dim: int = 512, 477*523fa7a6SAndroid Build Coastguard Worker ) -> None: 478*523fa7a6SAndroid Build Coastguard Worker super().__init__() 479*523fa7a6SAndroid Build Coastguard Worker 480*523fa7a6SAndroid Build Coastguard Worker self.l_dim = l_dim 481*523fa7a6SAndroid Build Coastguard Worker self.m_dim = m_dim 482*523fa7a6SAndroid Build Coastguard Worker self.c_dim = c_dim 483*523fa7a6SAndroid Build Coastguard Worker self.r_dim = r_dim 484*523fa7a6SAndroid Build Coastguard Worker 485*523fa7a6SAndroid Build Coastguard Worker self.input_dim = input_dim 486*523fa7a6SAndroid Build Coastguard Worker self.ffn_hidden_dim = ffn_hidden_dim 487*523fa7a6SAndroid Build Coastguard Worker 488*523fa7a6SAndroid Build Coastguard Worker self.split = TensorSplit() 489*523fa7a6SAndroid Build Coastguard Worker self.elem_add = ElementwiseAdd() 490*523fa7a6SAndroid Build Coastguard Worker 491*523fa7a6SAndroid Build Coastguard Worker self.attn = ScaledDotProductAttention( 492*523fa7a6SAndroid Build Coastguard Worker embed_dim=input_dim, 493*523fa7a6SAndroid Build Coastguard Worker num_heads=8, 494*523fa7a6SAndroid Build Coastguard Worker ) 495*523fa7a6SAndroid Build Coastguard Worker 496*523fa7a6SAndroid Build Coastguard Worker self.ffn = FeedForwardBlock(input_dim, ffn_hidden_dim) 497*523fa7a6SAndroid Build Coastguard Worker 498*523fa7a6SAndroid Build Coastguard Worker self.layer_norm = nn.LayerNorm(input_dim) 499*523fa7a6SAndroid Build Coastguard Worker 500*523fa7a6SAndroid Build Coastguard Worker self.linear_k = nn.Linear(self.input_dim, self.input_dim) 501*523fa7a6SAndroid Build Coastguard Worker self.linear_v = nn.Linear(self.input_dim, self.input_dim) 502*523fa7a6SAndroid Build Coastguard Worker self.linear_q = nn.Linear(self.input_dim, self.input_dim) 503*523fa7a6SAndroid Build Coastguard Worker 504*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: 505*523fa7a6SAndroid Build Coastguard Worker inputs = ( 506*523fa7a6SAndroid Build Coastguard Worker torch.randn(self.m_dim, 1, self.input_dim), 507*523fa7a6SAndroid Build Coastguard Worker torch.randn(self.c_dim, 1, self.input_dim), 508*523fa7a6SAndroid Build Coastguard Worker torch.randn(self.r_dim, 1, self.input_dim), 509*523fa7a6SAndroid Build Coastguard Worker torch.randn(self.l_dim, 1, self.input_dim), 510*523fa7a6SAndroid Build Coastguard Worker torch.randn(self.l_dim, 1, self.input_dim), 511*523fa7a6SAndroid Build Coastguard Worker ) 512*523fa7a6SAndroid Build Coastguard Worker return inputs 513*523fa7a6SAndroid Build Coastguard Worker 514*523fa7a6SAndroid Build Coastguard Worker def forward( 515*523fa7a6SAndroid Build Coastguard Worker self, M: Tensor, C: Tensor, R: Tensor, K_L: Tensor, V_L: Tensor 516*523fa7a6SAndroid Build Coastguard Worker ) -> Tensor: 517*523fa7a6SAndroid Build Coastguard Worker """ 518*523fa7a6SAndroid Build Coastguard Worker The Emformer block takes [M_i^n, C_i^n, R_i^n] and [K_{L,i}^n, V_{L,i}^n] 519*523fa7a6SAndroid Build Coastguard Worker as inputs and outputs [C_i^{n+1}, R_i^{n+1}]. 520*523fa7a6SAndroid Build Coastguard Worker See Fig. 1(b) Emformer and equations 6, 7, 8 - 13 in the original paper 521*523fa7a6SAndroid Build Coastguard Worker https://arxiv.org/pdf/2010.10759.pdf 522*523fa7a6SAndroid Build Coastguard Worker 523*523fa7a6SAndroid Build Coastguard Worker Ex: 524*523fa7a6SAndroid Build Coastguard Worker - self.input_dim = 525*523fa7a6SAndroid Build Coastguard Worker - L.shape = 30 x 1 x 512 526*523fa7a6SAndroid Build Coastguard Worker - M.shape = 2 x 1 x 512 527*523fa7a6SAndroid Build Coastguard Worker - C.shape = 5 x 1 x 512 528*523fa7a6SAndroid Build Coastguard Worker - R.shape = 1 x 1 x 512 529*523fa7a6SAndroid Build Coastguard Worker """ 530*523fa7a6SAndroid Build Coastguard Worker # Equation 8 531*523fa7a6SAndroid Build Coastguard Worker CR = torch.cat([C, R], 0) 532*523fa7a6SAndroid Build Coastguard Worker CR_normed = self.layer_norm(CR) 533*523fa7a6SAndroid Build Coastguard Worker # C_normed = self.layer_norm(C) 534*523fa7a6SAndroid Build Coastguard Worker # R_normed = self.layer_norm(R) 535*523fa7a6SAndroid Build Coastguard Worker 536*523fa7a6SAndroid Build Coastguard Worker # Equation 9 and 10 537*523fa7a6SAndroid Build Coastguard Worker if True: 538*523fa7a6SAndroid Build Coastguard Worker MCR = torch.cat([M, C, R], 0) 539*523fa7a6SAndroid Build Coastguard Worker K_MCR = self.linear_k(MCR) 540*523fa7a6SAndroid Build Coastguard Worker V_MCR = self.linear_v(MCR) 541*523fa7a6SAndroid Build Coastguard Worker 542*523fa7a6SAndroid Build Coastguard Worker K_M, K_C, K_R = self.split(K_MCR, 3) 543*523fa7a6SAndroid Build Coastguard Worker V_M, V_C, V_R = self.split(V_MCR, 3) 544*523fa7a6SAndroid Build Coastguard Worker else: 545*523fa7a6SAndroid Build Coastguard Worker K_M, K_C, K_R = self.linear_k(M), self.linear_k(C), self.linear_k(R) 546*523fa7a6SAndroid Build Coastguard Worker V_M, V_C, V_R = self.linear_v(M), self.linear_v(C), self.linear_v(R) 547*523fa7a6SAndroid Build Coastguard Worker 548*523fa7a6SAndroid Build Coastguard Worker K = torch.cat([K_M, K_L, K_C, K_R], 0) 549*523fa7a6SAndroid Build Coastguard Worker V = torch.cat([V_M, V_L, V_C, V_R], 0) 550*523fa7a6SAndroid Build Coastguard Worker 551*523fa7a6SAndroid Build Coastguard Worker # Equation 11 and 12 552*523fa7a6SAndroid Build Coastguard Worker Q_CR = self.linear_q(CR_normed) 553*523fa7a6SAndroid Build Coastguard Worker Z_CR = self.attn(Q_CR, K, V) 554*523fa7a6SAndroid Build Coastguard Worker Z_CR = self.elem_add(Z_CR, CR) 555*523fa7a6SAndroid Build Coastguard Worker # Q_C = self.linear_q(C_normed) 556*523fa7a6SAndroid Build Coastguard Worker # Q_R = self.linear_q(R_normed) 557*523fa7a6SAndroid Build Coastguard Worker # Z_C = self.attn(Q_C, K, V) 558*523fa7a6SAndroid Build Coastguard Worker # Z_R = self.attn(Q_R, K, V) 559*523fa7a6SAndroid Build Coastguard Worker # Z_C = self.elem_add(Z_C, C) 560*523fa7a6SAndroid Build Coastguard Worker # Z_R = self.elem_add(Z_R, R) 561*523fa7a6SAndroid Build Coastguard Worker 562*523fa7a6SAndroid Build Coastguard Worker # Equation 6 563*523fa7a6SAndroid Build Coastguard Worker Z_CR_normed = self.layer_norm(Z_CR) 564*523fa7a6SAndroid Build Coastguard Worker ffn_out = self.ffn(Z_CR_normed) 565*523fa7a6SAndroid Build Coastguard Worker 566*523fa7a6SAndroid Build Coastguard Worker # Equation 7 567*523fa7a6SAndroid Build Coastguard Worker output = self.layer_norm(self.elem_add(ffn_out, Z_CR)) 568*523fa7a6SAndroid Build Coastguard Worker 569*523fa7a6SAndroid Build Coastguard Worker # m = self.attn( 570*523fa7a6SAndroid Build Coastguard Worker 571*523fa7a6SAndroid Build Coastguard Worker return output 572*523fa7a6SAndroid Build Coastguard Worker 573*523fa7a6SAndroid Build Coastguard Worker 574*523fa7a6SAndroid Build Coastguard Worker# List of models that we want to export 575*523fa7a6SAndroid Build Coastguard Worker# TODO(angelayi): enable ControlFlowWhile test once we enable functionalization 576*523fa7a6SAndroid Build Coastguard WorkerMODELS = [ 577*523fa7a6SAndroid Build Coastguard Worker ["basic_sin_max", BasicSinMax()], 578*523fa7a6SAndroid Build Coastguard Worker ["composite_delegate", CompositeDelegateModule()], 579*523fa7a6SAndroid Build Coastguard Worker] 580