xref: /aosp_15_r20/external/executorch/exir/tests/models.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport itertools
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, List, Optional, Tuple, Union
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport torch  # noqa: F401
15*523fa7a6SAndroid Build Coastguard Workerimport torch.nn as nn
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import to_edge
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import LoweredBackendModule
18*523fa7a6SAndroid Build Coastguard Workerfrom torch import Tensor
19*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Worker# TODO: add one more test for data dependent op plus repeat
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Workerclass TensorItem(nn.Module):
25*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
26*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker    def forward(self, arg1: torch.Tensor, arg2: torch.Tensor) -> torch.Tensor:
29*523fa7a6SAndroid Build Coastguard Worker        h = arg1.item()
30*523fa7a6SAndroid Build Coastguard Worker        w = arg2.item()
31*523fa7a6SAndroid Build Coastguard Worker        torch._check(h >= 2)
32*523fa7a6SAndroid Build Coastguard Worker        torch._check(h <= 100)
33*523fa7a6SAndroid Build Coastguard Worker        torch._check(w >= 2)
34*523fa7a6SAndroid Build Coastguard Worker        torch._check(w <= 100)
35*523fa7a6SAndroid Build Coastguard Worker        return torch.ones(int(h), int(w))
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]:
38*523fa7a6SAndroid Build Coastguard Worker        return (torch.tensor(10), torch.tensor(20))
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Workerclass Repeat(nn.Module):
42*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
43*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Worker    def forward(
46*523fa7a6SAndroid Build Coastguard Worker        self, arg1: torch.Tensor, arg2: torch.Tensor
47*523fa7a6SAndroid Build Coastguard Worker    ) -> Tuple[torch.Tensor, torch.Tensor]:
48*523fa7a6SAndroid Build Coastguard Worker        x = arg2.repeat(arg1.size(0), 1)
49*523fa7a6SAndroid Build Coastguard Worker        return x * x, arg2 + arg2
50*523fa7a6SAndroid Build Coastguard Worker
51*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]:
52*523fa7a6SAndroid Build Coastguard Worker        return (torch.rand(4), torch.rand(5))
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker    def get_dynamic_shape(self) -> Any:  # pyre-ignore[3]
55*523fa7a6SAndroid Build Coastguard Worker        dim = torch.export.Dim("dim", max=10)
56*523fa7a6SAndroid Build Coastguard Worker        dim2 = torch.export.Dim("dim2", max=10)
57*523fa7a6SAndroid Build Coastguard Worker        return ({0: dim}, {0: dim2})
58*523fa7a6SAndroid Build Coastguard Worker
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Workerclass ModelWithUnusedArg(nn.Module):
61*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
62*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker    def forward(self, arg1: torch.Tensor, arg2: torch.Tensor) -> torch.Tensor:
65*523fa7a6SAndroid Build Coastguard Worker        return torch.sin(arg1)
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]:
68*523fa7a6SAndroid Build Coastguard Worker        return (torch.rand(4), torch.rand(5))
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Workerclass MLP(nn.Module):
72*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, n_layer: int = 1, output_size: int = 1) -> None:
73*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
74*523fa7a6SAndroid Build Coastguard Worker        self.n_layer = n_layer
75*523fa7a6SAndroid Build Coastguard Worker        self.output_size = output_size
76*523fa7a6SAndroid Build Coastguard Worker        # input shape [batch_size, n_layer+output_size]
77*523fa7a6SAndroid Build Coastguard Worker        # each linear layer reduce the activation dim 1 size by 1.
78*523fa7a6SAndroid Build Coastguard Worker        self.mlp = torch.nn.Sequential(
79*523fa7a6SAndroid Build Coastguard Worker            *itertools.chain(
80*523fa7a6SAndroid Build Coastguard Worker                *(
81*523fa7a6SAndroid Build Coastguard Worker                    [nn.Linear(i + output_size, i - 1 + output_size)]
82*523fa7a6SAndroid Build Coastguard Worker                    + ([nn.ReLU()] if i != 1 else [])
83*523fa7a6SAndroid Build Coastguard Worker                    for i in range(n_layer, 0, -1)
84*523fa7a6SAndroid Build Coastguard Worker                )
85*523fa7a6SAndroid Build Coastguard Worker            )
86*523fa7a6SAndroid Build Coastguard Worker        )
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
89*523fa7a6SAndroid Build Coastguard Worker        return self.mlp(inputs)
90*523fa7a6SAndroid Build Coastguard Worker
91*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
92*523fa7a6SAndroid Build Coastguard Worker        return (torch.rand(2, self.n_layer + self.output_size),)
93*523fa7a6SAndroid Build Coastguard Worker
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Workerclass Identity(nn.Module):
96*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
97*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker    def forward(self, input: Tensor) -> Tensor:
100*523fa7a6SAndroid Build Coastguard Worker        return torch.clone(input)
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Workerclass Reshape(nn.Module):
104*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
105*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker    def forward(
108*523fa7a6SAndroid Build Coastguard Worker        self, x: Tensor, *new_shape: Union[torch.Size, Tuple[int, ...], List[int]]
109*523fa7a6SAndroid Build Coastguard Worker    ) -> Tensor:
110*523fa7a6SAndroid Build Coastguard Worker        if len(new_shape) == 1 and (
111*523fa7a6SAndroid Build Coastguard Worker            isinstance(new_shape[0], tuple) or isinstance(new_shape[0], list)
112*523fa7a6SAndroid Build Coastguard Worker        ):
113*523fa7a6SAndroid Build Coastguard Worker            return x.reshape(new_shape[0])
114*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(new_shape, Union[torch.Size, Tuple[int, ...], List[int]])
115*523fa7a6SAndroid Build Coastguard Worker        return x.reshape(new_shape)
116*523fa7a6SAndroid Build Coastguard Worker
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Workerclass Transpose(nn.Module):
119*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
120*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor:
123*523fa7a6SAndroid Build Coastguard Worker        return x.transpose(dim0, dim1)
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Worker
126*523fa7a6SAndroid Build Coastguard Workerclass Mul(nn.Module):
127*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
128*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
129*523fa7a6SAndroid Build Coastguard Worker
130*523fa7a6SAndroid Build Coastguard Worker    def forward(self, input: Tensor, other: Tensor) -> Tensor:
131*523fa7a6SAndroid Build Coastguard Worker        # or return torch.mul(input, other)
132*523fa7a6SAndroid Build Coastguard Worker        return input * other
133*523fa7a6SAndroid Build Coastguard Worker
134*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
135*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(3, 2), torch.randn(3, 2))
136*523fa7a6SAndroid Build Coastguard Worker
137*523fa7a6SAndroid Build Coastguard Worker
138*523fa7a6SAndroid Build Coastguard Workerclass ElementwiseAdd(nn.Module):
139*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
140*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
141*523fa7a6SAndroid Build Coastguard Worker
142*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x: Tensor, y: Tensor) -> Tensor:
143*523fa7a6SAndroid Build Coastguard Worker        return x + y
144*523fa7a6SAndroid Build Coastguard Worker
145*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
146*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(1, 3), torch.randn(1, 3))
147*523fa7a6SAndroid Build Coastguard Worker
148*523fa7a6SAndroid Build Coastguard Worker
149*523fa7a6SAndroid Build Coastguard Workerclass BasicSinMax(nn.Module):
150*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
151*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
152*523fa7a6SAndroid Build Coastguard Worker
153*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x: Tensor) -> Tensor:
154*523fa7a6SAndroid Build Coastguard Worker        return torch.sin(x)
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self) -> Tuple[Tensor]:
157*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(100),)
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Worker
160*523fa7a6SAndroid Build Coastguard Workerclass CompositeDelegateModule(torch.nn.Module):
161*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
162*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
163*523fa7a6SAndroid Build Coastguard Worker
164*523fa7a6SAndroid Build Coastguard Worker        class DelegateAdd(nn.Module):
165*523fa7a6SAndroid Build Coastguard Worker            def __init__(self) -> None:
166*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
167*523fa7a6SAndroid Build Coastguard Worker
168*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: Tensor, y: Tensor) -> Tensor:
169*523fa7a6SAndroid Build Coastguard Worker                return [x + y]
170*523fa7a6SAndroid Build Coastguard Worker
171*523fa7a6SAndroid Build Coastguard Worker            def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
172*523fa7a6SAndroid Build Coastguard Worker                return (torch.randn(1, 3), torch.randn(1, 3))
173*523fa7a6SAndroid Build Coastguard Worker
174*523fa7a6SAndroid Build Coastguard Worker        delegated_m = DelegateAdd()
175*523fa7a6SAndroid Build Coastguard Worker        edge_ir_m = to_edge(
176*523fa7a6SAndroid Build Coastguard Worker            export(
177*523fa7a6SAndroid Build Coastguard Worker                delegated_m,
178*523fa7a6SAndroid Build Coastguard Worker                delegated_m.get_random_inputs(),
179*523fa7a6SAndroid Build Coastguard Worker            )
180*523fa7a6SAndroid Build Coastguard Worker        )
181*523fa7a6SAndroid Build Coastguard Worker        lowered_module = LoweredBackendModule(
182*523fa7a6SAndroid Build Coastguard Worker            edge_program=edge_ir_m.exported_program(),
183*523fa7a6SAndroid Build Coastguard Worker            backend_id="backend_demo",
184*523fa7a6SAndroid Build Coastguard Worker            processed_bytes=bytes("basic_module_add", encoding="utf8"),
185*523fa7a6SAndroid Build Coastguard Worker            compile_specs=[],
186*523fa7a6SAndroid Build Coastguard Worker        )
187*523fa7a6SAndroid Build Coastguard Worker        self.lowered_module: LoweredBackendModule = lowered_module
188*523fa7a6SAndroid Build Coastguard Worker
189*523fa7a6SAndroid Build Coastguard Worker    def forward(self, a: exir.Value, b: exir.Value, s: Tensor) -> Tensor:
190*523fa7a6SAndroid Build Coastguard Worker        res = self.lowered_module(a, b)
191*523fa7a6SAndroid Build Coastguard Worker        res = res[0] * s
192*523fa7a6SAndroid Build Coastguard Worker        return res
193*523fa7a6SAndroid Build Coastguard Worker
194*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self) -> Tuple[Tensor, Tensor, Tensor]:
195*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(1, 3), torch.randn(1, 3), torch.randn(1, 3))
196*523fa7a6SAndroid Build Coastguard Worker
197*523fa7a6SAndroid Build Coastguard Worker
198*523fa7a6SAndroid Build Coastguard Workerclass BatchMatrixMultiplication(nn.Module):
199*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, transposed: bool = False) -> None:
200*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
201*523fa7a6SAndroid Build Coastguard Worker
202*523fa7a6SAndroid Build Coastguard Worker        # Whether the last 2 dims (-1, -2) of the input has already been
203*523fa7a6SAndroid Build Coastguard Worker        # transposed. If yes, transpose it back before feeding to torch.bmm
204*523fa7a6SAndroid Build Coastguard Worker        self.transposed: bool = transposed
205*523fa7a6SAndroid Build Coastguard Worker
206*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x: Tensor, y: Tensor) -> Tensor:
207*523fa7a6SAndroid Build Coastguard Worker        if self.transposed:
208*523fa7a6SAndroid Build Coastguard Worker            return torch.bmm(x, y.transpose(-1, -2))
209*523fa7a6SAndroid Build Coastguard Worker        else:
210*523fa7a6SAndroid Build Coastguard Worker            return torch.bmm(x, y)
211*523fa7a6SAndroid Build Coastguard Worker
212*523fa7a6SAndroid Build Coastguard Worker    def extra_repr(self) -> str:
213*523fa7a6SAndroid Build Coastguard Worker        return f"transposed={self.transposed}"
214*523fa7a6SAndroid Build Coastguard Worker
215*523fa7a6SAndroid Build Coastguard Worker
216*523fa7a6SAndroid Build Coastguard Workerclass TensorSplit(nn.Module):
217*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
218*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
219*523fa7a6SAndroid Build Coastguard Worker
220*523fa7a6SAndroid Build Coastguard Worker    def forward(self, input: Tensor, sections: int, dim: int = 0) -> List[Tensor]:
221*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[7]: Expected `List[Tensor]` but got `Tuple[Tensor, ...]`.
222*523fa7a6SAndroid Build Coastguard Worker        return torch.tensor_split(input, sections, dim)
223*523fa7a6SAndroid Build Coastguard Worker
224*523fa7a6SAndroid Build Coastguard Worker
225*523fa7a6SAndroid Build Coastguard Workerclass TensorSplitWithSizes(nn.Module):
226*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
227*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
228*523fa7a6SAndroid Build Coastguard Worker
229*523fa7a6SAndroid Build Coastguard Worker    def forward(self, input: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
230*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[7]: Expected `List[Tensor]` but got `Tuple[Tensor, ...]`.
231*523fa7a6SAndroid Build Coastguard Worker        return torch.split(input, split_size, dim)
232*523fa7a6SAndroid Build Coastguard Worker
233*523fa7a6SAndroid Build Coastguard Worker
234*523fa7a6SAndroid Build Coastguard Workerclass Cat(nn.Module):
235*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
236*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
237*523fa7a6SAndroid Build Coastguard Worker
238*523fa7a6SAndroid Build Coastguard Worker    # def forward(self, tensors, dim=0):
239*523fa7a6SAndroid Build Coastguard Worker    def forward(self, *args: Tensor, dim: int) -> Tensor:
240*523fa7a6SAndroid Build Coastguard Worker        tensors = args[:-1]
241*523fa7a6SAndroid Build Coastguard Worker        return torch.cat(tensors, dim)
242*523fa7a6SAndroid Build Coastguard Worker
243*523fa7a6SAndroid Build Coastguard Worker
244*523fa7a6SAndroid Build Coastguard Workerclass FeedForwardBlock(nn.Module):
245*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, input_dim: int, hidden_dim: int) -> None:
246*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
247*523fa7a6SAndroid Build Coastguard Worker        self.input_dim = input_dim
248*523fa7a6SAndroid Build Coastguard Worker        self.hidden_dim = hidden_dim
249*523fa7a6SAndroid Build Coastguard Worker
250*523fa7a6SAndroid Build Coastguard Worker        self.layer_norm = nn.LayerNorm(input_dim)
251*523fa7a6SAndroid Build Coastguard Worker
252*523fa7a6SAndroid Build Coastguard Worker        self.relu = nn.ReLU()
253*523fa7a6SAndroid Build Coastguard Worker
254*523fa7a6SAndroid Build Coastguard Worker        self.linear1 = nn.Linear(input_dim, hidden_dim)
255*523fa7a6SAndroid Build Coastguard Worker        self.dropout1 = nn.Dropout()
256*523fa7a6SAndroid Build Coastguard Worker
257*523fa7a6SAndroid Build Coastguard Worker        self.linear2 = nn.Linear(hidden_dim, input_dim)
258*523fa7a6SAndroid Build Coastguard Worker        self.dropout2 = nn.Dropout()
259*523fa7a6SAndroid Build Coastguard Worker
260*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x: Tensor) -> Tensor:
261*523fa7a6SAndroid Build Coastguard Worker        # LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout
262*523fa7a6SAndroid Build Coastguard Worker        y = self.layer_norm(x)
263*523fa7a6SAndroid Build Coastguard Worker        y = self.linear1(y)
264*523fa7a6SAndroid Build Coastguard Worker        y = self.dropout1(y)
265*523fa7a6SAndroid Build Coastguard Worker        y = self.relu(y)
266*523fa7a6SAndroid Build Coastguard Worker        y = self.linear2(y)
267*523fa7a6SAndroid Build Coastguard Worker        y = self.dropout2(y)
268*523fa7a6SAndroid Build Coastguard Worker        return y
269*523fa7a6SAndroid Build Coastguard Worker
270*523fa7a6SAndroid Build Coastguard Worker
271*523fa7a6SAndroid Build Coastguard Workerclass NoOp(nn.Module):
272*523fa7a6SAndroid Build Coastguard Worker    """
273*523fa7a6SAndroid Build Coastguard Worker    NoOp simply passes the input as the output.
274*523fa7a6SAndroid Build Coastguard Worker    """
275*523fa7a6SAndroid Build Coastguard Worker
276*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
277*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
278*523fa7a6SAndroid Build Coastguard Worker
279*523fa7a6SAndroid Build Coastguard Worker    def forward(self, input: Tensor) -> Tensor:
280*523fa7a6SAndroid Build Coastguard Worker        return input
281*523fa7a6SAndroid Build Coastguard Worker
282*523fa7a6SAndroid Build Coastguard Worker
283*523fa7a6SAndroid Build Coastguard Workerclass MultiLayerPerceptron(nn.Module):
284*523fa7a6SAndroid Build Coastguard Worker    def __init__(
285*523fa7a6SAndroid Build Coastguard Worker        self,
286*523fa7a6SAndroid Build Coastguard Worker        input_dim: int,
287*523fa7a6SAndroid Build Coastguard Worker        hidden_dim1: int,
288*523fa7a6SAndroid Build Coastguard Worker        hidden_dim2: int,
289*523fa7a6SAndroid Build Coastguard Worker        hidden_dim3: int,
290*523fa7a6SAndroid Build Coastguard Worker        output_dim: int,
291*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
292*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
293*523fa7a6SAndroid Build Coastguard Worker        self.input_dim = input_dim
294*523fa7a6SAndroid Build Coastguard Worker        self.hidden_dim1 = hidden_dim1
295*523fa7a6SAndroid Build Coastguard Worker        self.hidden_dim2 = hidden_dim2
296*523fa7a6SAndroid Build Coastguard Worker        self.hidden_dim3 = hidden_dim3
297*523fa7a6SAndroid Build Coastguard Worker        self.output_dim = output_dim
298*523fa7a6SAndroid Build Coastguard Worker        self.layers = nn.Sequential(
299*523fa7a6SAndroid Build Coastguard Worker            nn.Linear(input_dim, hidden_dim1),
300*523fa7a6SAndroid Build Coastguard Worker            nn.ReLU(),
301*523fa7a6SAndroid Build Coastguard Worker            nn.Linear(hidden_dim1, hidden_dim2),
302*523fa7a6SAndroid Build Coastguard Worker            nn.ReLU(),
303*523fa7a6SAndroid Build Coastguard Worker            nn.Linear(hidden_dim2, hidden_dim3),
304*523fa7a6SAndroid Build Coastguard Worker            nn.ReLU(),
305*523fa7a6SAndroid Build Coastguard Worker            nn.Linear(hidden_dim3, output_dim),
306*523fa7a6SAndroid Build Coastguard Worker        )
307*523fa7a6SAndroid Build Coastguard Worker
308*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x: Tensor) -> Tensor:
309*523fa7a6SAndroid Build Coastguard Worker        return self.layers(x)
310*523fa7a6SAndroid Build Coastguard Worker
311*523fa7a6SAndroid Build Coastguard Worker
312*523fa7a6SAndroid Build Coastguard Workerclass ScaledDotProductAttentionModularized(nn.Module):
313*523fa7a6SAndroid Build Coastguard Worker    def __init__(
314*523fa7a6SAndroid Build Coastguard Worker        self,
315*523fa7a6SAndroid Build Coastguard Worker        embed_dim: int,
316*523fa7a6SAndroid Build Coastguard Worker        num_heads: int,
317*523fa7a6SAndroid Build Coastguard Worker        dropout_p: float = 0.5,
318*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
319*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
320*523fa7a6SAndroid Build Coastguard Worker        self.embed_dim = embed_dim
321*523fa7a6SAndroid Build Coastguard Worker        self.num_heads = num_heads
322*523fa7a6SAndroid Build Coastguard Worker        self.dropout_p = dropout_p
323*523fa7a6SAndroid Build Coastguard Worker        self.dropout = nn.Dropout(p=dropout_p)
324*523fa7a6SAndroid Build Coastguard Worker
325*523fa7a6SAndroid Build Coastguard Worker        self.head_dim: int = embed_dim // num_heads
326*523fa7a6SAndroid Build Coastguard Worker        self.scaling: float = self.head_dim**-0.5
327*523fa7a6SAndroid Build Coastguard Worker
328*523fa7a6SAndroid Build Coastguard Worker        self.mul = Mul()
329*523fa7a6SAndroid Build Coastguard Worker        self.reshape = Reshape()
330*523fa7a6SAndroid Build Coastguard Worker        self.transpose = Transpose()
331*523fa7a6SAndroid Build Coastguard Worker        self.bmm = BatchMatrixMultiplication(transposed=False)
332*523fa7a6SAndroid Build Coastguard Worker        self.bmm_t = BatchMatrixMultiplication(transposed=True)
333*523fa7a6SAndroid Build Coastguard Worker        self.softmax = nn.Softmax(dim=-1)
334*523fa7a6SAndroid Build Coastguard Worker
335*523fa7a6SAndroid Build Coastguard Worker    def forward(
336*523fa7a6SAndroid Build Coastguard Worker        self,
337*523fa7a6SAndroid Build Coastguard Worker        q: Tensor,
338*523fa7a6SAndroid Build Coastguard Worker        k: Tensor,
339*523fa7a6SAndroid Build Coastguard Worker        v: Tensor,
340*523fa7a6SAndroid Build Coastguard Worker    ) -> Tensor:
341*523fa7a6SAndroid Build Coastguard Worker        # q: (L, B, D) k: (S, B, D) v: (S, B, D)
342*523fa7a6SAndroid Build Coastguard Worker        # assert k.shape == v.shape
343*523fa7a6SAndroid Build Coastguard Worker        # assert q.dim() == 3 and k.dim() == 3
344*523fa7a6SAndroid Build Coastguard Worker        # assert q.size(1) == k.size(1) and q.size(2) == k.size(2)
345*523fa7a6SAndroid Build Coastguard Worker
346*523fa7a6SAndroid Build Coastguard Worker        L, B, D = q.shape
347*523fa7a6SAndroid Build Coastguard Worker        S = k.size(0)
348*523fa7a6SAndroid Build Coastguard Worker        # assert D % self.head_dim == 0
349*523fa7a6SAndroid Build Coastguard Worker
350*523fa7a6SAndroid Build Coastguard Worker        # FIXME(poweic): scaling layer!?
351*523fa7a6SAndroid Build Coastguard Worker        # this will break the modular assumption, which makes the following
352*523fa7a6SAndroid Build Coastguard Worker        # self.reshape to think it is using some floating inputs q because
353*523fa7a6SAndroid Build Coastguard Worker        # id(q) is no longer the same id(q)
354*523fa7a6SAndroid Build Coastguard Worker        # This is equiv. to `q = q * self.scaling`
355*523fa7a6SAndroid Build Coastguard Worker        q = self.mul(q, self.scaling)
356*523fa7a6SAndroid Build Coastguard Worker
357*523fa7a6SAndroid Build Coastguard Worker        # Reshape & transpose q from (L, B, D) to (B*H, L, D/H)
358*523fa7a6SAndroid Build Coastguard Worker        q = self.reshape(q, (L, B * self.num_heads, self.head_dim))
359*523fa7a6SAndroid Build Coastguard Worker        q = self.transpose(q, 0, 1)
360*523fa7a6SAndroid Build Coastguard Worker
361*523fa7a6SAndroid Build Coastguard Worker        # Reshape & transpose k from (S, B, D) to (B*H, S, D/H)
362*523fa7a6SAndroid Build Coastguard Worker        k = self.reshape(k, (S, B * self.num_heads, self.head_dim))
363*523fa7a6SAndroid Build Coastguard Worker        k = self.transpose(k, 0, 1)
364*523fa7a6SAndroid Build Coastguard Worker
365*523fa7a6SAndroid Build Coastguard Worker        # Reshape & transpose v from (S, B, D) to (B*H, S, D/H)
366*523fa7a6SAndroid Build Coastguard Worker        v = self.reshape(v, (S, B * self.num_heads, self.head_dim))
367*523fa7a6SAndroid Build Coastguard Worker        v = self.transpose(v, 0, 1)
368*523fa7a6SAndroid Build Coastguard Worker
369*523fa7a6SAndroid Build Coastguard Worker        # bmm((B*H, L, D/H), (B*H, D/H, S)) -> (B*H, L, S).
370*523fa7a6SAndroid Build Coastguard Worker        # this is equiv. to `qk = torch.bmm(q, k.transpose(-1, -2))`
371*523fa7a6SAndroid Build Coastguard Worker        qk = self.bmm_t(q, k)
372*523fa7a6SAndroid Build Coastguard Worker        # assert qk.shape == (B * self.num_heads, L, S)
373*523fa7a6SAndroid Build Coastguard Worker
374*523fa7a6SAndroid Build Coastguard Worker        softmax_qk = self.softmax(qk)
375*523fa7a6SAndroid Build Coastguard Worker
376*523fa7a6SAndroid Build Coastguard Worker        softmax_qk = self.dropout(softmax_qk)
377*523fa7a6SAndroid Build Coastguard Worker
378*523fa7a6SAndroid Build Coastguard Worker        # bmm((B*H, L, S), (B*H, S, D/H)) -> (B*H, L, D/H).
379*523fa7a6SAndroid Build Coastguard Worker        # this is equiv. to `attention = torch.bmm(softmax_qk, v)`
380*523fa7a6SAndroid Build Coastguard Worker        attention = self.bmm(softmax_qk, v)
381*523fa7a6SAndroid Build Coastguard Worker        # assert attention.shape == (B * self.num_heads, L, self.head_dim)
382*523fa7a6SAndroid Build Coastguard Worker
383*523fa7a6SAndroid Build Coastguard Worker        # Transpose & reshape attention: (B*H, L, D/H) -> (L, B*H, D/H) -> (L, B, D).
384*523fa7a6SAndroid Build Coastguard Worker        attention = self.transpose(attention, 0, 1)
385*523fa7a6SAndroid Build Coastguard Worker        attention = self.reshape(attention, (L, B, self.embed_dim))
386*523fa7a6SAndroid Build Coastguard Worker
387*523fa7a6SAndroid Build Coastguard Worker        return attention
388*523fa7a6SAndroid Build Coastguard Worker
389*523fa7a6SAndroid Build Coastguard Worker
390*523fa7a6SAndroid Build Coastguard Worker# ------------------------------------------------------------------------------
391*523fa7a6SAndroid Build Coastguard Worker#   Scaled Dot-Product Attention
392*523fa7a6SAndroid Build Coastguard Worker# ------------------------------------------------------------------------------
393*523fa7a6SAndroid Build Coastguard Workerclass ScaledDotProductAttention(nn.Module):
394*523fa7a6SAndroid Build Coastguard Worker    def __init__(
395*523fa7a6SAndroid Build Coastguard Worker        self,
396*523fa7a6SAndroid Build Coastguard Worker        embed_dim: int,
397*523fa7a6SAndroid Build Coastguard Worker        num_heads: int,
398*523fa7a6SAndroid Build Coastguard Worker        dropout: Optional[float] = None,
399*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
400*523fa7a6SAndroid Build Coastguard Worker        if embed_dim % num_heads:
401*523fa7a6SAndroid Build Coastguard Worker            raise ValueError(
402*523fa7a6SAndroid Build Coastguard Worker                "embed_dim ({}) must be divisible by num_heads ({})".format(
403*523fa7a6SAndroid Build Coastguard Worker                    embed_dim, num_heads
404*523fa7a6SAndroid Build Coastguard Worker                )
405*523fa7a6SAndroid Build Coastguard Worker            )
406*523fa7a6SAndroid Build Coastguard Worker
407*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
408*523fa7a6SAndroid Build Coastguard Worker
409*523fa7a6SAndroid Build Coastguard Worker        self.embed_dim = embed_dim
410*523fa7a6SAndroid Build Coastguard Worker        self.num_heads = num_heads
411*523fa7a6SAndroid Build Coastguard Worker        if dropout is not None and dropout > 0.0:
412*523fa7a6SAndroid Build Coastguard Worker            self.dropout: nn.Module = nn.Dropout(p=dropout)
413*523fa7a6SAndroid Build Coastguard Worker        else:
414*523fa7a6SAndroid Build Coastguard Worker            self.dropout = NoOp()
415*523fa7a6SAndroid Build Coastguard Worker
416*523fa7a6SAndroid Build Coastguard Worker        self.head_dim: int = embed_dim // num_heads
417*523fa7a6SAndroid Build Coastguard Worker        self.scaling: float = self.head_dim**-0.5
418*523fa7a6SAndroid Build Coastguard Worker
419*523fa7a6SAndroid Build Coastguard Worker    def forward(
420*523fa7a6SAndroid Build Coastguard Worker        self,
421*523fa7a6SAndroid Build Coastguard Worker        q: Tensor,
422*523fa7a6SAndroid Build Coastguard Worker        k: Tensor,
423*523fa7a6SAndroid Build Coastguard Worker        v: Tensor,
424*523fa7a6SAndroid Build Coastguard Worker        padding_mask: Optional[Tensor] = None,
425*523fa7a6SAndroid Build Coastguard Worker        attention_mask: Optional[Tensor] = None,
426*523fa7a6SAndroid Build Coastguard Worker    ) -> Tensor:
427*523fa7a6SAndroid Build Coastguard Worker        # q: (L, B, D) k: (S, B, D) v: (S, B, D)
428*523fa7a6SAndroid Build Coastguard Worker        # assert k.shape == v.shape
429*523fa7a6SAndroid Build Coastguard Worker        # assert q.dim() == 3 and k.dim() == 3
430*523fa7a6SAndroid Build Coastguard Worker        # assert q.size(1) == k.size(1) and q.size(2) == k.size(2)
431*523fa7a6SAndroid Build Coastguard Worker
432*523fa7a6SAndroid Build Coastguard Worker        L, B, D = q.shape
433*523fa7a6SAndroid Build Coastguard Worker        S = k.size(0)
434*523fa7a6SAndroid Build Coastguard Worker        # assert D % self.head_dim == 0
435*523fa7a6SAndroid Build Coastguard Worker
436*523fa7a6SAndroid Build Coastguard Worker        q = q * self.scaling
437*523fa7a6SAndroid Build Coastguard Worker        q = q.reshape(L, B * self.num_heads, self.head_dim).transpose(
438*523fa7a6SAndroid Build Coastguard Worker            0, 1
439*523fa7a6SAndroid Build Coastguard Worker        )  # (B*H, L, D/H)
440*523fa7a6SAndroid Build Coastguard Worker
441*523fa7a6SAndroid Build Coastguard Worker        k = k.reshape(S, B * self.num_heads, self.head_dim).transpose(
442*523fa7a6SAndroid Build Coastguard Worker            0, 1
443*523fa7a6SAndroid Build Coastguard Worker        )  # (B*H, S, D/H)
444*523fa7a6SAndroid Build Coastguard Worker
445*523fa7a6SAndroid Build Coastguard Worker        v = v.reshape(S, B * self.num_heads, self.head_dim).transpose(
446*523fa7a6SAndroid Build Coastguard Worker            0, 1
447*523fa7a6SAndroid Build Coastguard Worker        )  # (B*H, S, D/H)
448*523fa7a6SAndroid Build Coastguard Worker
449*523fa7a6SAndroid Build Coastguard Worker        # bmm((B*H, L, D/H), (B*H, D/H, S)) -> (B*H, L, S).
450*523fa7a6SAndroid Build Coastguard Worker        qk = torch.bmm(q, k.transpose(1, 2))
451*523fa7a6SAndroid Build Coastguard Worker        # assert qk.shape == (B * self.num_heads, L, S)
452*523fa7a6SAndroid Build Coastguard Worker
453*523fa7a6SAndroid Build Coastguard Worker        # TODO(cfyeh): figure out if we really need input to be float.
454*523fa7a6SAndroid Build Coastguard Worker        softmax_qk = nn.functional.softmax(qk.float(), dim=-1)
455*523fa7a6SAndroid Build Coastguard Worker
456*523fa7a6SAndroid Build Coastguard Worker        # softmax_qk = self.dropout(softmax_qk)
457*523fa7a6SAndroid Build Coastguard Worker
458*523fa7a6SAndroid Build Coastguard Worker        # bmm((B*H, L, S), (B*H, S, D/H)) -> (B*H, L, D/H).
459*523fa7a6SAndroid Build Coastguard Worker        attention = torch.bmm(softmax_qk, v)
460*523fa7a6SAndroid Build Coastguard Worker        # assert attention.shape == (B * self.num_heads, L, self.head_dim)
461*523fa7a6SAndroid Build Coastguard Worker
462*523fa7a6SAndroid Build Coastguard Worker        # (B*H, L, D/H) -> (L, B*H, D/H) -> (L, B, D).
463*523fa7a6SAndroid Build Coastguard Worker        attention = attention.transpose(0, 1).reshape(L, B, self.embed_dim)
464*523fa7a6SAndroid Build Coastguard Worker
465*523fa7a6SAndroid Build Coastguard Worker        return attention
466*523fa7a6SAndroid Build Coastguard Worker
467*523fa7a6SAndroid Build Coastguard Worker
468*523fa7a6SAndroid Build Coastguard Workerclass Emformer(nn.Module):
469*523fa7a6SAndroid Build Coastguard Worker    def __init__(
470*523fa7a6SAndroid Build Coastguard Worker        self,
471*523fa7a6SAndroid Build Coastguard Worker        l_dim: int = 32,
472*523fa7a6SAndroid Build Coastguard Worker        m_dim: int = 8,
473*523fa7a6SAndroid Build Coastguard Worker        c_dim: int = 8,
474*523fa7a6SAndroid Build Coastguard Worker        r_dim: int = 8,
475*523fa7a6SAndroid Build Coastguard Worker        input_dim: int = 256,
476*523fa7a6SAndroid Build Coastguard Worker        ffn_hidden_dim: int = 512,
477*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
478*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
479*523fa7a6SAndroid Build Coastguard Worker
480*523fa7a6SAndroid Build Coastguard Worker        self.l_dim = l_dim
481*523fa7a6SAndroid Build Coastguard Worker        self.m_dim = m_dim
482*523fa7a6SAndroid Build Coastguard Worker        self.c_dim = c_dim
483*523fa7a6SAndroid Build Coastguard Worker        self.r_dim = r_dim
484*523fa7a6SAndroid Build Coastguard Worker
485*523fa7a6SAndroid Build Coastguard Worker        self.input_dim = input_dim
486*523fa7a6SAndroid Build Coastguard Worker        self.ffn_hidden_dim = ffn_hidden_dim
487*523fa7a6SAndroid Build Coastguard Worker
488*523fa7a6SAndroid Build Coastguard Worker        self.split = TensorSplit()
489*523fa7a6SAndroid Build Coastguard Worker        self.elem_add = ElementwiseAdd()
490*523fa7a6SAndroid Build Coastguard Worker
491*523fa7a6SAndroid Build Coastguard Worker        self.attn = ScaledDotProductAttention(
492*523fa7a6SAndroid Build Coastguard Worker            embed_dim=input_dim,
493*523fa7a6SAndroid Build Coastguard Worker            num_heads=8,
494*523fa7a6SAndroid Build Coastguard Worker        )
495*523fa7a6SAndroid Build Coastguard Worker
496*523fa7a6SAndroid Build Coastguard Worker        self.ffn = FeedForwardBlock(input_dim, ffn_hidden_dim)
497*523fa7a6SAndroid Build Coastguard Worker
498*523fa7a6SAndroid Build Coastguard Worker        self.layer_norm = nn.LayerNorm(input_dim)
499*523fa7a6SAndroid Build Coastguard Worker
500*523fa7a6SAndroid Build Coastguard Worker        self.linear_k = nn.Linear(self.input_dim, self.input_dim)
501*523fa7a6SAndroid Build Coastguard Worker        self.linear_v = nn.Linear(self.input_dim, self.input_dim)
502*523fa7a6SAndroid Build Coastguard Worker        self.linear_q = nn.Linear(self.input_dim, self.input_dim)
503*523fa7a6SAndroid Build Coastguard Worker
504*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
505*523fa7a6SAndroid Build Coastguard Worker        inputs = (
506*523fa7a6SAndroid Build Coastguard Worker            torch.randn(self.m_dim, 1, self.input_dim),
507*523fa7a6SAndroid Build Coastguard Worker            torch.randn(self.c_dim, 1, self.input_dim),
508*523fa7a6SAndroid Build Coastguard Worker            torch.randn(self.r_dim, 1, self.input_dim),
509*523fa7a6SAndroid Build Coastguard Worker            torch.randn(self.l_dim, 1, self.input_dim),
510*523fa7a6SAndroid Build Coastguard Worker            torch.randn(self.l_dim, 1, self.input_dim),
511*523fa7a6SAndroid Build Coastguard Worker        )
512*523fa7a6SAndroid Build Coastguard Worker        return inputs
513*523fa7a6SAndroid Build Coastguard Worker
514*523fa7a6SAndroid Build Coastguard Worker    def forward(
515*523fa7a6SAndroid Build Coastguard Worker        self, M: Tensor, C: Tensor, R: Tensor, K_L: Tensor, V_L: Tensor
516*523fa7a6SAndroid Build Coastguard Worker    ) -> Tensor:
517*523fa7a6SAndroid Build Coastguard Worker        """
518*523fa7a6SAndroid Build Coastguard Worker        The Emformer block takes [M_i^n, C_i^n, R_i^n] and [K_{L,i}^n, V_{L,i}^n]
519*523fa7a6SAndroid Build Coastguard Worker        as inputs and outputs [C_i^{n+1}, R_i^{n+1}].
520*523fa7a6SAndroid Build Coastguard Worker        See Fig. 1(b) Emformer and equations 6, 7, 8 - 13 in the original paper
521*523fa7a6SAndroid Build Coastguard Worker        https://arxiv.org/pdf/2010.10759.pdf
522*523fa7a6SAndroid Build Coastguard Worker
523*523fa7a6SAndroid Build Coastguard Worker        Ex:
524*523fa7a6SAndroid Build Coastguard Worker         - self.input_dim =
525*523fa7a6SAndroid Build Coastguard Worker         - L.shape = 30 x 1 x 512
526*523fa7a6SAndroid Build Coastguard Worker         - M.shape =  2 x 1 x 512
527*523fa7a6SAndroid Build Coastguard Worker         - C.shape =  5 x 1 x 512
528*523fa7a6SAndroid Build Coastguard Worker         - R.shape =  1 x 1 x 512
529*523fa7a6SAndroid Build Coastguard Worker        """
530*523fa7a6SAndroid Build Coastguard Worker        # Equation 8
531*523fa7a6SAndroid Build Coastguard Worker        CR = torch.cat([C, R], 0)
532*523fa7a6SAndroid Build Coastguard Worker        CR_normed = self.layer_norm(CR)
533*523fa7a6SAndroid Build Coastguard Worker        # C_normed = self.layer_norm(C)
534*523fa7a6SAndroid Build Coastguard Worker        # R_normed = self.layer_norm(R)
535*523fa7a6SAndroid Build Coastguard Worker
536*523fa7a6SAndroid Build Coastguard Worker        # Equation 9 and 10
537*523fa7a6SAndroid Build Coastguard Worker        if True:
538*523fa7a6SAndroid Build Coastguard Worker            MCR = torch.cat([M, C, R], 0)
539*523fa7a6SAndroid Build Coastguard Worker            K_MCR = self.linear_k(MCR)
540*523fa7a6SAndroid Build Coastguard Worker            V_MCR = self.linear_v(MCR)
541*523fa7a6SAndroid Build Coastguard Worker
542*523fa7a6SAndroid Build Coastguard Worker            K_M, K_C, K_R = self.split(K_MCR, 3)
543*523fa7a6SAndroid Build Coastguard Worker            V_M, V_C, V_R = self.split(V_MCR, 3)
544*523fa7a6SAndroid Build Coastguard Worker        else:
545*523fa7a6SAndroid Build Coastguard Worker            K_M, K_C, K_R = self.linear_k(M), self.linear_k(C), self.linear_k(R)
546*523fa7a6SAndroid Build Coastguard Worker            V_M, V_C, V_R = self.linear_v(M), self.linear_v(C), self.linear_v(R)
547*523fa7a6SAndroid Build Coastguard Worker
548*523fa7a6SAndroid Build Coastguard Worker        K = torch.cat([K_M, K_L, K_C, K_R], 0)
549*523fa7a6SAndroid Build Coastguard Worker        V = torch.cat([V_M, V_L, V_C, V_R], 0)
550*523fa7a6SAndroid Build Coastguard Worker
551*523fa7a6SAndroid Build Coastguard Worker        # Equation 11 and 12
552*523fa7a6SAndroid Build Coastguard Worker        Q_CR = self.linear_q(CR_normed)
553*523fa7a6SAndroid Build Coastguard Worker        Z_CR = self.attn(Q_CR, K, V)
554*523fa7a6SAndroid Build Coastguard Worker        Z_CR = self.elem_add(Z_CR, CR)
555*523fa7a6SAndroid Build Coastguard Worker        # Q_C = self.linear_q(C_normed)
556*523fa7a6SAndroid Build Coastguard Worker        # Q_R = self.linear_q(R_normed)
557*523fa7a6SAndroid Build Coastguard Worker        # Z_C = self.attn(Q_C, K, V)
558*523fa7a6SAndroid Build Coastguard Worker        # Z_R = self.attn(Q_R, K, V)
559*523fa7a6SAndroid Build Coastguard Worker        # Z_C = self.elem_add(Z_C, C)
560*523fa7a6SAndroid Build Coastguard Worker        # Z_R = self.elem_add(Z_R, R)
561*523fa7a6SAndroid Build Coastguard Worker
562*523fa7a6SAndroid Build Coastguard Worker        # Equation 6
563*523fa7a6SAndroid Build Coastguard Worker        Z_CR_normed = self.layer_norm(Z_CR)
564*523fa7a6SAndroid Build Coastguard Worker        ffn_out = self.ffn(Z_CR_normed)
565*523fa7a6SAndroid Build Coastguard Worker
566*523fa7a6SAndroid Build Coastguard Worker        # Equation 7
567*523fa7a6SAndroid Build Coastguard Worker        output = self.layer_norm(self.elem_add(ffn_out, Z_CR))
568*523fa7a6SAndroid Build Coastguard Worker
569*523fa7a6SAndroid Build Coastguard Worker        # m = self.attn(
570*523fa7a6SAndroid Build Coastguard Worker
571*523fa7a6SAndroid Build Coastguard Worker        return output
572*523fa7a6SAndroid Build Coastguard Worker
573*523fa7a6SAndroid Build Coastguard Worker
574*523fa7a6SAndroid Build Coastguard Worker# List of models that we want to export
575*523fa7a6SAndroid Build Coastguard Worker# TODO(angelayi): enable ControlFlowWhile test once we enable functionalization
576*523fa7a6SAndroid Build Coastguard WorkerMODELS = [
577*523fa7a6SAndroid Build Coastguard Worker    ["basic_sin_max", BasicSinMax()],
578*523fa7a6SAndroid Build Coastguard Worker    ["composite_delegate", CompositeDelegateModule()],
579*523fa7a6SAndroid Build Coastguard Worker]
580