xref: /aosp_15_r20/external/executorch/exir/serde/upgrade.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport logging
8*523fa7a6SAndroid Build Coastguard Workerimport re
9*523fa7a6SAndroid Build Coastguard Workerfrom collections import defaultdict
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional, Tuple
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport torch
13*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
14*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.pass_infra.node_metadata import NodeMetadata
15*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.pass_infra.proxy_value import ProxyValue
16*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.node import Argument, Target
17*523fa7a6SAndroid Build Coastguard Workerfrom torch.library import Library
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Workerlib = Library("aten", "FRAGMENT")
20*523fa7a6SAndroid Build Coastguard Workerimpl_lib = Library("aten", "IMPL")
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Workerlog = logging.getLogger(__name__)
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Workerdef get_target_version(versioned_upgrader_name: str) -> int:
26*523fa7a6SAndroid Build Coastguard Worker    """div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is
27*523fa7a6SAndroid Build Coastguard Worker    upgrading to version 4."""
28*523fa7a6SAndroid Build Coastguard Worker    if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name):
29*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid")
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Worker    return int(versioned_upgrader_name.split("_")[-1]) + 1
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Workerdef get_upgraders() -> Dict[str, Tuple[str, str]]:
35*523fa7a6SAndroid Build Coastguard Worker    """Getting upgraders entry map and operator version map and merge them into one dict."""
36*523fa7a6SAndroid Build Coastguard Worker    upgraders = torch._C._get_upgraders_entry_map()
37*523fa7a6SAndroid Build Coastguard Worker    op_version_map = torch._C._get_operator_version_map()
38*523fa7a6SAndroid Build Coastguard Worker    output: Dict[str, Tuple[str, str]] = defaultdict(tuple)  # type: ignore[arg-type]
39*523fa7a6SAndroid Build Coastguard Worker    for opname, entry_list in op_version_map.items():
40*523fa7a6SAndroid Build Coastguard Worker        if not entry_list:
41*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(f"Op version map has an empty entry for opname {opname}")
42*523fa7a6SAndroid Build Coastguard Worker        entry = entry_list[0]
43*523fa7a6SAndroid Build Coastguard Worker        old_schema = entry.old_schema
44*523fa7a6SAndroid Build Coastguard Worker        upgrader_name = entry.upgrader_name
45*523fa7a6SAndroid Build Coastguard Worker        upgrader_str = upgraders.get(upgrader_name, None)
46*523fa7a6SAndroid Build Coastguard Worker        if not upgrader_str:
47*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
48*523fa7a6SAndroid Build Coastguard Worker                f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}"
49*523fa7a6SAndroid Build Coastguard Worker            )
50*523fa7a6SAndroid Build Coastguard Worker        output[upgrader_name] = (old_schema, upgrader_str)
51*523fa7a6SAndroid Build Coastguard Worker    return output
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Workerclass GraphModuleOpUpgrader:
55*523fa7a6SAndroid Build Coastguard Worker    """This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available.
56*523fa7a6SAndroid Build Coastguard Worker    To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In
57*523fa7a6SAndroid Build Coastguard Worker    __init__() it does the following:
58*523fa7a6SAndroid Build Coastguard Worker    1. parse the upgrader list and reorder for upgrading purpose.
59*523fa7a6SAndroid Build Coastguard Worker    2. register old versions of operators as custom ops.
60*523fa7a6SAndroid Build Coastguard Worker    3. prepare upgrader passes.
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker    In `upgrade()` API run these upgrader passes.
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker    An example of op_upgraders input:
65*523fa7a6SAndroid Build Coastguard Worker    {
66*523fa7a6SAndroid Build Coastguard Worker        "aten::div__Scalar_0_3": (                              # versioned op name
67*523fa7a6SAndroid Build Coastguard Worker            "div._Scalar(self: Tensor, other: Scalar)",         # old schema
68*523fa7a6SAndroid Build Coastguard Worker            '''
69*523fa7a6SAndroid Build Coastguard Worker            def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor:     # upgrader in literal string
70*523fa7a6SAndroid Build Coastguard Worker              if (self.is_floating_point() or isinstance(other, float)):
71*523fa7a6SAndroid Build Coastguard Worker                return self.true_divide_(other)
72*523fa7a6SAndroid Build Coastguard Worker              return self.divide_(other, rounding_mode='trunc')
73*523fa7a6SAndroid Build Coastguard Worker            ''',
74*523fa7a6SAndroid Build Coastguard Worker        ),
75*523fa7a6SAndroid Build Coastguard Worker    },
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Worker    Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the
78*523fa7a6SAndroid Build Coastguard Worker    original TorchScript upgrader).
79*523fa7a6SAndroid Build Coastguard Worker    """
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker    class UpgraderPass(_ExportPassBaseDeprecatedDoNotUse):
82*523fa7a6SAndroid Build Coastguard Worker        def __init__(self, old_target: Target, new_target: Target):
83*523fa7a6SAndroid Build Coastguard Worker            super().__init__()
84*523fa7a6SAndroid Build Coastguard Worker            self.old_target = old_target
85*523fa7a6SAndroid Build Coastguard Worker            self.new_target = new_target
86*523fa7a6SAndroid Build Coastguard Worker
87*523fa7a6SAndroid Build Coastguard Worker        def call_operator(
88*523fa7a6SAndroid Build Coastguard Worker            self,
89*523fa7a6SAndroid Build Coastguard Worker            op,
90*523fa7a6SAndroid Build Coastguard Worker            args: Tuple[Argument, ...],
91*523fa7a6SAndroid Build Coastguard Worker            kwargs: Dict[str, Argument],
92*523fa7a6SAndroid Build Coastguard Worker            meta: NodeMetadata,
93*523fa7a6SAndroid Build Coastguard Worker        ) -> ProxyValue:
94*523fa7a6SAndroid Build Coastguard Worker            if op == self.old_target:
95*523fa7a6SAndroid Build Coastguard Worker                return super().call_operator(self.new_target, args, kwargs, meta)
96*523fa7a6SAndroid Build Coastguard Worker            return super().call_operator(op, args, kwargs, meta)
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Worker    def __init__(
99*523fa7a6SAndroid Build Coastguard Worker        self,
100*523fa7a6SAndroid Build Coastguard Worker        compiler_opset_version: Optional[Dict[str, int]] = None,
101*523fa7a6SAndroid Build Coastguard Worker        model_opset_version: Optional[Dict[str, int]] = None,
102*523fa7a6SAndroid Build Coastguard Worker        op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None,
103*523fa7a6SAndroid Build Coastguard Worker    ):
104*523fa7a6SAndroid Build Coastguard Worker        self.op_upgraders: Dict[str, Tuple[str, str]] = (
105*523fa7a6SAndroid Build Coastguard Worker            get_upgraders() if not op_upgraders else op_upgraders
106*523fa7a6SAndroid Build Coastguard Worker        )
107*523fa7a6SAndroid Build Coastguard Worker        self.compiler_opset_version = (
108*523fa7a6SAndroid Build Coastguard Worker            compiler_opset_version if compiler_opset_version else {}
109*523fa7a6SAndroid Build Coastguard Worker        )
110*523fa7a6SAndroid Build Coastguard Worker        self.model_opset_version = model_opset_version if model_opset_version else {}
111*523fa7a6SAndroid Build Coastguard Worker        self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = (
112*523fa7a6SAndroid Build Coastguard Worker            GraphModuleOpUpgrader._populate_passes(
113*523fa7a6SAndroid Build Coastguard Worker                self._parse_upgraders(self.op_upgraders)
114*523fa7a6SAndroid Build Coastguard Worker            )
115*523fa7a6SAndroid Build Coastguard Worker        )
116*523fa7a6SAndroid Build Coastguard Worker
117*523fa7a6SAndroid Build Coastguard Worker    def _parse_upgraders(
118*523fa7a6SAndroid Build Coastguard Worker        self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None
119*523fa7a6SAndroid Build Coastguard Worker    ) -> List[Tuple[str, str]]:
120*523fa7a6SAndroid Build Coastguard Worker        """Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well
121*523fa7a6SAndroid Build Coastguard Worker        as the upgrader function string literal."""
122*523fa7a6SAndroid Build Coastguard Worker        # TODO(larryliu0820): Add support for custom ops
123*523fa7a6SAndroid Build Coastguard Worker        op_namespace = "aten"
124*523fa7a6SAndroid Build Coastguard Worker        if (
125*523fa7a6SAndroid Build Coastguard Worker            not op_upgraders
126*523fa7a6SAndroid Build Coastguard Worker            or op_namespace not in self.model_opset_version
127*523fa7a6SAndroid Build Coastguard Worker            or op_namespace not in self.compiler_opset_version
128*523fa7a6SAndroid Build Coastguard Worker        ):
129*523fa7a6SAndroid Build Coastguard Worker            return []
130*523fa7a6SAndroid Build Coastguard Worker        model_ver = self.model_opset_version[op_namespace]
131*523fa7a6SAndroid Build Coastguard Worker        curr_ver = self.compiler_opset_version[op_namespace]
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Worker        # key is the target version. div__Scalar_0_3 should have a key of 4.
134*523fa7a6SAndroid Build Coastguard Worker        versioned_upgraders: Dict[int, Tuple[str, str]] = {
135*523fa7a6SAndroid Build Coastguard Worker            get_target_version(name): v for name, v in op_upgraders.items()
136*523fa7a6SAndroid Build Coastguard Worker        }
137*523fa7a6SAndroid Build Coastguard Worker        target_upgraders: List[Tuple[str, str]] = []
138*523fa7a6SAndroid Build Coastguard Worker        # we need all upgraders from model_ver + 1 to curr_ver, inclusively
139*523fa7a6SAndroid Build Coastguard Worker        for ver in range(model_ver + 1, curr_ver + 1):
140*523fa7a6SAndroid Build Coastguard Worker            if ver in versioned_upgraders:
141*523fa7a6SAndroid Build Coastguard Worker                target_upgraders.append(versioned_upgraders[ver])
142*523fa7a6SAndroid Build Coastguard Worker            else:
143*523fa7a6SAndroid Build Coastguard Worker                # we may be able to get away with missing upgraders, if that operator is missing from given graph
144*523fa7a6SAndroid Build Coastguard Worker                # module.
145*523fa7a6SAndroid Build Coastguard Worker                log.warning(
146*523fa7a6SAndroid Build Coastguard Worker                    "Missing an upgrader to upgrade to version {ver}.",
147*523fa7a6SAndroid Build Coastguard Worker                    extra={"ver": ver},
148*523fa7a6SAndroid Build Coastguard Worker                )
149*523fa7a6SAndroid Build Coastguard Worker
150*523fa7a6SAndroid Build Coastguard Worker        return target_upgraders
151*523fa7a6SAndroid Build Coastguard Worker
152*523fa7a6SAndroid Build Coastguard Worker    @staticmethod
153*523fa7a6SAndroid Build Coastguard Worker    def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]:
154*523fa7a6SAndroid Build Coastguard Worker        """Given a list of upgraders, loop through it from lower version to higher version and create passes for all
155*523fa7a6SAndroid Build Coastguard Worker        upgraders. se torch.Library API to register old ops. Op name will be
156*523fa7a6SAndroid Build Coastguard Worker        <name>_<valid_from_ver>_<valid_till_ver>. Register upgraders as CompositeImplicitAutograd kernels. For example:
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker        lib = Library("aten", "FRAGMENT")
159*523fa7a6SAndroid Build Coastguard Worker        lib.define(old_schema)
160*523fa7a6SAndroid Build Coastguard Worker
161*523fa7a6SAndroid Build Coastguard Worker        impl_lib = Library("aten", "IMPL")
162*523fa7a6SAndroid Build Coastguard Worker        impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd")
163*523fa7a6SAndroid Build Coastguard Worker
164*523fa7a6SAndroid Build Coastguard Worker        @:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the
165*523fa7a6SAndroid Build Coastguard Worker        upgrader function literal text.
166*523fa7a6SAndroid Build Coastguard Worker        @:return upgrader passes, order matters
167*523fa7a6SAndroid Build Coastguard Worker        """
168*523fa7a6SAndroid Build Coastguard Worker
169*523fa7a6SAndroid Build Coastguard Worker        upgrader_passes = []
170*523fa7a6SAndroid Build Coastguard Worker
171*523fa7a6SAndroid Build Coastguard Worker        def register_old_op(name: str, schema: str, impl_str: str):
172*523fa7a6SAndroid Build Coastguard Worker            """Registers an old version operator using impl_name as old op name."""
173*523fa7a6SAndroid Build Coastguard Worker            lib.define(schema)
174*523fa7a6SAndroid Build Coastguard Worker            try:
175*523fa7a6SAndroid Build Coastguard Worker                exec(impl_str)
176*523fa7a6SAndroid Build Coastguard Worker            except Exception as e:
177*523fa7a6SAndroid Build Coastguard Worker                raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e
178*523fa7a6SAndroid Build Coastguard Worker            impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd")
179*523fa7a6SAndroid Build Coastguard Worker
180*523fa7a6SAndroid Build Coastguard Worker        for schema, upgrader_str in upgraders:
181*523fa7a6SAndroid Build Coastguard Worker            upgrader_name = upgrader_str.split("(")[0].split(" ")[-1]
182*523fa7a6SAndroid Build Coastguard Worker            op_name = schema.split("(")[0].split("::")[-1]
183*523fa7a6SAndroid Build Coastguard Worker            schema = schema.replace(op_name, upgrader_name)
184*523fa7a6SAndroid Build Coastguard Worker            try:
185*523fa7a6SAndroid Build Coastguard Worker                register_old_op(
186*523fa7a6SAndroid Build Coastguard Worker                    name=upgrader_name, schema=schema, impl_str=upgrader_str
187*523fa7a6SAndroid Build Coastguard Worker                )
188*523fa7a6SAndroid Build Coastguard Worker            except RuntimeError as e:
189*523fa7a6SAndroid Build Coastguard Worker                if "with the same name and overload name multiple times" in str(e):
190*523fa7a6SAndroid Build Coastguard Worker                    print(f"Registering {upgrader_name} multiple times")
191*523fa7a6SAndroid Build Coastguard Worker                else:
192*523fa7a6SAndroid Build Coastguard Worker                    raise RuntimeError from e
193*523fa7a6SAndroid Build Coastguard Worker            old_op_target = getattr(torch.ops.aten, upgrader_name).default
194*523fa7a6SAndroid Build Coastguard Worker            # for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the
195*523fa7a6SAndroid Build Coastguard Worker            # "default" at the end.
196*523fa7a6SAndroid Build Coastguard Worker            op_name, overload_name = (
197*523fa7a6SAndroid Build Coastguard Worker                (op_name, "default")
198*523fa7a6SAndroid Build Coastguard Worker                if "." not in op_name
199*523fa7a6SAndroid Build Coastguard Worker                else tuple(op_name.split(".")[:2])
200*523fa7a6SAndroid Build Coastguard Worker            )
201*523fa7a6SAndroid Build Coastguard Worker            new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name)
202*523fa7a6SAndroid Build Coastguard Worker            # Note that the graph will have op names in the graph, but actually they are of old versions.
203*523fa7a6SAndroid Build Coastguard Worker            upgrader_passes.append(
204*523fa7a6SAndroid Build Coastguard Worker                GraphModuleOpUpgrader.UpgraderPass(
205*523fa7a6SAndroid Build Coastguard Worker                    old_target=new_op_target, new_target=old_op_target
206*523fa7a6SAndroid Build Coastguard Worker                )
207*523fa7a6SAndroid Build Coastguard Worker            )
208*523fa7a6SAndroid Build Coastguard Worker
209*523fa7a6SAndroid Build Coastguard Worker        return upgrader_passes
210*523fa7a6SAndroid Build Coastguard Worker
211*523fa7a6SAndroid Build Coastguard Worker    def upgrade(self, exported_program):
212*523fa7a6SAndroid Build Coastguard Worker        return exported_program
213