1# mypy: ignore-errors 2 3# Owner(s): ["oncall: distributed"] 4 5from typing import Tuple 6 7import torch 8import torch.nn as nn 9 10 11class UnitModule(nn.Module): 12 def __init__(self, device: torch.device): 13 super().__init__() 14 self.l1 = nn.Linear(100, 100, device=device) 15 self.seq = nn.Sequential( 16 nn.ReLU(), 17 nn.Linear(100, 100, device=device), 18 nn.ReLU(), 19 ) 20 self.l2 = nn.Linear(100, 100, device=device) 21 22 def forward(self, x): 23 return self.l2(self.seq(self.l1(x))) 24 25 26class CompositeModel(nn.Module): 27 def __init__(self, device: torch.device): 28 super().__init__() 29 self.l1 = nn.Linear(100, 100, device=device) 30 self.u1 = UnitModule(device) 31 self.u2 = UnitModule(device) 32 self.l2 = nn.Linear(100, 100, device=device) 33 34 def forward(self, x): 35 return self.l2(self.u2(self.u1(self.l1(x)))) 36 37 38class UnitParamModule(nn.Module): 39 def __init__(self, device: torch.device): 40 super().__init__() 41 self.l = nn.Linear(100, 100, device=device) 42 self.seq = nn.Sequential( 43 nn.ReLU(), 44 nn.Linear(100, 100, device=device), 45 nn.ReLU(), 46 ) 47 self.p = nn.Parameter(torch.randn((100, 100), device=device)) 48 49 def forward(self, x): 50 return torch.mm(self.seq(self.l(x)), self.p) 51 52 53class CompositeParamModel(nn.Module): 54 def __init__(self, device: torch.device): 55 super().__init__() 56 self.l = nn.Linear(100, 100, device=device) 57 self.u1 = UnitModule(device) 58 self.u2 = UnitModule(device) 59 self.p = nn.Parameter(torch.randn((100, 100), device=device)) 60 self.register_buffer( 61 "buffer", torch.randn((100, 100), device=device), persistent=True 62 ) 63 64 def forward(self, x): 65 a = self.u2(self.u1(self.l(x))) 66 b = self.p 67 return torch.mm(a, b) 68 69 70class FakeSequential(nn.Module): 71 # Define this class to achieve a desired nested wrapping using the module 72 # wrap policy with `nn.Sequential` 73 def __init__(self, *modules: Tuple[nn.Module, ...]) -> None: 74 super().__init__() 75 self._module_sequence = list(modules) 76 77 def forward(self, x: torch.Tensor) -> torch.Tensor: 78 for module in self._module_sequence: 79 x = module(x) 80 return x 81 82 83class NestedSequentialModel(nn.Module): 84 def __init__(self, device: torch.device) -> None: 85 super().__init__() 86 # This nested structure exercises traversal order to catch differences 87 # between valid traversals (e.g. BFS and DFS variations). 88 self.seq1 = nn.Sequential( 89 nn.Linear(1, 1, device=device), 90 FakeSequential( 91 nn.Linear(1, 1, device=device), 92 nn.ReLU(), 93 FakeSequential( 94 nn.Linear(1, 1, device=device), 95 ), 96 nn.ReLU(), 97 ), 98 nn.Linear(1, 2, device=device), 99 ) 100 self.lin = nn.Linear(2, 2, device=device) 101 self.seq2 = nn.Sequential( 102 nn.ReLU(), 103 nn.Linear(2, 3, device=device), 104 FakeSequential( 105 nn.Linear(3, 2, bias=False, device=device), 106 nn.Linear(2, 4, bias=False, device=device), 107 ), 108 ) 109 110 def forward(self, x: torch.Tensor) -> torch.Tensor: 111 return self.seq2(self.lin(self.seq1(x))) 112