xref: /aosp_15_r20/external/pytorch/test/inductor/test_aot_inductor_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3import os
4import shutil
5import tempfile
6
7import torch
8import torch._export
9import torch._inductor
10import torch.export._trace
11import torch.fx._pytree as fx_pytree
12from torch.testing._internal.common_utils import IS_FBCODE
13from torch.utils import _pytree as pytree
14
15
16class WrapperModule(torch.nn.Module):
17    def __init__(self, model):
18        super().__init__()
19        self.model = model
20
21    def forward(self, *args, **kwargs):
22        return self.model(*args, **kwargs)
23
24
25class AOTIRunnerUtil:
26    @staticmethod
27    def compile(
28        model,
29        example_inputs,
30        options=None,
31        dynamic_shapes=None,
32        disable_constraint_solver=False,
33    ):
34        if not isinstance(model, torch.nn.Module):
35            model = WrapperModule(model)
36        # The exact API is subject to change
37        if torch._inductor.config.is_predispatch:
38            ep = torch.export._trace._export(
39                model, example_inputs, dynamic_shapes=dynamic_shapes, pre_dispatch=True
40            )
41            gm = ep.module()
42        else:
43            gm = torch.export._trace._export_to_torch_ir(
44                model,
45                example_inputs,
46                dynamic_shapes=dynamic_shapes,
47                disable_constraint_solver=disable_constraint_solver,
48                # Disabling this flag, because instead we can rely on the mapping
49                # dynamo_flat_name_to_original_fqn which is coming from Dynamo.
50                restore_fqn=False,
51            )
52
53        if IS_FBCODE:
54            from deeplearning.aot_inductor.extern_node_thrift_serializer import (
55                thrift_serializer,
56            )
57
58            if options is None:
59                options = {}
60            options["extern_node_serializer"] = thrift_serializer
61
62        with torch.no_grad():
63            so_path = torch._inductor.aot_compile(gm, example_inputs, options=options)  # type: ignore[arg-type]
64
65        return so_path
66
67    @staticmethod
68    def load_runner(device, so_path):
69        if IS_FBCODE:
70            from .fb import test_aot_inductor_model_runner_pybind
71
72            with tempfile.TemporaryDirectory() as temp_dir:
73                # copy *.so file to a unique path just before loading
74                # to avoid stale dlopen handles when an updated *.so
75                # from the same path is loaded repetitively in a test
76                temp_so_path = os.path.join(temp_dir, "model.so")
77                shutil.copy(so_path, temp_so_path)
78
79                # We also need to copy over the serialized extern_kernel_nodes for custom ops
80                extern_kernel_nodes_path = f"{so_path[:-3]}.json"
81                if os.path.isfile(extern_kernel_nodes_path):
82                    temp_extern_kernel_nodes_path = os.path.join(temp_dir, "model.json")
83                    shutil.copy(extern_kernel_nodes_path, temp_extern_kernel_nodes_path)
84
85                return test_aot_inductor_model_runner_pybind.Runner(
86                    temp_so_path, device == "cpu"
87                )
88        else:
89            return (
90                torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)
91                if device == "cpu"
92                else torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device)
93            )
94
95    @staticmethod
96    def load(device, so_path):
97        # TODO: unify fbcode and oss behavior to only use torch._export.aot_load
98        if IS_FBCODE:
99            runner = AOTIRunnerUtil.load_runner(device, so_path)
100
101            def optimized(*args, **kwargs):
102                call_spec = runner.get_call_spec()
103                in_spec = pytree.treespec_loads(call_spec[0])
104                out_spec = pytree.treespec_loads(call_spec[1])
105                flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
106                flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
107                flat_outputs = runner.run(flat_inputs)
108                return pytree.tree_unflatten(flat_outputs, out_spec)
109
110            return optimized
111        else:
112            return torch._export.aot_load(so_path, device)
113
114    @staticmethod
115    def run(
116        device,
117        model,
118        example_inputs,
119        options=None,
120        dynamic_shapes=None,
121        disable_constraint_solver=False,
122    ):
123        so_path = AOTIRunnerUtil.compile(
124            model,
125            example_inputs,
126            options=options,
127            dynamic_shapes=dynamic_shapes,
128            disable_constraint_solver=disable_constraint_solver,
129        )
130        optimized = AOTIRunnerUtil.load(device, so_path)
131        return optimized(*example_inputs)
132
133    @staticmethod
134    def run_multiple(
135        device,
136        model,
137        list_example_inputs,
138        options=None,
139        dynamic_shapes=None,
140    ):
141        so_path = AOTIRunnerUtil.compile(
142            model,
143            list_example_inputs[0],
144            options=options,
145            dynamic_shapes=dynamic_shapes,
146        )
147        optimized = AOTIRunnerUtil.load(device, so_path)
148        list_output_tensors = []
149        for example_inputs in list_example_inputs:
150            list_output_tensors.append(optimized(*example_inputs))
151        return list_output_tensors
152