1import torch 2from torch.export import Dim 3 4 5# custom op that loads the aot-compiled model 6AOTI_CUSTOM_OP_LIB = "libaoti_custom_class.so" 7torch.classes.load_library(AOTI_CUSTOM_OP_LIB) 8 9 10class TensorSerializer(torch.nn.Module): 11 def __init__(self, data): 12 super().__init__() 13 for key in data: 14 setattr(self, key, data[key]) 15 16 17class SimpleModule(torch.nn.Module): 18 """ 19 a simple module to be compiled 20 """ 21 22 def __init__(self) -> None: 23 super().__init__() 24 self.fc = torch.nn.Linear(4, 6) 25 self.relu = torch.nn.ReLU() 26 27 def forward(self, x): 28 a = self.fc(x) 29 b = self.relu(a) 30 return b 31 32 33class MyAOTIModule(torch.nn.Module): 34 """ 35 a wrapper nn.Module that instantiates its forward method 36 on MyAOTIClass 37 """ 38 39 def __init__(self, lib_path, device): 40 super().__init__() 41 self.aoti_custom_op = torch.classes.aoti.MyAOTIClass( 42 lib_path, 43 device, 44 ) 45 46 def forward(self, *x): 47 outputs = self.aoti_custom_op.forward(x) 48 return tuple(outputs) 49 50 51def make_script_module(lib_path, device, *inputs): 52 m = MyAOTIModule(lib_path, device) 53 # sanity check 54 m(*inputs) 55 return torch.jit.trace(m, inputs) 56 57 58def compile_model(device, data): 59 module = SimpleModule().to(device) 60 x = torch.randn((4, 4), device=device) 61 inputs = (x,) 62 # make batch dimension 63 batch_dim = Dim("batch", min=1, max=1024) 64 dynamic_shapes = { 65 "x": {0: batch_dim}, 66 } 67 with torch.no_grad(): 68 # aot-compile the module into a .so pointed by lib_path 69 lib_path = torch._export.aot_compile( 70 module, inputs, dynamic_shapes=dynamic_shapes 71 ) 72 script_module = make_script_module(lib_path, device, *inputs) 73 aoti_script_model = f"script_model_{device}.pt" 74 script_module.save(aoti_script_model) 75 76 # save sample inputs and ref output 77 with torch.no_grad(): 78 ref_output = module(*inputs) 79 data.update( 80 { 81 f"inputs_{device}": list(inputs), 82 f"outputs_{device}": [ref_output], 83 } 84 ) 85 86 87def main(): 88 data = {} 89 for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: 90 compile_model(device, data) 91 torch.jit.script(TensorSerializer(data)).save("script_data.pt") 92 93 94if __name__ == "__main__": 95 main() 96