xref: /aosp_15_r20/external/executorch/exir/serde/upgrade.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import re
9from collections import defaultdict
10from typing import Dict, List, Optional, Tuple
11
12import torch
13from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
14from torch._export.pass_infra.node_metadata import NodeMetadata
15from torch._export.pass_infra.proxy_value import ProxyValue
16from torch.fx.node import Argument, Target
17from torch.library import Library
18
19lib = Library("aten", "FRAGMENT")
20impl_lib = Library("aten", "IMPL")
21
22log = logging.getLogger(__name__)
23
24
25def get_target_version(versioned_upgrader_name: str) -> int:
26    """div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is
27    upgrading to version 4."""
28    if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name):
29        raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid")
30
31    return int(versioned_upgrader_name.split("_")[-1]) + 1
32
33
34def get_upgraders() -> Dict[str, Tuple[str, str]]:
35    """Getting upgraders entry map and operator version map and merge them into one dict."""
36    upgraders = torch._C._get_upgraders_entry_map()
37    op_version_map = torch._C._get_operator_version_map()
38    output: Dict[str, Tuple[str, str]] = defaultdict(tuple)  # type: ignore[arg-type]
39    for opname, entry_list in op_version_map.items():
40        if not entry_list:
41            raise RuntimeError(f"Op version map has an empty entry for opname {opname}")
42        entry = entry_list[0]
43        old_schema = entry.old_schema
44        upgrader_name = entry.upgrader_name
45        upgrader_str = upgraders.get(upgrader_name, None)
46        if not upgrader_str:
47            raise RuntimeError(
48                f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}"
49            )
50        output[upgrader_name] = (old_schema, upgrader_str)
51    return output
52
53
54class GraphModuleOpUpgrader:
55    """This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available.
56    To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In
57    __init__() it does the following:
58    1. parse the upgrader list and reorder for upgrading purpose.
59    2. register old versions of operators as custom ops.
60    3. prepare upgrader passes.
61
62    In `upgrade()` API run these upgrader passes.
63
64    An example of op_upgraders input:
65    {
66        "aten::div__Scalar_0_3": (                              # versioned op name
67            "div._Scalar(self: Tensor, other: Scalar)",         # old schema
68            '''
69            def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor:     # upgrader in literal string
70              if (self.is_floating_point() or isinstance(other, float)):
71                return self.true_divide_(other)
72              return self.divide_(other, rounding_mode='trunc')
73            ''',
74        ),
75    },
76
77    Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the
78    original TorchScript upgrader).
79    """
80
81    class UpgraderPass(_ExportPassBaseDeprecatedDoNotUse):
82        def __init__(self, old_target: Target, new_target: Target):
83            super().__init__()
84            self.old_target = old_target
85            self.new_target = new_target
86
87        def call_operator(
88            self,
89            op,
90            args: Tuple[Argument, ...],
91            kwargs: Dict[str, Argument],
92            meta: NodeMetadata,
93        ) -> ProxyValue:
94            if op == self.old_target:
95                return super().call_operator(self.new_target, args, kwargs, meta)
96            return super().call_operator(op, args, kwargs, meta)
97
98    def __init__(
99        self,
100        compiler_opset_version: Optional[Dict[str, int]] = None,
101        model_opset_version: Optional[Dict[str, int]] = None,
102        op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None,
103    ):
104        self.op_upgraders: Dict[str, Tuple[str, str]] = (
105            get_upgraders() if not op_upgraders else op_upgraders
106        )
107        self.compiler_opset_version = (
108            compiler_opset_version if compiler_opset_version else {}
109        )
110        self.model_opset_version = model_opset_version if model_opset_version else {}
111        self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = (
112            GraphModuleOpUpgrader._populate_passes(
113                self._parse_upgraders(self.op_upgraders)
114            )
115        )
116
117    def _parse_upgraders(
118        self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None
119    ) -> List[Tuple[str, str]]:
120        """Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well
121        as the upgrader function string literal."""
122        # TODO(larryliu0820): Add support for custom ops
123        op_namespace = "aten"
124        if (
125            not op_upgraders
126            or op_namespace not in self.model_opset_version
127            or op_namespace not in self.compiler_opset_version
128        ):
129            return []
130        model_ver = self.model_opset_version[op_namespace]
131        curr_ver = self.compiler_opset_version[op_namespace]
132
133        # key is the target version. div__Scalar_0_3 should have a key of 4.
134        versioned_upgraders: Dict[int, Tuple[str, str]] = {
135            get_target_version(name): v for name, v in op_upgraders.items()
136        }
137        target_upgraders: List[Tuple[str, str]] = []
138        # we need all upgraders from model_ver + 1 to curr_ver, inclusively
139        for ver in range(model_ver + 1, curr_ver + 1):
140            if ver in versioned_upgraders:
141                target_upgraders.append(versioned_upgraders[ver])
142            else:
143                # we may be able to get away with missing upgraders, if that operator is missing from given graph
144                # module.
145                log.warning(
146                    "Missing an upgrader to upgrade to version {ver}.",
147                    extra={"ver": ver},
148                )
149
150        return target_upgraders
151
152    @staticmethod
153    def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]:
154        """Given a list of upgraders, loop through it from lower version to higher version and create passes for all
155        upgraders. se torch.Library API to register old ops. Op name will be
156        <name>_<valid_from_ver>_<valid_till_ver>. Register upgraders as CompositeImplicitAutograd kernels. For example:
157
158        lib = Library("aten", "FRAGMENT")
159        lib.define(old_schema)
160
161        impl_lib = Library("aten", "IMPL")
162        impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd")
163
164        @:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the
165        upgrader function literal text.
166        @:return upgrader passes, order matters
167        """
168
169        upgrader_passes = []
170
171        def register_old_op(name: str, schema: str, impl_str: str):
172            """Registers an old version operator using impl_name as old op name."""
173            lib.define(schema)
174            try:
175                exec(impl_str)
176            except Exception as e:
177                raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e
178            impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd")
179
180        for schema, upgrader_str in upgraders:
181            upgrader_name = upgrader_str.split("(")[0].split(" ")[-1]
182            op_name = schema.split("(")[0].split("::")[-1]
183            schema = schema.replace(op_name, upgrader_name)
184            try:
185                register_old_op(
186                    name=upgrader_name, schema=schema, impl_str=upgrader_str
187                )
188            except RuntimeError as e:
189                if "with the same name and overload name multiple times" in str(e):
190                    print(f"Registering {upgrader_name} multiple times")
191                else:
192                    raise RuntimeError from e
193            old_op_target = getattr(torch.ops.aten, upgrader_name).default
194            # for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the
195            # "default" at the end.
196            op_name, overload_name = (
197                (op_name, "default")
198                if "." not in op_name
199                else tuple(op_name.split(".")[:2])
200            )
201            new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name)
202            # Note that the graph will have op names in the graph, but actually they are of old versions.
203            upgrader_passes.append(
204                GraphModuleOpUpgrader.UpgraderPass(
205                    old_target=new_op_target, new_target=old_op_target
206                )
207            )
208
209        return upgrader_passes
210
211    def upgrade(self, exported_program):
212        return exported_program
213