xref: /aosp_15_r20/external/pytorch/test/mobile/model_test/nn_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5
6# https://pytorch.org/docs/stable/nn.html
7class NNConvolutionModule(torch.nn.Module):
8    def __init__(self) -> None:
9        super().__init__()
10        self.input1d = torch.randn(1, 4, 36)
11        self.input2d = torch.randn(1, 4, 30, 10)
12        self.input3d = torch.randn(1, 4, 10, 4, 4)
13        self.module1d = nn.ModuleList(
14            [
15                nn.Conv1d(4, 33, 3),
16                nn.ConvTranspose1d(4, 33, 3),
17                nn.Fold(output_size=(5, 10), kernel_size=(2, 2)),
18            ]
19        )
20        self.module2d = nn.ModuleList(
21            [
22                nn.Conv2d(4, 33, 3),
23                nn.ConvTranspose2d(4, 33, 3),
24                nn.Unfold(kernel_size=3),
25            ]
26        )
27        self.module3d = nn.ModuleList(
28            [
29                nn.Conv3d(4, 33, 2),
30                nn.ConvTranspose3d(4, 33, 3),
31            ]
32        )
33
34    def forward(self):
35        return len(
36            (
37                [module(self.input1d) for i, module in enumerate(self.module1d)],
38                [module(self.input2d) for i, module in enumerate(self.module2d)],
39                [module(self.input3d) for i, module in enumerate(self.module3d)],
40            )
41        )
42
43
44class NNPoolingModule(torch.nn.Module):
45    def __init__(self) -> None:
46        super().__init__()
47        self.input1d = torch.randn(1, 16, 50)
48        self.module1d = nn.ModuleList(
49            [
50                nn.MaxPool1d(3, stride=2),
51                nn.AvgPool1d(3, stride=2),
52                nn.LPPool1d(2, 3, stride=2),
53                nn.AdaptiveMaxPool1d(3),
54                nn.AdaptiveAvgPool1d(3),
55            ]
56        )
57
58        self.input2d = torch.randn(1, 16, 30, 10)
59        self.module2d = nn.ModuleList(
60            [
61                nn.MaxPool2d((3, 2), stride=(2, 1)),
62                nn.AvgPool2d((3, 2), stride=(2, 1)),
63                nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5)),
64                nn.LPPool2d(2, 3, stride=(2, 1)),
65                nn.AdaptiveMaxPool2d((5, 7)),
66                nn.AdaptiveAvgPool2d(7),
67            ]
68        )
69
70        self.input3d = torch.randn(1, 16, 20, 4, 4)
71        self.module3d = nn.ModuleList(
72            [
73                nn.MaxPool3d(2),
74                nn.AvgPool3d(2),
75                nn.FractionalMaxPool3d(2, output_ratio=(0.5, 0.5, 0.5)),
76                nn.AdaptiveMaxPool3d((5, 7, 9)),
77                nn.AdaptiveAvgPool3d((5, 7, 9)),
78            ]
79        )
80        # TODO max_unpool
81
82    def forward(self):
83        return len(
84            (
85                [module(self.input1d) for i, module in enumerate(self.module1d)],
86                [module(self.input2d) for i, module in enumerate(self.module2d)],
87                [module(self.input3d) for i, module in enumerate(self.module3d)],
88            )
89        )
90
91
92class NNPaddingModule(torch.nn.Module):
93    def __init__(self) -> None:
94        super().__init__()
95        self.input1d = torch.randn(1, 4, 50)
96        self.module1d = nn.ModuleList(
97            [
98                nn.ReflectionPad1d(2),
99                nn.ReplicationPad1d(2),
100                nn.ConstantPad1d(2, 3.5),
101            ]
102        )
103
104        self.input2d = torch.randn(1, 4, 30, 10)
105        self.module2d = nn.ModuleList(
106            [
107                nn.ReflectionPad2d(2),
108                nn.ReplicationPad2d(2),
109                nn.ZeroPad2d(2),
110                nn.ConstantPad2d(2, 3.5),
111            ]
112        )
113
114        self.input3d = torch.randn(1, 4, 10, 4, 4)
115        self.module3d = nn.ModuleList(
116            [
117                nn.ReflectionPad3d(1),
118                nn.ReplicationPad3d(3),
119                nn.ConstantPad3d(3, 3.5),
120            ]
121        )
122
123    def forward(self):
124        return len(
125            (
126                [module(self.input1d) for i, module in enumerate(self.module1d)],
127                [module(self.input2d) for i, module in enumerate(self.module2d)],
128                [module(self.input3d) for i, module in enumerate(self.module3d)],
129            )
130        )
131
132
133class NNNormalizationModule(torch.nn.Module):
134    def __init__(self) -> None:
135        super().__init__()
136        self.input1d = torch.randn(1, 4, 50)
137        self.module1d = nn.ModuleList(
138            [
139                nn.BatchNorm1d(4),
140                nn.InstanceNorm1d(4),
141            ]
142        )
143
144        self.input2d = torch.randn(1, 4, 30, 10)
145        self.module2d = nn.ModuleList(
146            [
147                nn.BatchNorm2d(4),
148                nn.GroupNorm(4, 4),
149                nn.InstanceNorm2d(4),
150                nn.LayerNorm([4, 30, 10]),
151                nn.LocalResponseNorm(2),
152            ]
153        )
154
155        self.input3d = torch.randn(1, 4, 10, 4, 4)
156        self.module3d = nn.ModuleList(
157            [
158                nn.BatchNorm3d(4),
159                nn.InstanceNorm3d(4),
160                nn.ChannelShuffle(2),
161            ]
162        )
163
164    def forward(self):
165        return len(
166            (
167                [module(self.input1d) for i, module in enumerate(self.module1d)],
168                [module(self.input2d) for i, module in enumerate(self.module2d)],
169                [module(self.input3d) for i, module in enumerate(self.module3d)],
170            )
171        )
172
173
174class NNActivationModule(torch.nn.Module):
175    def __init__(self) -> None:
176        super().__init__()
177        self.activations = nn.ModuleList(
178            [
179                nn.ELU(),
180                nn.Hardshrink(),
181                nn.Hardsigmoid(),
182                nn.Hardtanh(),
183                nn.Hardswish(),
184                nn.LeakyReLU(),
185                nn.LogSigmoid(),
186                # nn.MultiheadAttention(),
187                nn.PReLU(),
188                nn.ReLU(),
189                nn.ReLU6(),
190                nn.RReLU(),
191                nn.SELU(),
192                nn.CELU(),
193                nn.GELU(),
194                nn.Sigmoid(),
195                nn.SiLU(),
196                nn.Mish(),
197                nn.Softplus(),
198                nn.Softshrink(),
199                nn.Softsign(),
200                nn.Tanh(),
201                nn.Tanhshrink(),
202                # nn.Threshold(0.1, 20),
203                nn.GLU(),
204                nn.Softmin(),
205                nn.Softmax(),
206                nn.Softmax2d(),
207                nn.LogSoftmax(),
208                # nn.AdaptiveLogSoftmaxWithLoss(),
209            ]
210        )
211
212    def forward(self):
213        input = torch.randn(2, 3, 4)
214        return len(([module(input) for i, module in enumerate(self.activations)],))
215
216
217class NNRecurrentModule(torch.nn.Module):
218    def __init__(self) -> None:
219        super().__init__()
220        self.rnn = nn.ModuleList(
221            [
222                nn.RNN(4, 8, 2),
223                nn.RNNCell(4, 8),
224            ]
225        )
226        self.gru = nn.ModuleList([nn.GRU(4, 8, 2), nn.GRUCell(4, 8)])
227        self.lstm = nn.ModuleList(
228            [
229                nn.LSTM(4, 8, 2),
230                nn.LSTMCell(4, 8),
231            ]
232        )
233
234    def forward(self):
235        input = torch.randn(5, 3, 4)
236        h = torch.randn(2, 3, 8)
237        c = torch.randn(2, 3, 8)
238        r = self.rnn[0](input, h)
239        r = self.rnn[1](input[0], h[0])
240        r = self.gru[0](input, h)
241        r = self.gru[1](input[0], h[0])
242        r = self.lstm[0](input, (h, c))
243        r = self.lstm[1](input[0], (h[0], c[0]))
244        return len(r)
245
246
247class NNTransformerModule(torch.nn.Module):
248    def __init__(self) -> None:
249        super().__init__()
250        self.transformers = nn.ModuleList(
251            [
252                nn.Transformer(
253                    d_model=2, nhead=2, num_encoder_layers=1, num_decoder_layers=1
254                ),
255                nn.TransformerEncoder(
256                    nn.TransformerEncoderLayer(d_model=2, nhead=2), num_layers=1
257                ),
258                nn.TransformerDecoder(
259                    nn.TransformerDecoderLayer(d_model=2, nhead=2), num_layers=1
260                ),
261            ]
262        )
263
264    def forward(self):
265        input = torch.rand(1, 16, 2)
266        tgt = torch.rand((1, 16, 2))
267        r = self.transformers[0](input, tgt)
268        r = self.transformers[1](input)
269        r = self.transformers[2](input, tgt)
270        return len(r)
271
272
273class NNLinearModule(torch.nn.Module):
274    def __init__(self) -> None:
275        super().__init__()
276        self.linears = nn.ModuleList(
277            [
278                nn.Identity(54),
279                nn.Linear(20, 20),
280                nn.Bilinear(20, 20, 40),
281                # nn.LazyLinear(20, 30),
282            ]
283        )
284
285    def forward(self):
286        input = torch.randn(32, 20)
287        r = self.linears[0](input)
288        r = self.linears[1](input)
289        r = self.linears[2](input, input)
290        return len(r)
291
292
293class NNDropoutModule(torch.nn.Module):
294    def forward(self):
295        a = torch.randn(8, 4)
296        b = torch.randn(8, 4, 4, 4)
297        c = torch.randn(8, 4, 4, 4, 4)
298        return len(
299            F.dropout(a),
300            F.dropout2d(b),
301            F.dropout3d(c),
302            F.alpha_dropout(a),
303            F.feature_alpha_dropout(c),
304        )
305
306
307class NNSparseModule(torch.nn.Module):
308    def forward(self):
309        input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
310        input2 = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
311        embedding_matrix = torch.rand(10, 3)
312        offsets = torch.tensor([0, 4])
313        return len(
314            F.embedding(input, embedding_matrix),
315            F.embedding_bag(input2, embedding_matrix, offsets),
316            F.one_hot(torch.arange(0, 5) % 3, num_classes=5),
317        )
318
319
320class NNDistanceModule(torch.nn.Module):
321    def forward(self):
322        a = torch.randn(8, 4)
323        b = torch.randn(8, 4)
324        return len(
325            F.pairwise_distance(a, b),
326            F.cosine_similarity(a, b),
327            F.pdist(a),
328        )
329
330
331class NNLossFunctionModule(torch.nn.Module):
332    def __init__(self) -> None:
333        super().__init__()
334        self.x = torch.FloatTensor([[0.1, 0.2, 0.4, 0.8]])
335        self.y = torch.LongTensor([[3, 0, -1, 1]])
336
337    def forward(self):
338        a = torch.randn(3, 2)
339        b = torch.rand(3, 2)
340        c = torch.rand(3)
341        log_probs = torch.randn(50, 16, 20).log_softmax(2).detach()
342        targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
343        input_lengths = torch.full((16,), 50, dtype=torch.long)
344        target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
345        return len(
346            F.binary_cross_entropy(torch.sigmoid(a), b),
347            F.binary_cross_entropy_with_logits(torch.sigmoid(a), b),
348            F.poisson_nll_loss(a, b),
349            F.cosine_embedding_loss(a, b, c),
350            F.cross_entropy(a, b),
351            F.ctc_loss(log_probs, targets, input_lengths, target_lengths),
352            # F.gaussian_nll_loss(a, b, torch.ones(5, 1)), # ENTER is not supported in mobile module
353            F.hinge_embedding_loss(a, b),
354            F.kl_div(a, b),
355            F.l1_loss(a, b),
356            F.mse_loss(a, b),
357            F.margin_ranking_loss(c, c, c),
358            F.multilabel_margin_loss(self.x, self.y),
359            F.multilabel_soft_margin_loss(self.x, self.y),
360            F.multi_margin_loss(self.x, torch.tensor([3])),
361            F.nll_loss(a, torch.tensor([1, 0, 1])),
362            F.huber_loss(a, b),
363            F.smooth_l1_loss(a, b),
364            F.soft_margin_loss(a, b),
365            F.triplet_margin_loss(a, b, -b),
366            # F.triplet_margin_with_distance_loss(a, b, -b), # can't take variable number of arguments
367        )
368
369
370class NNVisionModule(torch.nn.Module):
371    def __init__(self) -> None:
372        super().__init__()
373        self.input = torch.randn(1, 4, 9, 9)
374        self.vision_modules = nn.ModuleList(
375            [
376                nn.PixelShuffle(2),
377                nn.PixelUnshuffle(3),
378                nn.Upsample(scale_factor=2, mode="nearest"),
379                nn.Upsample(scale_factor=2, mode="bilinear"),
380                nn.Upsample(scale_factor=2, mode="bicubic"),
381                nn.UpsamplingNearest2d(scale_factor=2),
382                nn.UpsamplingBilinear2d(scale_factor=2),
383            ]
384        )
385        self.linear_sample = nn.Upsample(scale_factor=2, mode="linear")
386        self.trilinear_sample = nn.Upsample(scale_factor=2, mode="trilinear")
387
388    def forward(self):
389        input = torch.randn(1, 3, 16, 16)
390        for i, module in enumerate(self.vision_modules):
391            r = module(self.input)
392        return len(
393            r,
394            self.linear_sample(torch.randn(4, 9, 9)),
395            self.trilinear_sample(torch.randn(1, 3, 4, 9, 9)),
396            F.grid_sample(input, torch.ones(1, 4, 4, 2)),
397        )
398
399
400class NNShuffleModule(torch.nn.Module):
401    def __init__(self) -> None:
402        super().__init__()
403        self.shuffle = nn.ChannelShuffle(2)
404
405    def forward(self):
406        return len(
407            self.shuffle(torch.randn(1, 4, 2, 2)),
408        )
409
410
411class NNUtilsModule(torch.nn.Module):
412    def __init__(self) -> None:
413        super().__init__()
414        self.flatten = nn.Sequential(nn.Linear(50, 50), nn.Unflatten(1, (2, 5, 5)))
415
416    def forward(self):
417        a = [torch.tensor([1, 2, 3]), torch.tensor([3, 4])]
418        b = nn.utils.rnn.pad_sequence(a, batch_first=True)
419        # c = nn.utils.rnn.pack_padded_sequence(b, batch_first=True, lengths=torch.tensor([3, 2]))
420        input = torch.randn(2, 50)
421        return len(
422            self.flatten(input),
423            b,
424        )
425