xref: /aosp_15_r20/external/executorch/backends/qualcomm/tests/models.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport torch
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Worker# module with related operator only
11*523fa7a6SAndroid Build Coastguard Workerclass Add(torch.nn.Module):
12*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
13*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
16*523fa7a6SAndroid Build Coastguard Worker        return torch.add(x, y)
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Workerclass AddConstantFloat(torch.nn.Module):
20*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
21*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
24*523fa7a6SAndroid Build Coastguard Worker        return 10.0 + x
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Workerclass AddConstantLong(torch.nn.Module):
28*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
29*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
32*523fa7a6SAndroid Build Coastguard Worker        return 10 + x
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Workerclass Arange(torch.nn.Module):
36*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, x):
37*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
38*523fa7a6SAndroid Build Coastguard Worker        self.x = x
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker    def forward(self, y):
41*523fa7a6SAndroid Build Coastguard Worker        return torch.arange(self.x, dtype=torch.float32) + y
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Workerclass AvgPoolModule(torch.nn.Module):
45*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
46*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
47*523fa7a6SAndroid Build Coastguard Worker        self.avgPool = torch.nn.AvgPool2d(
48*523fa7a6SAndroid Build Coastguard Worker            kernel_size=(2, 2),
49*523fa7a6SAndroid Build Coastguard Worker            padding=(1, 1),
50*523fa7a6SAndroid Build Coastguard Worker            stride=(1, 1),
51*523fa7a6SAndroid Build Coastguard Worker            count_include_pad=False,
52*523fa7a6SAndroid Build Coastguard Worker        )
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
55*523fa7a6SAndroid Build Coastguard Worker        return self.avgPool(x)
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Workerclass BatchNorm(torch.nn.Module):
59*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, n_features):
60*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
61*523fa7a6SAndroid Build Coastguard Worker        self.native_batchnorm = torch.nn.BatchNorm2d(n_features)
62*523fa7a6SAndroid Build Coastguard Worker        self.eval()
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
65*523fa7a6SAndroid Build Coastguard Worker        return self.native_batchnorm(x)
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Workerclass Bmm(torch.nn.Module):
69*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
70*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
73*523fa7a6SAndroid Build Coastguard Worker        return torch.matmul(x, y)
74*523fa7a6SAndroid Build Coastguard Worker
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Workerclass Cast(torch.nn.Module):
77*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
78*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
81*523fa7a6SAndroid Build Coastguard Worker        return x.type(torch.IntTensor)
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Workerclass Cat2(torch.nn.Module):
85*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
86*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
89*523fa7a6SAndroid Build Coastguard Worker        return torch.cat((x, y), axis=2)
90*523fa7a6SAndroid Build Coastguard Worker
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Workerclass Cat3(torch.nn.Module):
93*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
94*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
97*523fa7a6SAndroid Build Coastguard Worker        return torch.concat((y, y, x), axis=2)
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker
100*523fa7a6SAndroid Build Coastguard Workerclass Cat4(torch.nn.Module):
101*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
102*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
103*523fa7a6SAndroid Build Coastguard Worker
104*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
105*523fa7a6SAndroid Build Coastguard Worker        return torch.cat((y, y, x, x), axis=2)
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker
108*523fa7a6SAndroid Build Coastguard Workerclass Ceil(torch.nn.Module):
109*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
110*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
111*523fa7a6SAndroid Build Coastguard Worker
112*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
113*523fa7a6SAndroid Build Coastguard Worker        return torch.ceil(x)
114*523fa7a6SAndroid Build Coastguard Worker
115*523fa7a6SAndroid Build Coastguard Worker
116*523fa7a6SAndroid Build Coastguard Workerclass Chunk(torch.nn.Module):
117*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
118*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
119*523fa7a6SAndroid Build Coastguard Worker
120*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
121*523fa7a6SAndroid Build Coastguard Worker        return torch.chunk(x, chunks=2, dim=-1)
122*523fa7a6SAndroid Build Coastguard Worker
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Workerclass ChunkAdd(torch.nn.Module):
125*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
126*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
129*523fa7a6SAndroid Build Coastguard Worker        c1, c2 = torch.chunk(x, chunks=2, dim=-1)
130*523fa7a6SAndroid Build Coastguard Worker        return torch.add(c1, c2)
131*523fa7a6SAndroid Build Coastguard Worker
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Workerclass Clamp(torch.nn.Module):
134*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
135*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
136*523fa7a6SAndroid Build Coastguard Worker
137*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
138*523fa7a6SAndroid Build Coastguard Worker        return torch.clamp(x, max=0)
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker
141*523fa7a6SAndroid Build Coastguard Workerclass CompositeDelegateModule(torch.nn.Module):
142*523fa7a6SAndroid Build Coastguard Worker    def __init__(
143*523fa7a6SAndroid Build Coastguard Worker        self,
144*523fa7a6SAndroid Build Coastguard Worker        compiler_specs,
145*523fa7a6SAndroid Build Coastguard Worker        partitioner_type,
146*523fa7a6SAndroid Build Coastguard Worker        capture_method,
147*523fa7a6SAndroid Build Coastguard Worker        lowered_method,
148*523fa7a6SAndroid Build Coastguard Worker        quantize_method=None,
149*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
150*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
151*523fa7a6SAndroid Build Coastguard Worker        self.modules = [
152*523fa7a6SAndroid Build Coastguard Worker            Conv2dSequential(),
153*523fa7a6SAndroid Build Coastguard Worker            Conv2dSequential(),
154*523fa7a6SAndroid Build Coastguard Worker            Add(),
155*523fa7a6SAndroid Build Coastguard Worker            Relu(),
156*523fa7a6SAndroid Build Coastguard Worker        ]
157*523fa7a6SAndroid Build Coastguard Worker        self.sample_inputs = [
158*523fa7a6SAndroid Build Coastguard Worker            (torch.randn([1, 1, 3, 3]),),
159*523fa7a6SAndroid Build Coastguard Worker            (torch.randn([1, 1, 3, 3]),),
160*523fa7a6SAndroid Build Coastguard Worker            (torch.randn([1, 2, 3, 3]), torch.randn([1, 2, 3, 3])),
161*523fa7a6SAndroid Build Coastguard Worker            (torch.randn([1, 2, 3, 3]),),
162*523fa7a6SAndroid Build Coastguard Worker        ]
163*523fa7a6SAndroid Build Coastguard Worker        self.lowered_modules = []
164*523fa7a6SAndroid Build Coastguard Worker        for module, sample_input in zip(self.modules, self.sample_inputs):
165*523fa7a6SAndroid Build Coastguard Worker            partitioner = partitioner_type(compiler_specs)
166*523fa7a6SAndroid Build Coastguard Worker            if quantize_method:
167*523fa7a6SAndroid Build Coastguard Worker                module = quantize_method(module, sample_input)
168*523fa7a6SAndroid Build Coastguard Worker            edge_prog = capture_method(module, sample_input)
169*523fa7a6SAndroid Build Coastguard Worker            edge_prog.exported_program = lowered_method(
170*523fa7a6SAndroid Build Coastguard Worker                edge_prog.exported_program, partitioner
171*523fa7a6SAndroid Build Coastguard Worker            )
172*523fa7a6SAndroid Build Coastguard Worker            self.lowered_modules.append(
173*523fa7a6SAndroid Build Coastguard Worker                edge_prog.exported_program.graph_module._modules.get("lowered_module_0")
174*523fa7a6SAndroid Build Coastguard Worker            )
175*523fa7a6SAndroid Build Coastguard Worker
176*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
177*523fa7a6SAndroid Build Coastguard Worker        x1 = self.lowered_modules[0](x)
178*523fa7a6SAndroid Build Coastguard Worker        x2 = self.lowered_modules[1](y)
179*523fa7a6SAndroid Build Coastguard Worker        x3 = self.lowered_modules[2](x1[0], x2[0])
180*523fa7a6SAndroid Build Coastguard Worker        x4 = self.lowered_modules[3](x3[0])
181*523fa7a6SAndroid Build Coastguard Worker        return x4[0]
182*523fa7a6SAndroid Build Coastguard Worker
183*523fa7a6SAndroid Build Coastguard Worker    def get_random_input(self):
184*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn([1, 1, 3, 3]), torch.randn([1, 1, 3, 3]))
185*523fa7a6SAndroid Build Coastguard Worker
186*523fa7a6SAndroid Build Coastguard Worker    def get_reference_module(self):
187*523fa7a6SAndroid Build Coastguard Worker        class CompositeReferenceModule(torch.nn.Module):
188*523fa7a6SAndroid Build Coastguard Worker            def __init__(self, modules):
189*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
190*523fa7a6SAndroid Build Coastguard Worker                self.modules = modules
191*523fa7a6SAndroid Build Coastguard Worker
192*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x, y):
193*523fa7a6SAndroid Build Coastguard Worker                x1 = self.modules[0](x)
194*523fa7a6SAndroid Build Coastguard Worker                x2 = self.modules[1](y)
195*523fa7a6SAndroid Build Coastguard Worker                x3 = self.modules[2](x1, x2)
196*523fa7a6SAndroid Build Coastguard Worker                x4 = self.modules[3](x3)
197*523fa7a6SAndroid Build Coastguard Worker                return x4
198*523fa7a6SAndroid Build Coastguard Worker
199*523fa7a6SAndroid Build Coastguard Worker        return CompositeReferenceModule(self.modules)
200*523fa7a6SAndroid Build Coastguard Worker
201*523fa7a6SAndroid Build Coastguard Worker
202*523fa7a6SAndroid Build Coastguard Workerclass ContextBinaryExample(torch.nn.Module):
203*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
204*523fa7a6SAndroid Build Coastguard Worker        x = torch.nn.functional.relu(x)
205*523fa7a6SAndroid Build Coastguard Worker        y = torch.nn.functional.relu(y)
206*523fa7a6SAndroid Build Coastguard Worker        return x, y
207*523fa7a6SAndroid Build Coastguard Worker
208*523fa7a6SAndroid Build Coastguard Worker    def example_inputs(self):
209*523fa7a6SAndroid Build Coastguard Worker        return {
210*523fa7a6SAndroid Build Coastguard Worker            "x": torch.randn((1, 3, 3, 3)),
211*523fa7a6SAndroid Build Coastguard Worker            "y": torch.randn((2, 1, 5, 5)),
212*523fa7a6SAndroid Build Coastguard Worker        }
213*523fa7a6SAndroid Build Coastguard Worker
214*523fa7a6SAndroid Build Coastguard Worker
215*523fa7a6SAndroid Build Coastguard Workerclass Conv1dSequential(torch.nn.Module):
216*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, bias=True):
217*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
218*523fa7a6SAndroid Build Coastguard Worker        self.first = torch.nn.Conv1d(
219*523fa7a6SAndroid Build Coastguard Worker            in_channels=1,
220*523fa7a6SAndroid Build Coastguard Worker            out_channels=3,
221*523fa7a6SAndroid Build Coastguard Worker            kernel_size=(3),
222*523fa7a6SAndroid Build Coastguard Worker            padding=1,
223*523fa7a6SAndroid Build Coastguard Worker            bias=bias,
224*523fa7a6SAndroid Build Coastguard Worker        )
225*523fa7a6SAndroid Build Coastguard Worker
226*523fa7a6SAndroid Build Coastguard Worker        self.second = torch.nn.Conv1d(
227*523fa7a6SAndroid Build Coastguard Worker            in_channels=3,
228*523fa7a6SAndroid Build Coastguard Worker            out_channels=2,
229*523fa7a6SAndroid Build Coastguard Worker            kernel_size=(3),
230*523fa7a6SAndroid Build Coastguard Worker            padding=1,
231*523fa7a6SAndroid Build Coastguard Worker            bias=bias,
232*523fa7a6SAndroid Build Coastguard Worker        )
233*523fa7a6SAndroid Build Coastguard Worker
234*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
235*523fa7a6SAndroid Build Coastguard Worker        return self.second(self.first(x))
236*523fa7a6SAndroid Build Coastguard Worker
237*523fa7a6SAndroid Build Coastguard Worker
238*523fa7a6SAndroid Build Coastguard Worker# small models
239*523fa7a6SAndroid Build Coastguard Workerclass Conv1dReluLogSoftmax(torch.nn.Module):
240*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
241*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
242*523fa7a6SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv1d(
243*523fa7a6SAndroid Build Coastguard Worker            in_channels=2, out_channels=2, kernel_size=1, stride=1, padding=1
244*523fa7a6SAndroid Build Coastguard Worker        )
245*523fa7a6SAndroid Build Coastguard Worker        self.logsoftmax = torch.nn.LogSoftmax(dim=1)
246*523fa7a6SAndroid Build Coastguard Worker
247*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
248*523fa7a6SAndroid Build Coastguard Worker        x = torch.nn.functional.relu(self.conv(x))
249*523fa7a6SAndroid Build Coastguard Worker        x = self.logsoftmax(x)
250*523fa7a6SAndroid Build Coastguard Worker        return x
251*523fa7a6SAndroid Build Coastguard Worker
252*523fa7a6SAndroid Build Coastguard Worker
253*523fa7a6SAndroid Build Coastguard Workerclass Conv2dAvgPool2d(torch.nn.Module):
254*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
255*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
256*523fa7a6SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(
257*523fa7a6SAndroid Build Coastguard Worker            3, 16, 7, bias=True, stride=2, padding=3, dilation=1
258*523fa7a6SAndroid Build Coastguard Worker        )
259*523fa7a6SAndroid Build Coastguard Worker        self.pool = torch.nn.AvgPool2d(3, stride=2, padding=1)
260*523fa7a6SAndroid Build Coastguard Worker
261*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
262*523fa7a6SAndroid Build Coastguard Worker        return self.pool(self.conv(x))
263*523fa7a6SAndroid Build Coastguard Worker
264*523fa7a6SAndroid Build Coastguard Worker
265*523fa7a6SAndroid Build Coastguard Workerclass Conv2dBnHardtanhMean(torch.nn.Module):
266*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
267*523fa7a6SAndroid Build Coastguard Worker        super(Conv2dBnHardtanhMean, self).__init__()
268*523fa7a6SAndroid Build Coastguard Worker        groups = 1
269*523fa7a6SAndroid Build Coastguard Worker        stride = [2, 2]
270*523fa7a6SAndroid Build Coastguard Worker        padding = [1, 1]
271*523fa7a6SAndroid Build Coastguard Worker        dilation = [1, 1]
272*523fa7a6SAndroid Build Coastguard Worker        in_channels = 1
273*523fa7a6SAndroid Build Coastguard Worker        out_channels = 1
274*523fa7a6SAndroid Build Coastguard Worker
275*523fa7a6SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(
276*523fa7a6SAndroid Build Coastguard Worker            in_channels=in_channels,
277*523fa7a6SAndroid Build Coastguard Worker            out_channels=out_channels,
278*523fa7a6SAndroid Build Coastguard Worker            kernel_size=(3, 3),
279*523fa7a6SAndroid Build Coastguard Worker            stride=stride,
280*523fa7a6SAndroid Build Coastguard Worker            padding=padding,
281*523fa7a6SAndroid Build Coastguard Worker            groups=groups,
282*523fa7a6SAndroid Build Coastguard Worker            dilation=dilation,
283*523fa7a6SAndroid Build Coastguard Worker            bias=True,
284*523fa7a6SAndroid Build Coastguard Worker        )
285*523fa7a6SAndroid Build Coastguard Worker        self.conv.weight = torch.nn.Parameter(torch.randn(self.conv.weight.size()))
286*523fa7a6SAndroid Build Coastguard Worker        self.native_batchnorm = torch.nn.BatchNorm2d(out_channels)
287*523fa7a6SAndroid Build Coastguard Worker        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)
288*523fa7a6SAndroid Build Coastguard Worker        self.eval()
289*523fa7a6SAndroid Build Coastguard Worker
290*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
291*523fa7a6SAndroid Build Coastguard Worker        x1 = self.conv(x)
292*523fa7a6SAndroid Build Coastguard Worker        x2 = self.native_batchnorm(x1)
293*523fa7a6SAndroid Build Coastguard Worker        x3 = self.hardtanh(x2)
294*523fa7a6SAndroid Build Coastguard Worker        x4 = torch.mean(x3, (1), keepdim=True)
295*523fa7a6SAndroid Build Coastguard Worker        return x4
296*523fa7a6SAndroid Build Coastguard Worker
297*523fa7a6SAndroid Build Coastguard Worker
298*523fa7a6SAndroid Build Coastguard Workerclass Conv2dCat(torch.nn.Module):
299*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
300*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
301*523fa7a6SAndroid Build Coastguard Worker        self.conv1 = torch.nn.Conv2d(3, 3, 3)
302*523fa7a6SAndroid Build Coastguard Worker        self.conv2 = torch.nn.Conv2d(3, 3, 3)
303*523fa7a6SAndroid Build Coastguard Worker
304*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
305*523fa7a6SAndroid Build Coastguard Worker        x = self.conv1(x)
306*523fa7a6SAndroid Build Coastguard Worker        y = self.conv2(y)
307*523fa7a6SAndroid Build Coastguard Worker        z = torch.cat([x, y], dim=1)
308*523fa7a6SAndroid Build Coastguard Worker        return z
309*523fa7a6SAndroid Build Coastguard Worker
310*523fa7a6SAndroid Build Coastguard Worker
311*523fa7a6SAndroid Build Coastguard Workerclass Conv2dMaxPool2d(torch.nn.Module):
312*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
313*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
314*523fa7a6SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(
315*523fa7a6SAndroid Build Coastguard Worker            in_channels=2,
316*523fa7a6SAndroid Build Coastguard Worker            out_channels=2,
317*523fa7a6SAndroid Build Coastguard Worker            kernel_size=(1, 1),
318*523fa7a6SAndroid Build Coastguard Worker            padding=1,
319*523fa7a6SAndroid Build Coastguard Worker            bias=True,
320*523fa7a6SAndroid Build Coastguard Worker        )
321*523fa7a6SAndroid Build Coastguard Worker        self.pool = torch.nn.MaxPool2d(1, 1)
322*523fa7a6SAndroid Build Coastguard Worker
323*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
324*523fa7a6SAndroid Build Coastguard Worker        return self.pool(self.conv(x))
325*523fa7a6SAndroid Build Coastguard Worker
326*523fa7a6SAndroid Build Coastguard Worker
327*523fa7a6SAndroid Build Coastguard Workerclass Conv2dSequential(torch.nn.Module):
328*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, bias=True):
329*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
330*523fa7a6SAndroid Build Coastguard Worker        self.first = torch.nn.Conv2d(
331*523fa7a6SAndroid Build Coastguard Worker            in_channels=1,
332*523fa7a6SAndroid Build Coastguard Worker            out_channels=3,
333*523fa7a6SAndroid Build Coastguard Worker            kernel_size=(3, 3),
334*523fa7a6SAndroid Build Coastguard Worker            padding=1,
335*523fa7a6SAndroid Build Coastguard Worker            bias=bias,
336*523fa7a6SAndroid Build Coastguard Worker        )
337*523fa7a6SAndroid Build Coastguard Worker        self.second = torch.nn.Conv2d(
338*523fa7a6SAndroid Build Coastguard Worker            in_channels=3,
339*523fa7a6SAndroid Build Coastguard Worker            out_channels=2,
340*523fa7a6SAndroid Build Coastguard Worker            kernel_size=(3, 3),
341*523fa7a6SAndroid Build Coastguard Worker            padding=1,
342*523fa7a6SAndroid Build Coastguard Worker            bias=bias,
343*523fa7a6SAndroid Build Coastguard Worker        )
344*523fa7a6SAndroid Build Coastguard Worker
345*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
346*523fa7a6SAndroid Build Coastguard Worker        return self.second(self.first(x))
347*523fa7a6SAndroid Build Coastguard Worker
348*523fa7a6SAndroid Build Coastguard Worker
349*523fa7a6SAndroid Build Coastguard Workerclass Conv2dSingle(torch.nn.Module):
350*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, bias=True):
351*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
352*523fa7a6SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(
353*523fa7a6SAndroid Build Coastguard Worker            in_channels=1,
354*523fa7a6SAndroid Build Coastguard Worker            out_channels=3,
355*523fa7a6SAndroid Build Coastguard Worker            kernel_size=(3, 3),
356*523fa7a6SAndroid Build Coastguard Worker            padding=1,
357*523fa7a6SAndroid Build Coastguard Worker            bias=bias,
358*523fa7a6SAndroid Build Coastguard Worker        )
359*523fa7a6SAndroid Build Coastguard Worker
360*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
361*523fa7a6SAndroid Build Coastguard Worker        return self.conv(x)
362*523fa7a6SAndroid Build Coastguard Worker
363*523fa7a6SAndroid Build Coastguard Worker
364*523fa7a6SAndroid Build Coastguard Workerclass ConvTranspose2dSingle(torch.nn.Module):
365*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, bias=True):
366*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
367*523fa7a6SAndroid Build Coastguard Worker        self.conv_transpose = torch.nn.ConvTranspose2d(
368*523fa7a6SAndroid Build Coastguard Worker            in_channels=1,
369*523fa7a6SAndroid Build Coastguard Worker            out_channels=3,
370*523fa7a6SAndroid Build Coastguard Worker            kernel_size=3,
371*523fa7a6SAndroid Build Coastguard Worker            stride=2,
372*523fa7a6SAndroid Build Coastguard Worker            padding=1,
373*523fa7a6SAndroid Build Coastguard Worker            bias=bias,
374*523fa7a6SAndroid Build Coastguard Worker        )
375*523fa7a6SAndroid Build Coastguard Worker
376*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
377*523fa7a6SAndroid Build Coastguard Worker        return self.conv_transpose(x)
378*523fa7a6SAndroid Build Coastguard Worker
379*523fa7a6SAndroid Build Coastguard Worker
380*523fa7a6SAndroid Build Coastguard Workerclass Conv2dDownUpSample(torch.nn.Module):
381*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, bias=True):
382*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
383*523fa7a6SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(
384*523fa7a6SAndroid Build Coastguard Worker            in_channels=16,
385*523fa7a6SAndroid Build Coastguard Worker            out_channels=16,
386*523fa7a6SAndroid Build Coastguard Worker            kernel_size=3,
387*523fa7a6SAndroid Build Coastguard Worker            stride=2,
388*523fa7a6SAndroid Build Coastguard Worker            padding=1,
389*523fa7a6SAndroid Build Coastguard Worker            bias=bias,
390*523fa7a6SAndroid Build Coastguard Worker        )
391*523fa7a6SAndroid Build Coastguard Worker        self.conv_transpose = torch.nn.ConvTranspose2d(
392*523fa7a6SAndroid Build Coastguard Worker            in_channels=16,
393*523fa7a6SAndroid Build Coastguard Worker            out_channels=16,
394*523fa7a6SAndroid Build Coastguard Worker            kernel_size=3,
395*523fa7a6SAndroid Build Coastguard Worker            stride=2,
396*523fa7a6SAndroid Build Coastguard Worker            padding=1,
397*523fa7a6SAndroid Build Coastguard Worker            bias=bias,
398*523fa7a6SAndroid Build Coastguard Worker        )
399*523fa7a6SAndroid Build Coastguard Worker
400*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
401*523fa7a6SAndroid Build Coastguard Worker        return self.conv_transpose(self.conv(x))
402*523fa7a6SAndroid Build Coastguard Worker
403*523fa7a6SAndroid Build Coastguard Worker
404*523fa7a6SAndroid Build Coastguard Workerclass Conv2dSumReduceDim(torch.nn.Module):
405*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
406*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
407*523fa7a6SAndroid Build Coastguard Worker        self.first = torch.nn.Conv2d(
408*523fa7a6SAndroid Build Coastguard Worker            in_channels=1,
409*523fa7a6SAndroid Build Coastguard Worker            out_channels=3,
410*523fa7a6SAndroid Build Coastguard Worker            kernel_size=(3, 3),
411*523fa7a6SAndroid Build Coastguard Worker            padding=1,
412*523fa7a6SAndroid Build Coastguard Worker            bias=True,
413*523fa7a6SAndroid Build Coastguard Worker        )
414*523fa7a6SAndroid Build Coastguard Worker
415*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
416*523fa7a6SAndroid Build Coastguard Worker        return torch.sum(self.first(x), dim=(2, 3), keepdim=False)
417*523fa7a6SAndroid Build Coastguard Worker
418*523fa7a6SAndroid Build Coastguard Worker
419*523fa7a6SAndroid Build Coastguard Workerclass Conv2dTopK(torch.nn.Module):
420*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
421*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
422*523fa7a6SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(3, 16, 3)
423*523fa7a6SAndroid Build Coastguard Worker
424*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
425*523fa7a6SAndroid Build Coastguard Worker        x = self.conv(x)
426*523fa7a6SAndroid Build Coastguard Worker        topk_values, topk_indices = torch.topk(x, 5, dim=1)
427*523fa7a6SAndroid Build Coastguard Worker        return topk_values
428*523fa7a6SAndroid Build Coastguard Worker
429*523fa7a6SAndroid Build Coastguard Worker
430*523fa7a6SAndroid Build Coastguard Workerclass Div(torch.nn.Module):
431*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
432*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
433*523fa7a6SAndroid Build Coastguard Worker
434*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
435*523fa7a6SAndroid Build Coastguard Worker        return torch.divide(x, y)
436*523fa7a6SAndroid Build Coastguard Worker
437*523fa7a6SAndroid Build Coastguard Worker
438*523fa7a6SAndroid Build Coastguard Workerclass DivConstantFloat(torch.nn.Module):
439*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
440*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
441*523fa7a6SAndroid Build Coastguard Worker
442*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
443*523fa7a6SAndroid Build Coastguard Worker        return x / 10.0
444*523fa7a6SAndroid Build Coastguard Worker
445*523fa7a6SAndroid Build Coastguard Worker
446*523fa7a6SAndroid Build Coastguard Workerclass DivConstantLong(torch.nn.Module):
447*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
448*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
449*523fa7a6SAndroid Build Coastguard Worker
450*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
451*523fa7a6SAndroid Build Coastguard Worker        return x / 10
452*523fa7a6SAndroid Build Coastguard Worker
453*523fa7a6SAndroid Build Coastguard Worker
454*523fa7a6SAndroid Build Coastguard Workerclass EinsumBilinear(torch.nn.Module):
455*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
456*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
457*523fa7a6SAndroid Build Coastguard Worker
458*523fa7a6SAndroid Build Coastguard Worker    def forward(self, bn, anm, bm):
459*523fa7a6SAndroid Build Coastguard Worker        return torch.einsum("bn,anm,bm->ba", bn, anm, bm)
460*523fa7a6SAndroid Build Coastguard Worker
461*523fa7a6SAndroid Build Coastguard Worker
462*523fa7a6SAndroid Build Coastguard Workerclass EinsumOuterProduct(torch.nn.Module):
463*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
464*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
465*523fa7a6SAndroid Build Coastguard Worker
466*523fa7a6SAndroid Build Coastguard Worker    def forward(self, i, j):
467*523fa7a6SAndroid Build Coastguard Worker        return torch.einsum("i,j->ij", i, j)
468*523fa7a6SAndroid Build Coastguard Worker
469*523fa7a6SAndroid Build Coastguard Worker
470*523fa7a6SAndroid Build Coastguard Workerclass EinsumOuterProductRelu(torch.nn.Module):
471*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
472*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
473*523fa7a6SAndroid Build Coastguard Worker
474*523fa7a6SAndroid Build Coastguard Worker    def forward(self, i, j):
475*523fa7a6SAndroid Build Coastguard Worker        return torch.relu(torch.einsum("i,j->ij", i, j))
476*523fa7a6SAndroid Build Coastguard Worker
477*523fa7a6SAndroid Build Coastguard Worker
478*523fa7a6SAndroid Build Coastguard Workerclass Embedding(torch.nn.Module):
479*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
480*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
481*523fa7a6SAndroid Build Coastguard Worker        self.embedding = torch.nn.Embedding(10, 3)
482*523fa7a6SAndroid Build Coastguard Worker
483*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
484*523fa7a6SAndroid Build Coastguard Worker        return self.embedding(x)
485*523fa7a6SAndroid Build Coastguard Worker
486*523fa7a6SAndroid Build Coastguard Worker
487*523fa7a6SAndroid Build Coastguard Workerclass ExpandCopy(torch.nn.Module):
488*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
489*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
490*523fa7a6SAndroid Build Coastguard Worker
491*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
492*523fa7a6SAndroid Build Coastguard Worker        return x.expand(3, 4)
493*523fa7a6SAndroid Build Coastguard Worker
494*523fa7a6SAndroid Build Coastguard Worker
495*523fa7a6SAndroid Build Coastguard Workerclass Gelu(torch.nn.Module):
496*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
497*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
498*523fa7a6SAndroid Build Coastguard Worker        self.gelu = torch.nn.GELU()
499*523fa7a6SAndroid Build Coastguard Worker
500*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
501*523fa7a6SAndroid Build Coastguard Worker        return self.gelu(x)
502*523fa7a6SAndroid Build Coastguard Worker
503*523fa7a6SAndroid Build Coastguard Worker
504*523fa7a6SAndroid Build Coastguard Workerclass GroupNorm(torch.nn.Module):
505*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, bias=True):
506*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
507*523fa7a6SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(
508*523fa7a6SAndroid Build Coastguard Worker            32,
509*523fa7a6SAndroid Build Coastguard Worker            256,
510*523fa7a6SAndroid Build Coastguard Worker            kernel_size=3,
511*523fa7a6SAndroid Build Coastguard Worker            stride=1,
512*523fa7a6SAndroid Build Coastguard Worker            padding=1,
513*523fa7a6SAndroid Build Coastguard Worker            bias=bias,
514*523fa7a6SAndroid Build Coastguard Worker        )
515*523fa7a6SAndroid Build Coastguard Worker        self.norm = torch.nn.GroupNorm(32, 256)
516*523fa7a6SAndroid Build Coastguard Worker
517*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
518*523fa7a6SAndroid Build Coastguard Worker        y = self.conv(x)
519*523fa7a6SAndroid Build Coastguard Worker        return y, self.norm(y)
520*523fa7a6SAndroid Build Coastguard Worker
521*523fa7a6SAndroid Build Coastguard Worker
522*523fa7a6SAndroid Build Coastguard Workerclass HardSigmoid(torch.nn.Module):
523*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
524*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
525*523fa7a6SAndroid Build Coastguard Worker        self.hardsigmoid = torch.nn.Hardsigmoid()
526*523fa7a6SAndroid Build Coastguard Worker
527*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
528*523fa7a6SAndroid Build Coastguard Worker        return self.hardsigmoid(x)
529*523fa7a6SAndroid Build Coastguard Worker
530*523fa7a6SAndroid Build Coastguard Worker
531*523fa7a6SAndroid Build Coastguard Workerclass HardSwish(torch.nn.Module):
532*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
533*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
534*523fa7a6SAndroid Build Coastguard Worker        self.hardswish = torch.nn.Hardswish()
535*523fa7a6SAndroid Build Coastguard Worker
536*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
537*523fa7a6SAndroid Build Coastguard Worker        return self.hardswish(x)
538*523fa7a6SAndroid Build Coastguard Worker
539*523fa7a6SAndroid Build Coastguard Worker
540*523fa7a6SAndroid Build Coastguard Workerclass HardTanh(torch.nn.Module):
541*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
542*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
543*523fa7a6SAndroid Build Coastguard Worker        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)
544*523fa7a6SAndroid Build Coastguard Worker
545*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
546*523fa7a6SAndroid Build Coastguard Worker        return self.hardtanh(x)
547*523fa7a6SAndroid Build Coastguard Worker
548*523fa7a6SAndroid Build Coastguard Worker
549*523fa7a6SAndroid Build Coastguard Workerclass Index(torch.nn.Module):
550*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
551*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
552*523fa7a6SAndroid Build Coastguard Worker        self.idx0 = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int32)
553*523fa7a6SAndroid Build Coastguard Worker        self.idx1 = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32)
554*523fa7a6SAndroid Build Coastguard Worker
555*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
556*523fa7a6SAndroid Build Coastguard Worker        return x[self.idx0] + x[self.idx1]
557*523fa7a6SAndroid Build Coastguard Worker
558*523fa7a6SAndroid Build Coastguard Worker
559*523fa7a6SAndroid Build Coastguard Workerclass IndexPut(torch.nn.Module):
560*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
561*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
562*523fa7a6SAndroid Build Coastguard Worker        self.register_buffer(
563*523fa7a6SAndroid Build Coastguard Worker            "k_cache",
564*523fa7a6SAndroid Build Coastguard Worker            torch.zeros((1, 1024, 12, 64), dtype=torch.float32),
565*523fa7a6SAndroid Build Coastguard Worker        )
566*523fa7a6SAndroid Build Coastguard Worker
567*523fa7a6SAndroid Build Coastguard Worker    def forward(self, input_pos, k_val):
568*523fa7a6SAndroid Build Coastguard Worker        k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val)
569*523fa7a6SAndroid Build Coastguard Worker        return k_out
570*523fa7a6SAndroid Build Coastguard Worker
571*523fa7a6SAndroid Build Coastguard Worker
572*523fa7a6SAndroid Build Coastguard Workerclass LayerNorm(torch.nn.Module):
573*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
574*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
575*523fa7a6SAndroid Build Coastguard Worker        self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6)
576*523fa7a6SAndroid Build Coastguard Worker        self.linear = torch.nn.Linear(768, 196)
577*523fa7a6SAndroid Build Coastguard Worker
578*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
579*523fa7a6SAndroid Build Coastguard Worker        return self.linear(self.layer_norm(x))
580*523fa7a6SAndroid Build Coastguard Worker
581*523fa7a6SAndroid Build Coastguard Worker
582*523fa7a6SAndroid Build Coastguard Workerclass LeakyReLUDefault(torch.nn.Module):
583*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
584*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
585*523fa7a6SAndroid Build Coastguard Worker        self.leaky_relu = torch.nn.LeakyReLU()
586*523fa7a6SAndroid Build Coastguard Worker
587*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
588*523fa7a6SAndroid Build Coastguard Worker        return self.leaky_relu(x)
589*523fa7a6SAndroid Build Coastguard Worker
590*523fa7a6SAndroid Build Coastguard Worker
591*523fa7a6SAndroid Build Coastguard Workerclass LeakyReLUCustom(torch.nn.Module):
592*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, coeff):
593*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
594*523fa7a6SAndroid Build Coastguard Worker        self.leaky_relu = torch.nn.LeakyReLU(coeff)
595*523fa7a6SAndroid Build Coastguard Worker
596*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
597*523fa7a6SAndroid Build Coastguard Worker        return self.leaky_relu(x)
598*523fa7a6SAndroid Build Coastguard Worker
599*523fa7a6SAndroid Build Coastguard Worker
600*523fa7a6SAndroid Build Coastguard Workerclass Linear(torch.nn.Module):
601*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, use_bias: bool = True):
602*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
603*523fa7a6SAndroid Build Coastguard Worker        self.linear = torch.nn.Linear(4, 5, use_bias).eval()
604*523fa7a6SAndroid Build Coastguard Worker
605*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
606*523fa7a6SAndroid Build Coastguard Worker        return self.linear(x)
607*523fa7a6SAndroid Build Coastguard Worker
608*523fa7a6SAndroid Build Coastguard Worker
609*523fa7a6SAndroid Build Coastguard Workerclass LogSoftmax(torch.nn.Module):
610*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
611*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
612*523fa7a6SAndroid Build Coastguard Worker
613*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
614*523fa7a6SAndroid Build Coastguard Worker        return torch.nn.functional.log_softmax(x, dim=-1)
615*523fa7a6SAndroid Build Coastguard Worker
616*523fa7a6SAndroid Build Coastguard Worker
617*523fa7a6SAndroid Build Coastguard Workerclass MaxPool2d(torch.nn.Module):
618*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
619*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
620*523fa7a6SAndroid Build Coastguard Worker        self.max_pool2d = torch.nn.MaxPool2d(
621*523fa7a6SAndroid Build Coastguard Worker            kernel_size=3,
622*523fa7a6SAndroid Build Coastguard Worker            stride=1,
623*523fa7a6SAndroid Build Coastguard Worker            padding=1,
624*523fa7a6SAndroid Build Coastguard Worker            dilation=1,
625*523fa7a6SAndroid Build Coastguard Worker            ceil_mode=True,
626*523fa7a6SAndroid Build Coastguard Worker        )
627*523fa7a6SAndroid Build Coastguard Worker
628*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
629*523fa7a6SAndroid Build Coastguard Worker        return self.max_pool2d(x)
630*523fa7a6SAndroid Build Coastguard Worker
631*523fa7a6SAndroid Build Coastguard Worker
632*523fa7a6SAndroid Build Coastguard Workerclass MeanWKeppDim(torch.nn.Module):
633*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
634*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
635*523fa7a6SAndroid Build Coastguard Worker
636*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
637*523fa7a6SAndroid Build Coastguard Worker        return torch.mean(x, (-1, -2), keepdim=True)
638*523fa7a6SAndroid Build Coastguard Worker
639*523fa7a6SAndroid Build Coastguard Worker
640*523fa7a6SAndroid Build Coastguard Workerclass MeanWOKeppDim(torch.nn.Module):
641*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
642*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
643*523fa7a6SAndroid Build Coastguard Worker
644*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
645*523fa7a6SAndroid Build Coastguard Worker        return torch.mean(x, (-1, -2))
646*523fa7a6SAndroid Build Coastguard Worker
647*523fa7a6SAndroid Build Coastguard Worker
648*523fa7a6SAndroid Build Coastguard Workerclass Mul(torch.nn.Module):
649*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
650*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
651*523fa7a6SAndroid Build Coastguard Worker
652*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
653*523fa7a6SAndroid Build Coastguard Worker        return torch.mul(x, y)
654*523fa7a6SAndroid Build Coastguard Worker
655*523fa7a6SAndroid Build Coastguard Worker
656*523fa7a6SAndroid Build Coastguard Workerclass MulConstantFloat(torch.nn.Module):
657*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
658*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
659*523fa7a6SAndroid Build Coastguard Worker
660*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
661*523fa7a6SAndroid Build Coastguard Worker        return 10.0 * x
662*523fa7a6SAndroid Build Coastguard Worker
663*523fa7a6SAndroid Build Coastguard Worker
664*523fa7a6SAndroid Build Coastguard Workerclass MulConstantLong(torch.nn.Module):
665*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
666*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
667*523fa7a6SAndroid Build Coastguard Worker
668*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
669*523fa7a6SAndroid Build Coastguard Worker        return 10 * x
670*523fa7a6SAndroid Build Coastguard Worker
671*523fa7a6SAndroid Build Coastguard Worker
672*523fa7a6SAndroid Build Coastguard Workerclass MulScalar(torch.nn.Module):
673*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
674*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
675*523fa7a6SAndroid Build Coastguard Worker        self._scalar = 3.14
676*523fa7a6SAndroid Build Coastguard Worker
677*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
678*523fa7a6SAndroid Build Coastguard Worker        out1 = torch.ops.aten.mul.Scalar(x, self._scalar)
679*523fa7a6SAndroid Build Coastguard Worker        return out1
680*523fa7a6SAndroid Build Coastguard Worker
681*523fa7a6SAndroid Build Coastguard Worker
682*523fa7a6SAndroid Build Coastguard Workerclass MultiheadAttention(torch.nn.Module):
683*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
684*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
685*523fa7a6SAndroid Build Coastguard Worker        self.multi_head_attention = torch.nn.MultiheadAttention(
686*523fa7a6SAndroid Build Coastguard Worker            96, 12, dropout=0.0, batch_first=True
687*523fa7a6SAndroid Build Coastguard Worker        )
688*523fa7a6SAndroid Build Coastguard Worker
689*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
690*523fa7a6SAndroid Build Coastguard Worker        attn_output, _ = self.multi_head_attention(x, x, x, need_weights=False)
691*523fa7a6SAndroid Build Coastguard Worker        return attn_output
692*523fa7a6SAndroid Build Coastguard Worker
693*523fa7a6SAndroid Build Coastguard Worker
694*523fa7a6SAndroid Build Coastguard Workerclass Pad(torch.nn.Module):
695*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
696*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
697*523fa7a6SAndroid Build Coastguard Worker
698*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
699*523fa7a6SAndroid Build Coastguard Worker        return torch.nn.functional.pad(
700*523fa7a6SAndroid Build Coastguard Worker            x[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0, mode="constant"
701*523fa7a6SAndroid Build Coastguard Worker        )
702*523fa7a6SAndroid Build Coastguard Worker
703*523fa7a6SAndroid Build Coastguard Worker
704*523fa7a6SAndroid Build Coastguard Workerclass PixelShuffle(torch.nn.Module):
705*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, scale):
706*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
707*523fa7a6SAndroid Build Coastguard Worker        self.pixel_shuffle = torch.nn.PixelShuffle(scale)
708*523fa7a6SAndroid Build Coastguard Worker
709*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
710*523fa7a6SAndroid Build Coastguard Worker        return self.pixel_shuffle(x)
711*523fa7a6SAndroid Build Coastguard Worker
712*523fa7a6SAndroid Build Coastguard Worker
713*523fa7a6SAndroid Build Coastguard Workerclass PixelUnshuffle(torch.nn.Module):
714*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, scale):
715*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
716*523fa7a6SAndroid Build Coastguard Worker        self.pixel_unshuffle = torch.nn.PixelUnshuffle(scale)
717*523fa7a6SAndroid Build Coastguard Worker
718*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
719*523fa7a6SAndroid Build Coastguard Worker        return self.pixel_unshuffle(x)
720*523fa7a6SAndroid Build Coastguard Worker
721*523fa7a6SAndroid Build Coastguard Worker
722*523fa7a6SAndroid Build Coastguard Workerclass PixelUnshuffleMathEquivalent(torch.nn.Module):
723*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, scale):
724*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
725*523fa7a6SAndroid Build Coastguard Worker        self.scale = scale
726*523fa7a6SAndroid Build Coastguard Worker
727*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
728*523fa7a6SAndroid Build Coastguard Worker        b, c, hh, hw = x.size()
729*523fa7a6SAndroid Build Coastguard Worker        out_channel = c * (self.scale**2)
730*523fa7a6SAndroid Build Coastguard Worker        h = hh // self.scale
731*523fa7a6SAndroid Build Coastguard Worker        w = hw // self.scale
732*523fa7a6SAndroid Build Coastguard Worker        x_view = x.view(b, c, h, self.scale, w, self.scale)
733*523fa7a6SAndroid Build Coastguard Worker        return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
734*523fa7a6SAndroid Build Coastguard Worker
735*523fa7a6SAndroid Build Coastguard Worker
736*523fa7a6SAndroid Build Coastguard Workerclass PowTensorScalar(torch.nn.Module):
737*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
738*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
739*523fa7a6SAndroid Build Coastguard Worker
740*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
741*523fa7a6SAndroid Build Coastguard Worker        return torch.pow(x, 2)
742*523fa7a6SAndroid Build Coastguard Worker
743*523fa7a6SAndroid Build Coastguard Worker
744*523fa7a6SAndroid Build Coastguard Workerclass PReLUDefault(torch.nn.Module):
745*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
746*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
747*523fa7a6SAndroid Build Coastguard Worker        self.prelu = torch.nn.PReLU()
748*523fa7a6SAndroid Build Coastguard Worker
749*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
750*523fa7a6SAndroid Build Coastguard Worker        return self.prelu(x)
751*523fa7a6SAndroid Build Coastguard Worker
752*523fa7a6SAndroid Build Coastguard Worker
753*523fa7a6SAndroid Build Coastguard Workerclass PReLUPerChannel(torch.nn.Module):
754*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, channels):
755*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
756*523fa7a6SAndroid Build Coastguard Worker        self.prelu = torch.nn.PReLU(channels)
757*523fa7a6SAndroid Build Coastguard Worker
758*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
759*523fa7a6SAndroid Build Coastguard Worker        return self.prelu(x)
760*523fa7a6SAndroid Build Coastguard Worker
761*523fa7a6SAndroid Build Coastguard Worker
762*523fa7a6SAndroid Build Coastguard Workerclass Relu(torch.nn.Module):
763*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
764*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
765*523fa7a6SAndroid Build Coastguard Worker        self.relu = torch.nn.ReLU()
766*523fa7a6SAndroid Build Coastguard Worker
767*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
768*523fa7a6SAndroid Build Coastguard Worker        return self.relu(x)
769*523fa7a6SAndroid Build Coastguard Worker
770*523fa7a6SAndroid Build Coastguard Worker
771*523fa7a6SAndroid Build Coastguard Workerclass Reshape(torch.nn.Module):
772*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
773*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
774*523fa7a6SAndroid Build Coastguard Worker
775*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
776*523fa7a6SAndroid Build Coastguard Worker        return x.reshape(1, 12)
777*523fa7a6SAndroid Build Coastguard Worker
778*523fa7a6SAndroid Build Coastguard Worker
779*523fa7a6SAndroid Build Coastguard Workerclass ResidualBlockModule(torch.nn.Module):
780*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
781*523fa7a6SAndroid Build Coastguard Worker        super(ResidualBlockModule, self).__init__()
782*523fa7a6SAndroid Build Coastguard Worker        groups = 1
783*523fa7a6SAndroid Build Coastguard Worker        stride = [1, 1]
784*523fa7a6SAndroid Build Coastguard Worker        padding = [1, 1]
785*523fa7a6SAndroid Build Coastguard Worker        dilation = [1, 1]
786*523fa7a6SAndroid Build Coastguard Worker        in_channels = 32
787*523fa7a6SAndroid Build Coastguard Worker        out_channels = 32
788*523fa7a6SAndroid Build Coastguard Worker
789*523fa7a6SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(
790*523fa7a6SAndroid Build Coastguard Worker            in_channels=in_channels,
791*523fa7a6SAndroid Build Coastguard Worker            out_channels=out_channels,
792*523fa7a6SAndroid Build Coastguard Worker            kernel_size=(3, 3),
793*523fa7a6SAndroid Build Coastguard Worker            stride=stride,
794*523fa7a6SAndroid Build Coastguard Worker            padding=padding,
795*523fa7a6SAndroid Build Coastguard Worker            groups=groups,
796*523fa7a6SAndroid Build Coastguard Worker            dilation=dilation,
797*523fa7a6SAndroid Build Coastguard Worker            bias=True,
798*523fa7a6SAndroid Build Coastguard Worker        )
799*523fa7a6SAndroid Build Coastguard Worker        self.native_batchnorm = torch.nn.BatchNorm2d(out_channels)
800*523fa7a6SAndroid Build Coastguard Worker        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6.0)
801*523fa7a6SAndroid Build Coastguard Worker        self.eval()
802*523fa7a6SAndroid Build Coastguard Worker
803*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
804*523fa7a6SAndroid Build Coastguard Worker        x1 = self.conv(x)
805*523fa7a6SAndroid Build Coastguard Worker        x2 = self.native_batchnorm(x1)
806*523fa7a6SAndroid Build Coastguard Worker        x3 = self.conv(x2)
807*523fa7a6SAndroid Build Coastguard Worker        x4 = self.native_batchnorm(x3)
808*523fa7a6SAndroid Build Coastguard Worker        x5 = self.hardtanh(x4)
809*523fa7a6SAndroid Build Coastguard Worker        x6 = torch.add(x5, x2)
810*523fa7a6SAndroid Build Coastguard Worker        return x6
811*523fa7a6SAndroid Build Coastguard Worker
812*523fa7a6SAndroid Build Coastguard Worker
813*523fa7a6SAndroid Build Coastguard Workerclass ResizeBilinear2D(torch.nn.Module):
814*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
815*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
816*523fa7a6SAndroid Build Coastguard Worker
817*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
818*523fa7a6SAndroid Build Coastguard Worker        output_shape = [dim * 2 for dim in x.shape[-2:]]
819*523fa7a6SAndroid Build Coastguard Worker        return torch.nn.functional.interpolate(
820*523fa7a6SAndroid Build Coastguard Worker            x,
821*523fa7a6SAndroid Build Coastguard Worker            size=list(torch.randn(output_shape).shape),
822*523fa7a6SAndroid Build Coastguard Worker            mode="bilinear",
823*523fa7a6SAndroid Build Coastguard Worker            align_corners=False,
824*523fa7a6SAndroid Build Coastguard Worker        )
825*523fa7a6SAndroid Build Coastguard Worker
826*523fa7a6SAndroid Build Coastguard Worker
827*523fa7a6SAndroid Build Coastguard Workerclass ResizeNearest2D(torch.nn.Module):
828*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
829*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
830*523fa7a6SAndroid Build Coastguard Worker
831*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
832*523fa7a6SAndroid Build Coastguard Worker        output_shape = [dim * 2 for dim in x.shape[-2:]]
833*523fa7a6SAndroid Build Coastguard Worker        return torch.nn.functional.interpolate(
834*523fa7a6SAndroid Build Coastguard Worker            x,
835*523fa7a6SAndroid Build Coastguard Worker            size=list(torch.randn(output_shape).shape),
836*523fa7a6SAndroid Build Coastguard Worker            mode="nearest",
837*523fa7a6SAndroid Build Coastguard Worker        )
838*523fa7a6SAndroid Build Coastguard Worker
839*523fa7a6SAndroid Build Coastguard Worker
840*523fa7a6SAndroid Build Coastguard Workerclass RmsNorm(torch.nn.Module):
841*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
842*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
843*523fa7a6SAndroid Build Coastguard Worker        self.eps = 1e-5
844*523fa7a6SAndroid Build Coastguard Worker        self.rms = torch.nn.RMSNorm([4], 1e-5)
845*523fa7a6SAndroid Build Coastguard Worker
846*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
847*523fa7a6SAndroid Build Coastguard Worker        return self.rms(x)
848*523fa7a6SAndroid Build Coastguard Worker
849*523fa7a6SAndroid Build Coastguard Worker
850*523fa7a6SAndroid Build Coastguard Workerclass Rsqrt(torch.nn.Module):
851*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
852*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
853*523fa7a6SAndroid Build Coastguard Worker
854*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
855*523fa7a6SAndroid Build Coastguard Worker        return torch.rsqrt(x)
856*523fa7a6SAndroid Build Coastguard Worker
857*523fa7a6SAndroid Build Coastguard Worker
858*523fa7a6SAndroid Build Coastguard Workerclass ScaledDotProductAttention(torch.nn.Module):
859*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
860*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
861*523fa7a6SAndroid Build Coastguard Worker
862*523fa7a6SAndroid Build Coastguard Worker    def forward(self, query_layer, key_layer, value_layer, attn_mask):
863*523fa7a6SAndroid Build Coastguard Worker        attn_output = torch.nn.functional.scaled_dot_product_attention(
864*523fa7a6SAndroid Build Coastguard Worker            query_layer, key_layer, value_layer, attn_mask
865*523fa7a6SAndroid Build Coastguard Worker        )
866*523fa7a6SAndroid Build Coastguard Worker        return attn_output
867*523fa7a6SAndroid Build Coastguard Worker
868*523fa7a6SAndroid Build Coastguard Worker
869*523fa7a6SAndroid Build Coastguard Workerclass SelectCopy(torch.nn.Module):
870*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
871*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
872*523fa7a6SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(
873*523fa7a6SAndroid Build Coastguard Worker            in_channels=3,
874*523fa7a6SAndroid Build Coastguard Worker            out_channels=2,
875*523fa7a6SAndroid Build Coastguard Worker            kernel_size=(3, 3),
876*523fa7a6SAndroid Build Coastguard Worker            padding=1,
877*523fa7a6SAndroid Build Coastguard Worker            bias=True,
878*523fa7a6SAndroid Build Coastguard Worker        )
879*523fa7a6SAndroid Build Coastguard Worker
880*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
881*523fa7a6SAndroid Build Coastguard Worker        return self.conv(x)[0, 1, 1:2]
882*523fa7a6SAndroid Build Coastguard Worker
883*523fa7a6SAndroid Build Coastguard Worker
884*523fa7a6SAndroid Build Coastguard Workerclass Sigmoid(torch.nn.Module):
885*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
886*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
887*523fa7a6SAndroid Build Coastguard Worker
888*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
889*523fa7a6SAndroid Build Coastguard Worker        return torch.sigmoid(x)
890*523fa7a6SAndroid Build Coastguard Worker
891*523fa7a6SAndroid Build Coastguard Worker
892*523fa7a6SAndroid Build Coastguard Workerclass SimpleModel(torch.nn.Module):
893*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
894*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
895*523fa7a6SAndroid Build Coastguard Worker        kernel_sz = 32
896*523fa7a6SAndroid Build Coastguard Worker        self.conv1 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True)
897*523fa7a6SAndroid Build Coastguard Worker        self.conv2 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True)
898*523fa7a6SAndroid Build Coastguard Worker        self.conv3 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False)
899*523fa7a6SAndroid Build Coastguard Worker        self.conv4 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False)
900*523fa7a6SAndroid Build Coastguard Worker        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)
901*523fa7a6SAndroid Build Coastguard Worker        self.relu = torch.nn.ReLU()
902*523fa7a6SAndroid Build Coastguard Worker        self.batch_norm = torch.nn.BatchNorm2d(kernel_sz)
903*523fa7a6SAndroid Build Coastguard Worker        self.add = torch.add
904*523fa7a6SAndroid Build Coastguard Worker        self.mean = torch.mean
905*523fa7a6SAndroid Build Coastguard Worker        self.reshape = torch.reshape
906*523fa7a6SAndroid Build Coastguard Worker        self.linear = torch.nn.Linear(4, 10)
907*523fa7a6SAndroid Build Coastguard Worker        self.permute = torch.permute
908*523fa7a6SAndroid Build Coastguard Worker        self.eval()
909*523fa7a6SAndroid Build Coastguard Worker
910*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
911*523fa7a6SAndroid Build Coastguard Worker        x1 = self.conv1(x)
912*523fa7a6SAndroid Build Coastguard Worker        x2 = self.batch_norm(x1)
913*523fa7a6SAndroid Build Coastguard Worker        x3 = self.relu(x2)
914*523fa7a6SAndroid Build Coastguard Worker        x4 = self.conv2(x3)
915*523fa7a6SAndroid Build Coastguard Worker        x5 = self.relu(x4)
916*523fa7a6SAndroid Build Coastguard Worker        y1 = self.conv3(y)
917*523fa7a6SAndroid Build Coastguard Worker        y2 = self.batch_norm(y1)
918*523fa7a6SAndroid Build Coastguard Worker        y3 = self.relu(y2)
919*523fa7a6SAndroid Build Coastguard Worker        y4 = self.conv4(y3)
920*523fa7a6SAndroid Build Coastguard Worker        y5 = self.relu(y4)
921*523fa7a6SAndroid Build Coastguard Worker        z = self.add(x5, y5)
922*523fa7a6SAndroid Build Coastguard Worker        z1 = self.permute(z, (0, 3, 2, 1))
923*523fa7a6SAndroid Build Coastguard Worker        z2 = torch.mean(z1, [1, 2], True)
924*523fa7a6SAndroid Build Coastguard Worker        z3 = self.reshape(z2, (8, -1))
925*523fa7a6SAndroid Build Coastguard Worker        z4 = self.linear(z3)
926*523fa7a6SAndroid Build Coastguard Worker        z5 = self.hardtanh(z4)
927*523fa7a6SAndroid Build Coastguard Worker        return z5
928*523fa7a6SAndroid Build Coastguard Worker
929*523fa7a6SAndroid Build Coastguard Worker
930*523fa7a6SAndroid Build Coastguard Workerclass SliceCopy(torch.nn.Module):
931*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
932*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
933*523fa7a6SAndroid Build Coastguard Worker        self.position_ids = torch.randn([1, 512])
934*523fa7a6SAndroid Build Coastguard Worker
935*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
936*523fa7a6SAndroid Build Coastguard Worker        seq_length = y.size()[1]
937*523fa7a6SAndroid Build Coastguard Worker        return x[:, :seq_length] + self.position_ids[:, :seq_length]
938*523fa7a6SAndroid Build Coastguard Worker
939*523fa7a6SAndroid Build Coastguard Worker
940*523fa7a6SAndroid Build Coastguard Workerclass SliceCopyWithStep(torch.nn.Module):
941*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
942*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
943*523fa7a6SAndroid Build Coastguard Worker        self.position_ids = torch.randn([1, 512])
944*523fa7a6SAndroid Build Coastguard Worker        self.step = 2
945*523fa7a6SAndroid Build Coastguard Worker
946*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
947*523fa7a6SAndroid Build Coastguard Worker        seq_length = y.size()[1]
948*523fa7a6SAndroid Build Coastguard Worker        return (
949*523fa7a6SAndroid Build Coastguard Worker            x[:, : seq_length : self.step]
950*523fa7a6SAndroid Build Coastguard Worker            + self.position_ids[:, : seq_length : self.step]
951*523fa7a6SAndroid Build Coastguard Worker        )
952*523fa7a6SAndroid Build Coastguard Worker
953*523fa7a6SAndroid Build Coastguard Worker
954*523fa7a6SAndroid Build Coastguard Workerclass Softmax(torch.nn.Module):
955*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
956*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
957*523fa7a6SAndroid Build Coastguard Worker
958*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
959*523fa7a6SAndroid Build Coastguard Worker        return torch.nn.functional.softmax(x, dim=-1)
960*523fa7a6SAndroid Build Coastguard Worker
961*523fa7a6SAndroid Build Coastguard Worker
962*523fa7a6SAndroid Build Coastguard Workerclass Sqrt(torch.nn.Module):
963*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
964*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
965*523fa7a6SAndroid Build Coastguard Worker
966*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
967*523fa7a6SAndroid Build Coastguard Worker        return torch.sqrt(x)
968*523fa7a6SAndroid Build Coastguard Worker
969*523fa7a6SAndroid Build Coastguard Worker
970*523fa7a6SAndroid Build Coastguard Workerclass SqrtConstant(torch.nn.Module):
971*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
972*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
973*523fa7a6SAndroid Build Coastguard Worker
974*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
975*523fa7a6SAndroid Build Coastguard Worker        return x / torch.sqrt(torch.tensor([64.0]))
976*523fa7a6SAndroid Build Coastguard Worker
977*523fa7a6SAndroid Build Coastguard Worker
978*523fa7a6SAndroid Build Coastguard Workerclass Squeeze(torch.nn.Module):
979*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
980*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
981*523fa7a6SAndroid Build Coastguard Worker
982*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
983*523fa7a6SAndroid Build Coastguard Worker        return x.squeeze()
984*523fa7a6SAndroid Build Coastguard Worker
985*523fa7a6SAndroid Build Coastguard Worker
986*523fa7a6SAndroid Build Coastguard Workerclass Stack(torch.nn.Module):
987*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
988*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
989*523fa7a6SAndroid Build Coastguard Worker
990*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
991*523fa7a6SAndroid Build Coastguard Worker        return torch.stack((x, y))
992*523fa7a6SAndroid Build Coastguard Worker
993*523fa7a6SAndroid Build Coastguard Worker
994*523fa7a6SAndroid Build Coastguard Workerclass Sub(torch.nn.Module):
995*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
996*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
997*523fa7a6SAndroid Build Coastguard Worker
998*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
999*523fa7a6SAndroid Build Coastguard Worker        return torch.sub(x, y)
1000*523fa7a6SAndroid Build Coastguard Worker
1001*523fa7a6SAndroid Build Coastguard Worker
1002*523fa7a6SAndroid Build Coastguard Workerclass SubConstantFloat(torch.nn.Module):
1003*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
1004*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
1005*523fa7a6SAndroid Build Coastguard Worker
1006*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
1007*523fa7a6SAndroid Build Coastguard Worker        return 10.0 - x
1008*523fa7a6SAndroid Build Coastguard Worker
1009*523fa7a6SAndroid Build Coastguard Worker
1010*523fa7a6SAndroid Build Coastguard Workerclass SubConstantLong(torch.nn.Module):
1011*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
1012*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
1013*523fa7a6SAndroid Build Coastguard Worker
1014*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
1015*523fa7a6SAndroid Build Coastguard Worker        return 10 - x
1016*523fa7a6SAndroid Build Coastguard Worker
1017*523fa7a6SAndroid Build Coastguard Worker
1018*523fa7a6SAndroid Build Coastguard Workerclass SumIntList(torch.nn.Module):
1019*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
1020*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
1021*523fa7a6SAndroid Build Coastguard Worker
1022*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
1023*523fa7a6SAndroid Build Coastguard Worker        return torch.sum(x, dim=(2, 3), keepdim=True)
1024*523fa7a6SAndroid Build Coastguard Worker
1025*523fa7a6SAndroid Build Coastguard Worker
1026*523fa7a6SAndroid Build Coastguard Workerclass Tanh(torch.nn.Module):
1027*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
1028*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
1029*523fa7a6SAndroid Build Coastguard Worker
1030*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
1031*523fa7a6SAndroid Build Coastguard Worker        return torch.tanh(x)
1032*523fa7a6SAndroid Build Coastguard Worker
1033*523fa7a6SAndroid Build Coastguard Worker
1034*523fa7a6SAndroid Build Coastguard Workerclass TopKandIndex(torch.nn.Module):
1035*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
1036*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
1037*523fa7a6SAndroid Build Coastguard Worker        self.idx_source = torch.rand(10, 3)
1038*523fa7a6SAndroid Build Coastguard Worker
1039*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
1040*523fa7a6SAndroid Build Coastguard Worker        a, b = torch.topk(x, 3)
1041*523fa7a6SAndroid Build Coastguard Worker        return a + self.idx_source[b]
1042*523fa7a6SAndroid Build Coastguard Worker
1043*523fa7a6SAndroid Build Coastguard Worker
1044*523fa7a6SAndroid Build Coastguard Workerclass Unbind(torch.nn.Module):
1045*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
1046*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
1047*523fa7a6SAndroid Build Coastguard Worker
1048*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
1049*523fa7a6SAndroid Build Coastguard Worker        return torch.unbind(x)
1050*523fa7a6SAndroid Build Coastguard Worker
1051*523fa7a6SAndroid Build Coastguard Worker
1052*523fa7a6SAndroid Build Coastguard Workerclass Unsqueeze(torch.nn.Module):
1053*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
1054*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
1055*523fa7a6SAndroid Build Coastguard Worker
1056*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
1057*523fa7a6SAndroid Build Coastguard Worker        return x.unsqueeze(0)
1058*523fa7a6SAndroid Build Coastguard Worker
1059*523fa7a6SAndroid Build Coastguard Worker
1060*523fa7a6SAndroid Build Coastguard Workerclass View(torch.nn.Module):
1061*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
1062*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
1063*523fa7a6SAndroid Build Coastguard Worker        self.first_size = 2
1064*523fa7a6SAndroid Build Coastguard Worker        self.second_size = 256
1065*523fa7a6SAndroid Build Coastguard Worker
1066*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
1067*523fa7a6SAndroid Build Coastguard Worker        new_shape = x.size()[:-1] + (self.first_size, self.second_size)
1068*523fa7a6SAndroid Build Coastguard Worker        return x.view(new_shape)
1069*523fa7a6SAndroid Build Coastguard Worker
1070*523fa7a6SAndroid Build Coastguard Worker
1071*523fa7a6SAndroid Build Coastguard Workerclass ViewPermuteMatMul(torch.nn.Module):
1072*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
1073*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
1074*523fa7a6SAndroid Build Coastguard Worker        self.first_size = 2
1075*523fa7a6SAndroid Build Coastguard Worker        self.second_size = 256
1076*523fa7a6SAndroid Build Coastguard Worker
1077*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x, y):
1078*523fa7a6SAndroid Build Coastguard Worker        new_shape = x.size()[:-1] + (self.first_size, self.second_size)
1079*523fa7a6SAndroid Build Coastguard Worker        x = x.view(new_shape)
1080*523fa7a6SAndroid Build Coastguard Worker        x = x.permute(0, 2, 1, 3)
1081*523fa7a6SAndroid Build Coastguard Worker        return torch.matmul(x, y.transpose(-1, -2))
1082