xref: /aosp_15_r20/external/pytorch/test/package/package_a/test_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3import torch
4from torch.fx import wrap
5
6
7wrap("a_non_torch_leaf")
8
9
10class ModWithSubmod(torch.nn.Module):
11    def __init__(self, script_mod):
12        super().__init__()
13        self.script_mod = script_mod
14
15    def forward(self, x):
16        return self.script_mod(x)
17
18
19class ModWithTensor(torch.nn.Module):
20    def __init__(self, tensor):
21        super().__init__()
22        self.tensor = tensor
23
24    def forward(self, x):
25        return self.tensor * x
26
27
28class ModWithSubmodAndTensor(torch.nn.Module):
29    def __init__(self, tensor, sub_mod):
30        super().__init__()
31        self.tensor = tensor
32        self.sub_mod = sub_mod
33
34    def forward(self, x):
35        return self.sub_mod(x) + self.tensor
36
37
38class ModWithTwoSubmodsAndTensor(torch.nn.Module):
39    def __init__(self, tensor, sub_mod_0, sub_mod_1):
40        super().__init__()
41        self.tensor = tensor
42        self.sub_mod_0 = sub_mod_0
43        self.sub_mod_1 = sub_mod_1
44
45    def forward(self, x):
46        return self.sub_mod_0(x) + self.sub_mod_1(x) + self.tensor
47
48
49class ModWithMultipleSubmods(torch.nn.Module):
50    def __init__(self, mod1, mod2):
51        super().__init__()
52        self.mod1 = mod1
53        self.mod2 = mod2
54
55    def forward(self, x):
56        return self.mod1(x) + self.mod2(x)
57
58
59class SimpleTest(torch.nn.Module):
60    def forward(self, x):
61        x = a_non_torch_leaf(x, x)
62        return torch.relu(x + 3.0)
63
64
65def a_non_torch_leaf(a, b):
66    return a + b
67