1# Owner(s): ["oncall: package/deploy"] 2 3import torch 4from torch.fx import wrap 5 6 7wrap("a_non_torch_leaf") 8 9 10class ModWithSubmod(torch.nn.Module): 11 def __init__(self, script_mod): 12 super().__init__() 13 self.script_mod = script_mod 14 15 def forward(self, x): 16 return self.script_mod(x) 17 18 19class ModWithTensor(torch.nn.Module): 20 def __init__(self, tensor): 21 super().__init__() 22 self.tensor = tensor 23 24 def forward(self, x): 25 return self.tensor * x 26 27 28class ModWithSubmodAndTensor(torch.nn.Module): 29 def __init__(self, tensor, sub_mod): 30 super().__init__() 31 self.tensor = tensor 32 self.sub_mod = sub_mod 33 34 def forward(self, x): 35 return self.sub_mod(x) + self.tensor 36 37 38class ModWithTwoSubmodsAndTensor(torch.nn.Module): 39 def __init__(self, tensor, sub_mod_0, sub_mod_1): 40 super().__init__() 41 self.tensor = tensor 42 self.sub_mod_0 = sub_mod_0 43 self.sub_mod_1 = sub_mod_1 44 45 def forward(self, x): 46 return self.sub_mod_0(x) + self.sub_mod_1(x) + self.tensor 47 48 49class ModWithMultipleSubmods(torch.nn.Module): 50 def __init__(self, mod1, mod2): 51 super().__init__() 52 self.mod1 = mod1 53 self.mod2 = mod2 54 55 def forward(self, x): 56 return self.mod1(x) + self.mod2(x) 57 58 59class SimpleTest(torch.nn.Module): 60 def forward(self, x): 61 x = a_non_torch_leaf(x, x) 62 return torch.relu(x + 3.0) 63 64 65def a_non_torch_leaf(a, b): 66 return a + b 67