# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import logging import re from collections import defaultdict from typing import Dict, List, Optional, Tuple import torch from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse from torch._export.pass_infra.node_metadata import NodeMetadata from torch._export.pass_infra.proxy_value import ProxyValue from torch.fx.node import Argument, Target from torch.library import Library lib = Library("aten", "FRAGMENT") impl_lib = Library("aten", "IMPL") log = logging.getLogger(__name__) def get_target_version(versioned_upgrader_name: str) -> int: """div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is upgrading to version 4.""" if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name): raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid") return int(versioned_upgrader_name.split("_")[-1]) + 1 def get_upgraders() -> Dict[str, Tuple[str, str]]: """Getting upgraders entry map and operator version map and merge them into one dict.""" upgraders = torch._C._get_upgraders_entry_map() op_version_map = torch._C._get_operator_version_map() output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type] for opname, entry_list in op_version_map.items(): if not entry_list: raise RuntimeError(f"Op version map has an empty entry for opname {opname}") entry = entry_list[0] old_schema = entry.old_schema upgrader_name = entry.upgrader_name upgrader_str = upgraders.get(upgrader_name, None) if not upgrader_str: raise RuntimeError( f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}" ) output[upgrader_name] = (old_schema, upgrader_str) return output class GraphModuleOpUpgrader: """This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available. To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In __init__() it does the following: 1. parse the upgrader list and reorder for upgrading purpose. 2. register old versions of operators as custom ops. 3. prepare upgrader passes. In `upgrade()` API run these upgrader passes. An example of op_upgraders input: { "aten::div__Scalar_0_3": ( # versioned op name "div._Scalar(self: Tensor, other: Scalar)", # old schema ''' def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor: # upgrader in literal string if (self.is_floating_point() or isinstance(other, float)): return self.true_divide_(other) return self.divide_(other, rounding_mode='trunc') ''', ), }, Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the original TorchScript upgrader). """ class UpgraderPass(_ExportPassBaseDeprecatedDoNotUse): def __init__(self, old_target: Target, new_target: Target): super().__init__() self.old_target = old_target self.new_target = new_target def call_operator( self, op, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op == self.old_target: return super().call_operator(self.new_target, args, kwargs, meta) return super().call_operator(op, args, kwargs, meta) def __init__( self, compiler_opset_version: Optional[Dict[str, int]] = None, model_opset_version: Optional[Dict[str, int]] = None, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None, ): self.op_upgraders: Dict[str, Tuple[str, str]] = ( get_upgraders() if not op_upgraders else op_upgraders ) self.compiler_opset_version = ( compiler_opset_version if compiler_opset_version else {} ) self.model_opset_version = model_opset_version if model_opset_version else {} self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = ( GraphModuleOpUpgrader._populate_passes( self._parse_upgraders(self.op_upgraders) ) ) def _parse_upgraders( self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None ) -> List[Tuple[str, str]]: """Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well as the upgrader function string literal.""" # TODO(larryliu0820): Add support for custom ops op_namespace = "aten" if ( not op_upgraders or op_namespace not in self.model_opset_version or op_namespace not in self.compiler_opset_version ): return [] model_ver = self.model_opset_version[op_namespace] curr_ver = self.compiler_opset_version[op_namespace] # key is the target version. div__Scalar_0_3 should have a key of 4. versioned_upgraders: Dict[int, Tuple[str, str]] = { get_target_version(name): v for name, v in op_upgraders.items() } target_upgraders: List[Tuple[str, str]] = [] # we need all upgraders from model_ver + 1 to curr_ver, inclusively for ver in range(model_ver + 1, curr_ver + 1): if ver in versioned_upgraders: target_upgraders.append(versioned_upgraders[ver]) else: # we may be able to get away with missing upgraders, if that operator is missing from given graph # module. log.warning( "Missing an upgrader to upgrade to version {ver}.", extra={"ver": ver}, ) return target_upgraders @staticmethod def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]: """Given a list of upgraders, loop through it from lower version to higher version and create passes for all upgraders. se torch.Library API to register old ops. Op name will be __. Register upgraders as CompositeImplicitAutograd kernels. For example: lib = Library("aten", "FRAGMENT") lib.define(old_schema) impl_lib = Library("aten", "IMPL") impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd") @:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the upgrader function literal text. @:return upgrader passes, order matters """ upgrader_passes = [] def register_old_op(name: str, schema: str, impl_str: str): """Registers an old version operator using impl_name as old op name.""" lib.define(schema) try: exec(impl_str) except Exception as e: raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd") for schema, upgrader_str in upgraders: upgrader_name = upgrader_str.split("(")[0].split(" ")[-1] op_name = schema.split("(")[0].split("::")[-1] schema = schema.replace(op_name, upgrader_name) try: register_old_op( name=upgrader_name, schema=schema, impl_str=upgrader_str ) except RuntimeError as e: if "with the same name and overload name multiple times" in str(e): print(f"Registering {upgrader_name} multiple times") else: raise RuntimeError from e old_op_target = getattr(torch.ops.aten, upgrader_name).default # for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the # "default" at the end. op_name, overload_name = ( (op_name, "default") if "." not in op_name else tuple(op_name.split(".")[:2]) ) new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name) # Note that the graph will have op names in the graph, but actually they are of old versions. upgrader_passes.append( GraphModuleOpUpgrader.UpgraderPass( old_target=new_op_target, new_target=old_op_target ) ) return upgrader_passes def upgrade(self, exported_program): return exported_program