xref: /aosp_15_r20/external/pytorch/test/cpp/aoti_inference/test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2from torch._export import aot_compile
3from torch.export import Dim
4
5
6torch.manual_seed(1337)
7
8
9class Net(torch.nn.Module):
10    def __init__(self, device):
11        super().__init__()
12        self.w_pre = torch.randn(4, 4, device=device)
13        self.w_add = torch.randn(4, 4, device=device)
14
15    def forward(self, x):
16        w_transpose = torch.transpose(self.w_pre, 0, 1)
17        w_relu = torch.nn.functional.relu(w_transpose)
18        w = w_relu + self.w_add
19        return torch.matmul(x, w)
20
21
22class NetWithTensorConstants(torch.nn.Module):
23    def __init__(self) -> None:
24        super().__init__()
25        self.w = torch.randn(30, 1, device="cuda")
26
27    def forward(self, x, y):
28        z = self.w * x * y
29        return z[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17]]
30
31
32data = {}
33data_with_tensor_constants = {}
34
35
36# Basice AOTI model test generation.
37def generate_basic_tests():
38    for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
39        for use_runtime_constant_folding in [True, False]:
40            if device == "cpu" and use_runtime_constant_folding:
41                # We do not test runtime const folding for cpu mode.
42                continue
43            model = Net(device).to(device=device)
44            x = torch.randn((4, 4), device=device)
45            with torch.no_grad():
46                ref_output = model(x)
47
48            torch._dynamo.reset()
49            with torch.no_grad():
50                dim0_x = Dim("dim0_x", min=1, max=1024)
51                dynamic_shapes = {"x": {0: dim0_x}}
52                model_so_path = aot_compile(
53                    model,
54                    (x,),
55                    dynamic_shapes=dynamic_shapes,
56                    options={
57                        "aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
58                    },
59                )
60
61            suffix = f"{device}"
62            if use_runtime_constant_folding:
63                suffix += "_use_runtime_constant_folding"
64            data.update(
65                {
66                    f"model_so_path_{suffix}": model_so_path,
67                    f"inputs_{suffix}": [x],
68                    f"outputs_{suffix}": [ref_output],
69                    f"w_pre_{suffix}": model.w_pre,
70                    f"w_add_{suffix}": model.w_add,
71                }
72            )
73
74
75# AOTI model which will create additional tensors during autograd.
76def generate_test_with_additional_tensors():
77    if not torch.cuda.is_available():
78        return
79
80    model = NetWithTensorConstants()
81    x = torch.randn((30, 1), device="cuda")
82    y = torch.randn((30, 1), device="cuda")
83    with torch.no_grad():
84        ref_output = model(x, y)
85
86    torch._dynamo.reset()
87    with torch.no_grad():
88        model_so_path = aot_compile(model, (x, y))
89
90    data_with_tensor_constants.update(
91        {
92            "model_so_path": model_so_path,
93            "inputs": [x, y],
94            "outputs": [ref_output],
95            "w": model.w,
96        }
97    )
98
99
100generate_basic_tests()
101generate_test_with_additional_tensors()
102
103
104# Use this to communicate tensors to the cpp code
105class Serializer(torch.nn.Module):
106    def __init__(self, data):
107        super().__init__()
108        for key in data:
109            setattr(self, key, data[key])
110
111
112torch.jit.script(Serializer(data)).save("data.pt")
113torch.jit.script(Serializer(data_with_tensor_constants)).save(
114    "data_with_tensor_constants.pt"
115)
116