xref: /aosp_15_r20/external/pytorch/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict, Optional
3import torch
4from torch._ops import OpOverload, HigherOrderOperator
5from torch._export.error import InternalError
6from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
7
8
9__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"]
10
11
12_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = {
13    torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
14}
15
16
17def is_view_op(schema: torch._C.FunctionSchema) -> bool:
18    if len(schema.arguments) == 0:
19        return False
20    alias_info = schema.arguments[0].alias_info
21    return (alias_info is not None) and (not alias_info.is_write)
22
23
24def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]:
25    if is_view_op(schema) and schema.name.startswith("aten::"):
26        view_op_name = schema.name.split("::")[1]
27        view_op_overload = (
28            schema.overload_name
29            if schema.overload_name != ""
30            else "default"
31        )
32        view_copy_op_name = view_op_name + "_copy"
33        if not hasattr(torch.ops.aten, view_copy_op_name):
34            raise InternalError(f"{schema.name} is missing a view_copy variant")
35
36        view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name)
37
38        if not hasattr(view_copy_op_overload_packet, view_op_overload):
39            raise InternalError(f"{schema.name} is missing a view_copy variant")
40
41        return getattr(view_copy_op_overload_packet, view_op_overload)
42
43    return None
44
45
46class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse):
47    """
48    Our backend expects pure functional operators. For efficiency
49    purposes, we keep view ops around while functionalizing the exported
50    program. This pass replaces view ops with view copy ops for backends that
51    need AOT memory planning.
52    """
53    def call_operator(self, op, args, kwargs, meta):
54        if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS:
55            return super().call_operator(
56                (_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta
57            )
58
59        if isinstance(op, HigherOrderOperator):
60            return super().call_operator(op, args, kwargs, meta)
61
62        if view_copy_op := get_view_copy_of_view_op(op._schema):
63            return super().call_operator(view_copy_op, args, kwargs, meta)
64
65        return super().call_operator(op, args, kwargs, meta)
66