1# Owner(s): ["module: inductor"] 2 3import os 4import shutil 5import tempfile 6 7import torch 8import torch._export 9import torch._inductor 10import torch.export._trace 11import torch.fx._pytree as fx_pytree 12from torch.testing._internal.common_utils import IS_FBCODE 13from torch.utils import _pytree as pytree 14 15 16class WrapperModule(torch.nn.Module): 17 def __init__(self, model): 18 super().__init__() 19 self.model = model 20 21 def forward(self, *args, **kwargs): 22 return self.model(*args, **kwargs) 23 24 25class AOTIRunnerUtil: 26 @staticmethod 27 def compile( 28 model, 29 example_inputs, 30 options=None, 31 dynamic_shapes=None, 32 disable_constraint_solver=False, 33 ): 34 if not isinstance(model, torch.nn.Module): 35 model = WrapperModule(model) 36 # The exact API is subject to change 37 if torch._inductor.config.is_predispatch: 38 ep = torch.export._trace._export( 39 model, example_inputs, dynamic_shapes=dynamic_shapes, pre_dispatch=True 40 ) 41 gm = ep.module() 42 else: 43 gm = torch.export._trace._export_to_torch_ir( 44 model, 45 example_inputs, 46 dynamic_shapes=dynamic_shapes, 47 disable_constraint_solver=disable_constraint_solver, 48 # Disabling this flag, because instead we can rely on the mapping 49 # dynamo_flat_name_to_original_fqn which is coming from Dynamo. 50 restore_fqn=False, 51 ) 52 53 if IS_FBCODE: 54 from deeplearning.aot_inductor.extern_node_thrift_serializer import ( 55 thrift_serializer, 56 ) 57 58 if options is None: 59 options = {} 60 options["extern_node_serializer"] = thrift_serializer 61 62 with torch.no_grad(): 63 so_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type] 64 65 return so_path 66 67 @staticmethod 68 def load_runner(device, so_path): 69 if IS_FBCODE: 70 from .fb import test_aot_inductor_model_runner_pybind 71 72 with tempfile.TemporaryDirectory() as temp_dir: 73 # copy *.so file to a unique path just before loading 74 # to avoid stale dlopen handles when an updated *.so 75 # from the same path is loaded repetitively in a test 76 temp_so_path = os.path.join(temp_dir, "model.so") 77 shutil.copy(so_path, temp_so_path) 78 79 # We also need to copy over the serialized extern_kernel_nodes for custom ops 80 extern_kernel_nodes_path = f"{so_path[:-3]}.json" 81 if os.path.isfile(extern_kernel_nodes_path): 82 temp_extern_kernel_nodes_path = os.path.join(temp_dir, "model.json") 83 shutil.copy(extern_kernel_nodes_path, temp_extern_kernel_nodes_path) 84 85 return test_aot_inductor_model_runner_pybind.Runner( 86 temp_so_path, device == "cpu" 87 ) 88 else: 89 return ( 90 torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) 91 if device == "cpu" 92 else torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) 93 ) 94 95 @staticmethod 96 def load(device, so_path): 97 # TODO: unify fbcode and oss behavior to only use torch._export.aot_load 98 if IS_FBCODE: 99 runner = AOTIRunnerUtil.load_runner(device, so_path) 100 101 def optimized(*args, **kwargs): 102 call_spec = runner.get_call_spec() 103 in_spec = pytree.treespec_loads(call_spec[0]) 104 out_spec = pytree.treespec_loads(call_spec[1]) 105 flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec) 106 flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] 107 flat_outputs = runner.run(flat_inputs) 108 return pytree.tree_unflatten(flat_outputs, out_spec) 109 110 return optimized 111 else: 112 return torch._export.aot_load(so_path, device) 113 114 @staticmethod 115 def run( 116 device, 117 model, 118 example_inputs, 119 options=None, 120 dynamic_shapes=None, 121 disable_constraint_solver=False, 122 ): 123 so_path = AOTIRunnerUtil.compile( 124 model, 125 example_inputs, 126 options=options, 127 dynamic_shapes=dynamic_shapes, 128 disable_constraint_solver=disable_constraint_solver, 129 ) 130 optimized = AOTIRunnerUtil.load(device, so_path) 131 return optimized(*example_inputs) 132 133 @staticmethod 134 def run_multiple( 135 device, 136 model, 137 list_example_inputs, 138 options=None, 139 dynamic_shapes=None, 140 ): 141 so_path = AOTIRunnerUtil.compile( 142 model, 143 list_example_inputs[0], 144 options=options, 145 dynamic_shapes=dynamic_shapes, 146 ) 147 optimized = AOTIRunnerUtil.load(device, so_path) 148 list_output_tensors = [] 149 for example_inputs in list_example_inputs: 150 list_output_tensors.append(optimized(*example_inputs)) 151 return list_output_tensors 152