xref: /aosp_15_r20/external/pytorch/test/cpp/aoti_inference/compile_model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2from torch.export import Dim
3
4
5# custom op that loads the aot-compiled model
6AOTI_CUSTOM_OP_LIB = "libaoti_custom_class.so"
7torch.classes.load_library(AOTI_CUSTOM_OP_LIB)
8
9
10class TensorSerializer(torch.nn.Module):
11    def __init__(self, data):
12        super().__init__()
13        for key in data:
14            setattr(self, key, data[key])
15
16
17class SimpleModule(torch.nn.Module):
18    """
19    a simple module to be compiled
20    """
21
22    def __init__(self) -> None:
23        super().__init__()
24        self.fc = torch.nn.Linear(4, 6)
25        self.relu = torch.nn.ReLU()
26
27    def forward(self, x):
28        a = self.fc(x)
29        b = self.relu(a)
30        return b
31
32
33class MyAOTIModule(torch.nn.Module):
34    """
35    a wrapper nn.Module that instantiates its forward method
36    on MyAOTIClass
37    """
38
39    def __init__(self, lib_path, device):
40        super().__init__()
41        self.aoti_custom_op = torch.classes.aoti.MyAOTIClass(
42            lib_path,
43            device,
44        )
45
46    def forward(self, *x):
47        outputs = self.aoti_custom_op.forward(x)
48        return tuple(outputs)
49
50
51def make_script_module(lib_path, device, *inputs):
52    m = MyAOTIModule(lib_path, device)
53    # sanity check
54    m(*inputs)
55    return torch.jit.trace(m, inputs)
56
57
58def compile_model(device, data):
59    module = SimpleModule().to(device)
60    x = torch.randn((4, 4), device=device)
61    inputs = (x,)
62    # make batch dimension
63    batch_dim = Dim("batch", min=1, max=1024)
64    dynamic_shapes = {
65        "x": {0: batch_dim},
66    }
67    with torch.no_grad():
68        # aot-compile the module into a .so pointed by lib_path
69        lib_path = torch._export.aot_compile(
70            module, inputs, dynamic_shapes=dynamic_shapes
71        )
72    script_module = make_script_module(lib_path, device, *inputs)
73    aoti_script_model = f"script_model_{device}.pt"
74    script_module.save(aoti_script_model)
75
76    # save sample inputs and ref output
77    with torch.no_grad():
78        ref_output = module(*inputs)
79    data.update(
80        {
81            f"inputs_{device}": list(inputs),
82            f"outputs_{device}": [ref_output],
83        }
84    )
85
86
87def main():
88    data = {}
89    for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
90        compile_model(device, data)
91    torch.jit.script(TensorSerializer(data)).save("script_data.pt")
92
93
94if __name__ == "__main__":
95    main()
96