xref: /aosp_15_r20/external/pytorch/torch/export/_unlift.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import warnings
4from itertools import chain
5from typing import Any, Dict, List, Optional, Tuple
6
7import torch
8import torch.utils._pytree as pytree
9from torch._export.utils import _check_input_constraints_for_graph
10from torch.export.unflatten import _assign_attr, _AttrKind, _recursive_getattr
11from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
12
13from ._remove_effect_tokens_pass import _remove_effect_tokens
14from .exported_program import (
15    ExportedProgram,
16    ExportGraphSignature,
17    InputKind,
18    OutputKind,
19)
20
21
22@torch._dynamo.disable
23def _check_input_constraints_pre_hook(self, *args, **kwargs):
24    flat_args_with_path, received_spec = pytree.tree_flatten_with_path(args)
25
26    if received_spec != self._in_spec:
27        raise ValueError(  # noqa: B904
28            "Trying to flatten user inputs with exported input tree spec: \n"
29            f"{self._in_spec}\n"
30            "but actually got inputs with tree spec of: \n"
31            f"{received_spec}"
32        )
33
34    return _check_input_constraints_for_graph(
35        [node for node in self.graph.nodes if node.op == "placeholder"],
36        flat_args_with_path,
37        self.range_constraints,
38    )
39
40
41def _unlift_inputs_as_getattr(
42    gm: torch.fx.GraphModule,
43    lifted_inputs: List[Optional[str]],
44) -> Tuple[Dict[str, torch.fx.Node], Dict[str, torch.fx.Node]]:
45    """
46    Unlift inputs referring to params/buffers/constants as getattr nodes in the
47    graph
48    """
49    unlifted_name_to_node = {}
50    input_name_to_node = {}
51
52    placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
53    assert len(lifted_inputs) == len(placeholder_nodes)
54    for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs):
55        if lifted_node is None:
56            input_name_to_node[input_node.name] = input_node
57
58        else:
59            with gm.graph.inserting_after(input_node):
60                getattr_node = gm.graph.get_attr(lifted_node)
61                input_node.replace_all_uses_with(getattr_node)
62                metadata = input_node.meta
63                gm.graph.erase_node(input_node)
64                getattr_node.meta = metadata
65                unlifted_name_to_node[lifted_node] = getattr_node
66
67    return unlifted_name_to_node, input_name_to_node
68
69
70def _insert_copy_for_mutations(
71    gm: torch.fx.GraphModule,
72    mutated_outputs: List[Optional[str]],
73    unlifted_name_to_node: Dict[str, torch.fx.Node],
74    input_name_to_node: Dict[str, torch.fx.Node],
75) -> None:
76    """
77    Find the all the buffers and inputs that were mutated and insert copy_
78    operators to reflect mutations.
79    """
80    output_node = None
81    for node in gm.graph.nodes:
82        if node.op == "output":
83            output_node = node
84            break
85    assert output_node is not None
86    outputs = pytree.tree_flatten(output_node.args)[0]
87    assert len(outputs) == len(mutated_outputs)
88
89    user_output_nodes = []
90    return_nodes_to_copy = {}
91    for return_node, mutated_node_name in zip(outputs, mutated_outputs):
92        if mutated_node_name is None:
93            user_output_nodes.append(return_node)
94            continue
95
96        if mutated_node_name in unlifted_name_to_node:
97            mutated_node = unlifted_name_to_node[mutated_node_name]
98        elif mutated_node_name in input_name_to_node:
99            mutated_node = input_name_to_node[mutated_node_name]
100        else:
101            raise RuntimeError(
102                f"Could not find {mutated_node_name} in either buffer or input nodes"
103            )
104
105        with gm.graph.inserting_before(output_node):
106            copy_node = gm.graph.call_function(
107                torch.ops.aten.copy_.default, (mutated_node, return_node)
108            )
109            return_nodes_to_copy[return_node] = copy_node
110
111    output_args = [
112        return_nodes_to_copy[node] if node in return_nodes_to_copy else node
113        for node in user_output_nodes
114    ]
115    with gm.graph.inserting_before(output_node):
116        # Only return user outputs
117        new_output = gm.graph.output(tuple(output_args))
118        new_output.meta.update(output_node.meta)
119        output_node.replace_all_uses_with(new_output)
120        gm.graph.erase_node(output_node)
121
122
123def _get_codegen(
124    in_spec: pytree.TreeSpec,
125    out_spec: Optional[pytree.TreeSpec],
126    forward_arg_names: Optional[List[str]] = None,
127) -> _PyTreeCodeGen:
128    """
129    Create the codegen for the graph module based on the in/out specs
130    """
131    if forward_arg_names:
132        names = forward_arg_names
133    else:
134        if (
135            in_spec.type == tuple
136            and in_spec.num_children == 2
137            and in_spec.children_specs[0].type == tuple
138            and in_spec.children_specs[1].type == dict
139        ):
140            # if in_spec contains the args (tuple) and kwargs (dict)
141            names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
142            # add kwarg names
143            names.extend(in_spec.children_specs[1].context)
144        else:
145            names = [f"arg_{i}" for i in range(in_spec.num_children)]
146
147    return _PyTreeCodeGen(
148        _PyTreeInfo(
149            names,
150            in_spec,
151            out_spec,
152        )
153    )
154
155
156def _unlift(
157    gm: torch.fx.GraphModule,
158    lifted_inputs: List[Optional[str]],
159    mutated_outputs: List[Optional[str]],
160    in_spec: pytree.TreeSpec,
161    out_spec: Optional[pytree.TreeSpec],
162    state_dict: Dict[str, Any],
163    constants: Dict[str, Any],
164    forward_arg_names: Optional[List[str]] = None,
165):
166    """
167    Args:
168        lifted_inputs: A list matching the graph module's input nodes. For
169        an input node that is referring to a lifted parameter/buffer, this
170        list will contain the fqn the corresponding attribute. Otherwise, this
171        list will contain None. This is used to unlift the lifted parameters as
172        get_attr nodes.
173
174        mutated_outputs: A list matching the graph module's output nodes. For
175        an output node that is referring to a mutated buffer or user input, this
176        list will contain the name of the corresponding buffer or user input
177        that needs to be mutated. Otherwise, this list will contain None. This
178        is used to re-insert an inplace copy_ operator to copy the mutated
179        values back to the original node.
180    """
181    unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr(
182        gm, lifted_inputs
183    )
184    _insert_copy_for_mutations(
185        gm, mutated_outputs, unlifted_name_to_node, input_name_to_node
186    )
187    gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names)
188    gm.graph.lint()
189    gm.recompile()
190    return gm
191
192
193def _register_attrs_to_new_gm(
194    new_gm: torch.fx.GraphModule,
195    graph_signature: ExportGraphSignature,
196    state_dict: Dict[str, Any],
197    constants: Dict[str, Any],
198) -> None:
199    non_persistent_buffers = set(graph_signature.non_persistent_buffers)
200    for name in graph_signature.buffers:
201        if name in non_persistent_buffers:
202            persistent = False
203            value = constants[name]
204        else:
205            persistent = True
206            value = state_dict[name]
207        _assign_attr(
208            value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent
209        )
210    for name in graph_signature.parameters:
211        value = state_dict[name]
212        _assign_attr(
213            value,
214            new_gm,
215            name,
216            attr_kind=_AttrKind.PARAMETER,
217        )
218
219    for name in chain(
220        graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants
221    ):
222        value = constants[name]
223        _assign_attr(
224            value,
225            new_gm,
226            name,
227            attr_kind=_AttrKind.CONSTANT,
228        )
229
230
231class _StatefulGraphModuleFactory(type):
232    """
233    Metaclass that ensures a private constructor for _StatefulGraphModule
234    """
235
236    def __call__(cls, *args, **kwargs):
237        raise TypeError(
238            f"{cls.__module__}.{cls.__qualname__} has no public constructor. "
239        )
240
241    def _create(cls, root, graph, range_constraints=None):
242        return super().__call__(
243            root,
244            graph,
245            range_constraints=range_constraints,
246        )
247
248
249class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory):
250    def __init__(self, root, graph, range_constraints=None):
251        super().__init__(root, graph)
252        # Need to fix up non-persistent buffers.
253        self.range_constraints = range_constraints or []
254
255
256def _create_stateful_graph_module(
257    plain_graph_module: torch.fx.GraphModule,
258    range_constraints,
259    # TODO(suo) this should not be optional, but is since we still ahve
260    # capture_pre_autograd_graph grr
261    graph_signature: Optional[ExportGraphSignature] = None,
262):
263    stateful_gm = _StatefulGraphModule._create(
264        plain_graph_module,
265        plain_graph_module.graph,
266        range_constraints=range_constraints,
267    )
268
269    stateful_gm.register_forward_pre_hook(
270        _check_input_constraints_pre_hook, with_kwargs=True
271    )
272
273    if graph_signature is None:
274        return stateful_gm
275
276    # Fix up lifted tensor constants.
277    # fx.GraphModule() constructor silently turns a constant attribute of plain_graph_module
278    # into a buffer in stateful_gm and creates an inconsistency with graph_signature.
279    # We fix this by de-registering these buffers in lifted_tensor_constants
280    # and call _assign_attr(attr_kind=CONSTANT) to register them as constants.
281    for constant_fqn in graph_signature.lifted_tensor_constants:
282        # Sometimes, the constant can require gradient, this is probably a bug in user code,
283        # e.g. `self.const = torch.randn(2, 2, requires_grad=True)`.
284        # We call detach on the constant_val since they're tensor contants and we don't need to
285        # compute their gradients anyway.
286        # Users should properly register it as parameter if they want it to require gradient.
287        buffer = stateful_gm.get_buffer(constant_fqn)
288        if buffer.requires_grad:
289            warnings.warn(
290                f"A model attribute `{constant_fqn}` requires gradient. "
291                f"but it's not properly registered as a parameter. "
292                f"torch.export will detach it and treat it as a constant tensor "
293                f"but please register it as parameter instead."
294            )
295            buffer = buffer.detach()
296        *prefix, field = constant_fqn.rsplit(".")
297        submod = _recursive_getattr(stateful_gm, prefix)
298        delattr(submod, field)
299        _assign_attr(buffer, stateful_gm, constant_fqn, attr_kind=_AttrKind.CONSTANT)
300
301    # Fix up non-persistent buffers. torch.fx does not distinguish between
302    # persistent and non-persistent buffers, so we must restore that distinction
303    # here.
304    for buffer in graph_signature.non_persistent_buffers:
305        _assign_attr(
306            plain_graph_module.get_buffer(buffer),
307            stateful_gm,
308            buffer,
309            attr_kind=_AttrKind.BUFFER,
310            persistent=False,
311        )
312
313    return stateful_gm
314
315
316def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module:
317    ep = _remove_effect_tokens(ep)
318    new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
319    _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
320    forward_arg_names = ep.graph_module.meta.get("forward_arg_names")
321
322    lifted_inputs: List[Optional[str]] = [
323        (
324            in_spec.target
325            if in_spec.kind
326            in (
327                InputKind.BUFFER,
328                InputKind.CONSTANT_TENSOR,
329                InputKind.PARAMETER,
330                InputKind.CUSTOM_OBJ,
331            )
332            else None
333        )
334        for in_spec in ep.graph_signature.input_specs
335    ]
336
337    mutated_outputs: List[Optional[str]] = [
338        (
339            out_spec.target
340            if out_spec.kind
341            in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION)
342            else None
343        )
344        for out_spec in ep.graph_signature.output_specs
345    ]
346
347    new_gm = _unlift(
348        new_gm,
349        lifted_inputs,
350        mutated_outputs,
351        ep.call_spec.in_spec,
352        ep.call_spec.out_spec,
353        ep.state_dict,
354        ep.constants,
355        forward_arg_names=forward_arg_names,
356    )
357    unlift_gm = _create_stateful_graph_module(
358        new_gm, ep.range_constraints, ep.graph_signature
359    )
360    unlift_gm.meta.update(ep.graph_module.meta)
361    return unlift_gm
362