1import torch 2from torch._export import aot_compile 3from torch.export import Dim 4 5 6torch.manual_seed(1337) 7 8 9class Net(torch.nn.Module): 10 def __init__(self, device): 11 super().__init__() 12 self.w_pre = torch.randn(4, 4, device=device) 13 self.w_add = torch.randn(4, 4, device=device) 14 15 def forward(self, x): 16 w_transpose = torch.transpose(self.w_pre, 0, 1) 17 w_relu = torch.nn.functional.relu(w_transpose) 18 w = w_relu + self.w_add 19 return torch.matmul(x, w) 20 21 22class NetWithTensorConstants(torch.nn.Module): 23 def __init__(self) -> None: 24 super().__init__() 25 self.w = torch.randn(30, 1, device="cuda") 26 27 def forward(self, x, y): 28 z = self.w * x * y 29 return z[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17]] 30 31 32data = {} 33data_with_tensor_constants = {} 34 35 36# Basice AOTI model test generation. 37def generate_basic_tests(): 38 for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: 39 for use_runtime_constant_folding in [True, False]: 40 if device == "cpu" and use_runtime_constant_folding: 41 # We do not test runtime const folding for cpu mode. 42 continue 43 model = Net(device).to(device=device) 44 x = torch.randn((4, 4), device=device) 45 with torch.no_grad(): 46 ref_output = model(x) 47 48 torch._dynamo.reset() 49 with torch.no_grad(): 50 dim0_x = Dim("dim0_x", min=1, max=1024) 51 dynamic_shapes = {"x": {0: dim0_x}} 52 model_so_path = aot_compile( 53 model, 54 (x,), 55 dynamic_shapes=dynamic_shapes, 56 options={ 57 "aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding 58 }, 59 ) 60 61 suffix = f"{device}" 62 if use_runtime_constant_folding: 63 suffix += "_use_runtime_constant_folding" 64 data.update( 65 { 66 f"model_so_path_{suffix}": model_so_path, 67 f"inputs_{suffix}": [x], 68 f"outputs_{suffix}": [ref_output], 69 f"w_pre_{suffix}": model.w_pre, 70 f"w_add_{suffix}": model.w_add, 71 } 72 ) 73 74 75# AOTI model which will create additional tensors during autograd. 76def generate_test_with_additional_tensors(): 77 if not torch.cuda.is_available(): 78 return 79 80 model = NetWithTensorConstants() 81 x = torch.randn((30, 1), device="cuda") 82 y = torch.randn((30, 1), device="cuda") 83 with torch.no_grad(): 84 ref_output = model(x, y) 85 86 torch._dynamo.reset() 87 with torch.no_grad(): 88 model_so_path = aot_compile(model, (x, y)) 89 90 data_with_tensor_constants.update( 91 { 92 "model_so_path": model_so_path, 93 "inputs": [x, y], 94 "outputs": [ref_output], 95 "w": model.w, 96 } 97 ) 98 99 100generate_basic_tests() 101generate_test_with_additional_tensors() 102 103 104# Use this to communicate tensors to the cpp code 105class Serializer(torch.nn.Module): 106 def __init__(self, data): 107 super().__init__() 108 for key in data: 109 setattr(self, key, data[key]) 110 111 112torch.jit.script(Serializer(data)).save("data.pt") 113torch.jit.script(Serializer(data_with_tensor_constants)).save( 114 "data_with_tensor_constants.pt" 115) 116