xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/test_xnnpack_utils_classes.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8
9
10class OpSequencesAddConv2d(torch.nn.Module):
11    """
12    Module which include sequences of Memory Format sensitive ops. forward runs
13    [num_sequences] sequences of [ops_per_sequences] ops. Each sequence is
14    followed by an add to separate the sequences
15    """
16
17    def __init__(self, num_sequences, ops_per_sequence):
18        super().__init__()
19        self.num_ops = num_sequences * ops_per_sequence
20        self.num_sequences = num_sequences
21
22        self.op_sequence = torch.nn.ModuleList()
23        for _ in range(num_sequences):
24            inner = torch.nn.ModuleList()
25            for _ in range(ops_per_sequence):
26                inner.append(
27                    torch.nn.Conv2d(
28                        in_channels=1,
29                        out_channels=1,
30                        kernel_size=(3, 3),
31                        padding=1,
32                        bias=False,
33                    )
34                )
35            self.op_sequence.append(inner)
36
37    def forward(self, x):
38        for seq in self.op_sequence:
39            for op in seq:
40                x = op(x)
41            x = x + x
42        return x + x
43