1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8 9 10class OpSequencesAddConv2d(torch.nn.Module): 11 """ 12 Module which include sequences of Memory Format sensitive ops. forward runs 13 [num_sequences] sequences of [ops_per_sequences] ops. Each sequence is 14 followed by an add to separate the sequences 15 """ 16 17 def __init__(self, num_sequences, ops_per_sequence): 18 super().__init__() 19 self.num_ops = num_sequences * ops_per_sequence 20 self.num_sequences = num_sequences 21 22 self.op_sequence = torch.nn.ModuleList() 23 for _ in range(num_sequences): 24 inner = torch.nn.ModuleList() 25 for _ in range(ops_per_sequence): 26 inner.append( 27 torch.nn.Conv2d( 28 in_channels=1, 29 out_channels=1, 30 kernel_size=(3, 3), 31 padding=1, 32 bias=False, 33 ) 34 ) 35 self.op_sequence.append(inner) 36 37 def forward(self, x): 38 for seq in self.op_sequence: 39 for op in seq: 40 x = op(x) 41 x = x + x 42 return x + x 43