xref: /aosp_15_r20/external/pytorch/torch/_lazy/extract_compiled_graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import dataclasses
4import itertools
5import os
6from typing import Any, Callable, Dict, List
7
8import torch
9import torch._lazy as lazy
10import torch._lazy.metrics as metrics
11from torch import fx
12from torch._lazy import computation, debug as lazy_debug
13from torch._lazy.tensor_factory_functions import tensor_factory_functions
14
15
16debug = os.environ.get("debug_extract_compiled_graph") is not None
17
18
19@dataclasses.dataclass
20class GraphInputMatcher:
21    """
22    The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing.
23    Specifically, those graph inputs corresponding to method parameters should be replaced with the
24    arguments for the current call.
25
26    tensor_id_to_arg_idx maps the tensor id to the parameter index.
27    graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the
28    TS/XLA graph inputs.
29    """
30
31    tensor_id_to_arg_idx: Dict[int, int]
32    graph_input_tensor_ids: List[int]
33    # there are 2 categories of graph_input_tensors.
34    # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are
35    # most likely const tensors and we can get its content from graph_input_tensors
36    # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get
37    #  the tensor from method arguments
38    graph_input_ivalues: List[Any]
39
40    # get the real graph input tensors
41    def __call__(self, args):
42        real_input = []
43        for tensor_id, traced_ivalue in zip(
44            self.graph_input_tensor_ids, self.graph_input_ivalues
45        ):
46            arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None)
47            if arg_idx is None:
48                inp = traced_ivalue
49            else:
50                inp = args[arg_idx]
51            real_input.append(inp)
52        return real_input
53
54
55class ReturnValueHandler:
56    r"""
57    When ltc_sync_multi is called on multi tensors, the compiled graph
58    will contain output only for unique tensors - if a tensor appears multiple
59    times in the input to _ltc_sync_multi, only the first occurance matters.
60
61    However from python level, we still expect multi tensors returned with duplciation
62    even if the TS graph dedup the output. e.g. for method:
63
64      def forward(self, a):
65        return a, a
66
67    the TS graph captured by LTC will return a single tensor, but Python method expects 2.
68
69    This class dedup the lazy tensors first to get the index that will be used
70    to duplicate the eager tensors later.
71    """
72
73    def __init__(self, lazy_out_list):
74        self.index: List[List[int]] = []
75        self.total_count = len(lazy_out_list)
76
77        tensor_id_to_idx: Dict[int, int] = {}
78        for dup_idx, lazy_tensor in enumerate(lazy_out_list):
79            uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None)
80            if uniq_idx is not None:
81                self.index[uniq_idx].append(dup_idx)
82            else:
83                uniq_idx = len(self.index)
84                self.index.append([dup_idx])
85                tensor_id_to_idx[id(lazy_tensor)] = uniq_idx
86
87    def duplicate_eager_tensors(self, eager_tensor_list):
88        duplicated_list = [None] * self.total_count
89        assert len(eager_tensor_list) == len(self.index)
90
91        for uniq_idx, eager_tensor in enumerate(eager_tensor_list):
92            for dup_idx in self.index[uniq_idx]:
93                duplicated_list[dup_idx] = eager_tensor
94        return duplicated_list
95
96
97def force_lazy_device(model: fx.GraphModule):
98    """
99    Factory methods in a Fx graph may create tensors for a specific eager devices.
100    If we take no actions, those eager tensors will be mixed with lazy tensors and
101    cause crash. This method overwrite those eager device to lazy device.
102    """
103
104    def tolazydevice(dev):
105        if isinstance(dev, torch.device):
106            return torch.device("lazy", index=dev.index)
107        return dev
108
109    def hasDeviceArg(args, kwargs):
110        return any(
111            isinstance(arg, torch.device)
112            for arg in itertools.chain(args, kwargs.values())
113        )
114
115    for nd in model.graph.nodes:
116        nd.args = tuple(tolazydevice(arg) for arg in nd.args)
117        nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()}
118
119        # For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return
120        # eager tensors on the default device
121        # (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove,
122        # and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart).
123        # To force those tensors on the lazy device, we can not simply override
124        # the device argument since there is no explicit device argument.
125        # What we are doing here is, for the list of covered tensor factory methods
126        # we add a lazy device argument explicity.
127        #
128        # TODO: This solution is no ideal since we may miss some factory methods. In future
129        # when we support lazy mode, this method can be replaced by that.
130        if nd.target in tensor_factory_functions and not hasDeviceArg(
131            nd.args, nd.kwargs
132        ):
133            kwargs = dict(nd.kwargs)  # nd.kwargs is immutable. make a mutable copy.
134            kwargs["device"] = torch.device("lazy")
135            nd.kwargs = kwargs
136
137    model.recompile()
138
139
140def get_fallback_ops():
141    fallback_ops = []
142    for opname in metrics.counter_names():
143        if "aten::" not in opname:
144            continue
145        val = int(metrics.counter_value(opname))
146        if val > 0:
147            fallback_ops.append(f"{opname}={val}")
148
149    return fallback_ops
150
151
152def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable:
153    """
154    Optimize an eager model with LTC and returns a wrapper to execute the
155    compiled graph directly without retracing. It depends on other mechanisms
156    like TorchDynamo guards to guarantee the returned wrapper is only called
157    when it's safe.
158    """
159    lazy_args = [arg.to(device="lazy") for arg in example_inputs]
160    args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args]
161    tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)}
162    lazy_model = copy.deepcopy(model).to(device=torch.device("lazy"))
163    force_lazy_device(lazy_model)
164
165    # This line executes lazy tracing and enable us extracting compiled graph later
166    metrics.reset()
167    lazy_out = lazy_model(*lazy_args)
168    fallback_ops = get_fallback_ops()
169    metrics.reset()
170
171    if len(fallback_ops) > 0:
172        raise RuntimeError(
173            f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}"
174        )
175
176    if not isinstance(lazy_out, (tuple, list)):
177        lazy_out = (lazy_out,)
178
179    args_and_out = tuple(lazy_args) + tuple(lazy_out)
180    return_value_handler = ReturnValueHandler(args_and_out)
181    if debug:
182        print("Fx code:\n", model.code)
183        print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text"))
184
185    # TODO: this part is TS backend specific for now and will be generalized to
186    # support XLA
187    (
188        graph_input_tensor_ids,
189        graph_input_ivalues,
190    ) = computation.get_tensors_ts_device_data_node(args_and_out)
191    assert len(graph_input_tensor_ids) == len(graph_input_ivalues)
192    graph_input_matcher = GraphInputMatcher(
193        tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues
194    )
195
196    graph_hash = computation.get_graph_hash(args_and_out)
197
198    if debug:
199        print("graph_hash", graph_hash)
200        print(f"args_tensor_ids {args_tensor_ids}")
201        print("tensor ids from device data:", graph_input_tensor_ids)
202
203    # sync the list of output tensors so the computation graph for these
204    # tensors will be cached. Those computation graphs can be retrieved
205    # by graph hash later.
206    lazy.sync_multi(args_and_out, [])
207
208    def optimized_mod(*args):
209        if len(args_and_out) == 0:
210            return ()
211        graph_input = graph_input_matcher(args)
212        res = return_value_handler.duplicate_eager_tensors(
213            computation.run_cached_graph(graph_hash, graph_input)
214        )
215
216        assert len(res) == len(args_and_out)
217        for i, arg in enumerate(args):
218            # only copy those tensors that get inplace updated
219            if arg is not res[i]:
220                arg.copy_(res[i])
221
222        # skip the args
223        return res[len(args) :]
224
225    return optimized_mod
226