1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport logging 8*523fa7a6SAndroid Build Coastguard Workerimport re 9*523fa7a6SAndroid Build Coastguard Workerfrom collections import defaultdict 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional, Tuple 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport torch 13*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse 14*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.pass_infra.node_metadata import NodeMetadata 15*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.pass_infra.proxy_value import ProxyValue 16*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.node import Argument, Target 17*523fa7a6SAndroid Build Coastguard Workerfrom torch.library import Library 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Workerlib = Library("aten", "FRAGMENT") 20*523fa7a6SAndroid Build Coastguard Workerimpl_lib = Library("aten", "IMPL") 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Workerlog = logging.getLogger(__name__) 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Workerdef get_target_version(versioned_upgrader_name: str) -> int: 26*523fa7a6SAndroid Build Coastguard Worker """div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is 27*523fa7a6SAndroid Build Coastguard Worker upgrading to version 4.""" 28*523fa7a6SAndroid Build Coastguard Worker if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name): 29*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid") 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Worker return int(versioned_upgrader_name.split("_")[-1]) + 1 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Workerdef get_upgraders() -> Dict[str, Tuple[str, str]]: 35*523fa7a6SAndroid Build Coastguard Worker """Getting upgraders entry map and operator version map and merge them into one dict.""" 36*523fa7a6SAndroid Build Coastguard Worker upgraders = torch._C._get_upgraders_entry_map() 37*523fa7a6SAndroid Build Coastguard Worker op_version_map = torch._C._get_operator_version_map() 38*523fa7a6SAndroid Build Coastguard Worker output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type] 39*523fa7a6SAndroid Build Coastguard Worker for opname, entry_list in op_version_map.items(): 40*523fa7a6SAndroid Build Coastguard Worker if not entry_list: 41*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Op version map has an empty entry for opname {opname}") 42*523fa7a6SAndroid Build Coastguard Worker entry = entry_list[0] 43*523fa7a6SAndroid Build Coastguard Worker old_schema = entry.old_schema 44*523fa7a6SAndroid Build Coastguard Worker upgrader_name = entry.upgrader_name 45*523fa7a6SAndroid Build Coastguard Worker upgrader_str = upgraders.get(upgrader_name, None) 46*523fa7a6SAndroid Build Coastguard Worker if not upgrader_str: 47*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 48*523fa7a6SAndroid Build Coastguard Worker f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}" 49*523fa7a6SAndroid Build Coastguard Worker ) 50*523fa7a6SAndroid Build Coastguard Worker output[upgrader_name] = (old_schema, upgrader_str) 51*523fa7a6SAndroid Build Coastguard Worker return output 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Workerclass GraphModuleOpUpgrader: 55*523fa7a6SAndroid Build Coastguard Worker """This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available. 56*523fa7a6SAndroid Build Coastguard Worker To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In 57*523fa7a6SAndroid Build Coastguard Worker __init__() it does the following: 58*523fa7a6SAndroid Build Coastguard Worker 1. parse the upgrader list and reorder for upgrading purpose. 59*523fa7a6SAndroid Build Coastguard Worker 2. register old versions of operators as custom ops. 60*523fa7a6SAndroid Build Coastguard Worker 3. prepare upgrader passes. 61*523fa7a6SAndroid Build Coastguard Worker 62*523fa7a6SAndroid Build Coastguard Worker In `upgrade()` API run these upgrader passes. 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker An example of op_upgraders input: 65*523fa7a6SAndroid Build Coastguard Worker { 66*523fa7a6SAndroid Build Coastguard Worker "aten::div__Scalar_0_3": ( # versioned op name 67*523fa7a6SAndroid Build Coastguard Worker "div._Scalar(self: Tensor, other: Scalar)", # old schema 68*523fa7a6SAndroid Build Coastguard Worker ''' 69*523fa7a6SAndroid Build Coastguard Worker def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor: # upgrader in literal string 70*523fa7a6SAndroid Build Coastguard Worker if (self.is_floating_point() or isinstance(other, float)): 71*523fa7a6SAndroid Build Coastguard Worker return self.true_divide_(other) 72*523fa7a6SAndroid Build Coastguard Worker return self.divide_(other, rounding_mode='trunc') 73*523fa7a6SAndroid Build Coastguard Worker ''', 74*523fa7a6SAndroid Build Coastguard Worker ), 75*523fa7a6SAndroid Build Coastguard Worker }, 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Worker Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the 78*523fa7a6SAndroid Build Coastguard Worker original TorchScript upgrader). 79*523fa7a6SAndroid Build Coastguard Worker """ 80*523fa7a6SAndroid Build Coastguard Worker 81*523fa7a6SAndroid Build Coastguard Worker class UpgraderPass(_ExportPassBaseDeprecatedDoNotUse): 82*523fa7a6SAndroid Build Coastguard Worker def __init__(self, old_target: Target, new_target: Target): 83*523fa7a6SAndroid Build Coastguard Worker super().__init__() 84*523fa7a6SAndroid Build Coastguard Worker self.old_target = old_target 85*523fa7a6SAndroid Build Coastguard Worker self.new_target = new_target 86*523fa7a6SAndroid Build Coastguard Worker 87*523fa7a6SAndroid Build Coastguard Worker def call_operator( 88*523fa7a6SAndroid Build Coastguard Worker self, 89*523fa7a6SAndroid Build Coastguard Worker op, 90*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Argument, ...], 91*523fa7a6SAndroid Build Coastguard Worker kwargs: Dict[str, Argument], 92*523fa7a6SAndroid Build Coastguard Worker meta: NodeMetadata, 93*523fa7a6SAndroid Build Coastguard Worker ) -> ProxyValue: 94*523fa7a6SAndroid Build Coastguard Worker if op == self.old_target: 95*523fa7a6SAndroid Build Coastguard Worker return super().call_operator(self.new_target, args, kwargs, meta) 96*523fa7a6SAndroid Build Coastguard Worker return super().call_operator(op, args, kwargs, meta) 97*523fa7a6SAndroid Build Coastguard Worker 98*523fa7a6SAndroid Build Coastguard Worker def __init__( 99*523fa7a6SAndroid Build Coastguard Worker self, 100*523fa7a6SAndroid Build Coastguard Worker compiler_opset_version: Optional[Dict[str, int]] = None, 101*523fa7a6SAndroid Build Coastguard Worker model_opset_version: Optional[Dict[str, int]] = None, 102*523fa7a6SAndroid Build Coastguard Worker op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None, 103*523fa7a6SAndroid Build Coastguard Worker ): 104*523fa7a6SAndroid Build Coastguard Worker self.op_upgraders: Dict[str, Tuple[str, str]] = ( 105*523fa7a6SAndroid Build Coastguard Worker get_upgraders() if not op_upgraders else op_upgraders 106*523fa7a6SAndroid Build Coastguard Worker ) 107*523fa7a6SAndroid Build Coastguard Worker self.compiler_opset_version = ( 108*523fa7a6SAndroid Build Coastguard Worker compiler_opset_version if compiler_opset_version else {} 109*523fa7a6SAndroid Build Coastguard Worker ) 110*523fa7a6SAndroid Build Coastguard Worker self.model_opset_version = model_opset_version if model_opset_version else {} 111*523fa7a6SAndroid Build Coastguard Worker self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = ( 112*523fa7a6SAndroid Build Coastguard Worker GraphModuleOpUpgrader._populate_passes( 113*523fa7a6SAndroid Build Coastguard Worker self._parse_upgraders(self.op_upgraders) 114*523fa7a6SAndroid Build Coastguard Worker ) 115*523fa7a6SAndroid Build Coastguard Worker ) 116*523fa7a6SAndroid Build Coastguard Worker 117*523fa7a6SAndroid Build Coastguard Worker def _parse_upgraders( 118*523fa7a6SAndroid Build Coastguard Worker self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None 119*523fa7a6SAndroid Build Coastguard Worker ) -> List[Tuple[str, str]]: 120*523fa7a6SAndroid Build Coastguard Worker """Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well 121*523fa7a6SAndroid Build Coastguard Worker as the upgrader function string literal.""" 122*523fa7a6SAndroid Build Coastguard Worker # TODO(larryliu0820): Add support for custom ops 123*523fa7a6SAndroid Build Coastguard Worker op_namespace = "aten" 124*523fa7a6SAndroid Build Coastguard Worker if ( 125*523fa7a6SAndroid Build Coastguard Worker not op_upgraders 126*523fa7a6SAndroid Build Coastguard Worker or op_namespace not in self.model_opset_version 127*523fa7a6SAndroid Build Coastguard Worker or op_namespace not in self.compiler_opset_version 128*523fa7a6SAndroid Build Coastguard Worker ): 129*523fa7a6SAndroid Build Coastguard Worker return [] 130*523fa7a6SAndroid Build Coastguard Worker model_ver = self.model_opset_version[op_namespace] 131*523fa7a6SAndroid Build Coastguard Worker curr_ver = self.compiler_opset_version[op_namespace] 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Worker # key is the target version. div__Scalar_0_3 should have a key of 4. 134*523fa7a6SAndroid Build Coastguard Worker versioned_upgraders: Dict[int, Tuple[str, str]] = { 135*523fa7a6SAndroid Build Coastguard Worker get_target_version(name): v for name, v in op_upgraders.items() 136*523fa7a6SAndroid Build Coastguard Worker } 137*523fa7a6SAndroid Build Coastguard Worker target_upgraders: List[Tuple[str, str]] = [] 138*523fa7a6SAndroid Build Coastguard Worker # we need all upgraders from model_ver + 1 to curr_ver, inclusively 139*523fa7a6SAndroid Build Coastguard Worker for ver in range(model_ver + 1, curr_ver + 1): 140*523fa7a6SAndroid Build Coastguard Worker if ver in versioned_upgraders: 141*523fa7a6SAndroid Build Coastguard Worker target_upgraders.append(versioned_upgraders[ver]) 142*523fa7a6SAndroid Build Coastguard Worker else: 143*523fa7a6SAndroid Build Coastguard Worker # we may be able to get away with missing upgraders, if that operator is missing from given graph 144*523fa7a6SAndroid Build Coastguard Worker # module. 145*523fa7a6SAndroid Build Coastguard Worker log.warning( 146*523fa7a6SAndroid Build Coastguard Worker "Missing an upgrader to upgrade to version {ver}.", 147*523fa7a6SAndroid Build Coastguard Worker extra={"ver": ver}, 148*523fa7a6SAndroid Build Coastguard Worker ) 149*523fa7a6SAndroid Build Coastguard Worker 150*523fa7a6SAndroid Build Coastguard Worker return target_upgraders 151*523fa7a6SAndroid Build Coastguard Worker 152*523fa7a6SAndroid Build Coastguard Worker @staticmethod 153*523fa7a6SAndroid Build Coastguard Worker def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]: 154*523fa7a6SAndroid Build Coastguard Worker """Given a list of upgraders, loop through it from lower version to higher version and create passes for all 155*523fa7a6SAndroid Build Coastguard Worker upgraders. se torch.Library API to register old ops. Op name will be 156*523fa7a6SAndroid Build Coastguard Worker <name>_<valid_from_ver>_<valid_till_ver>. Register upgraders as CompositeImplicitAutograd kernels. For example: 157*523fa7a6SAndroid Build Coastguard Worker 158*523fa7a6SAndroid Build Coastguard Worker lib = Library("aten", "FRAGMENT") 159*523fa7a6SAndroid Build Coastguard Worker lib.define(old_schema) 160*523fa7a6SAndroid Build Coastguard Worker 161*523fa7a6SAndroid Build Coastguard Worker impl_lib = Library("aten", "IMPL") 162*523fa7a6SAndroid Build Coastguard Worker impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd") 163*523fa7a6SAndroid Build Coastguard Worker 164*523fa7a6SAndroid Build Coastguard Worker @:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the 165*523fa7a6SAndroid Build Coastguard Worker upgrader function literal text. 166*523fa7a6SAndroid Build Coastguard Worker @:return upgrader passes, order matters 167*523fa7a6SAndroid Build Coastguard Worker """ 168*523fa7a6SAndroid Build Coastguard Worker 169*523fa7a6SAndroid Build Coastguard Worker upgrader_passes = [] 170*523fa7a6SAndroid Build Coastguard Worker 171*523fa7a6SAndroid Build Coastguard Worker def register_old_op(name: str, schema: str, impl_str: str): 172*523fa7a6SAndroid Build Coastguard Worker """Registers an old version operator using impl_name as old op name.""" 173*523fa7a6SAndroid Build Coastguard Worker lib.define(schema) 174*523fa7a6SAndroid Build Coastguard Worker try: 175*523fa7a6SAndroid Build Coastguard Worker exec(impl_str) 176*523fa7a6SAndroid Build Coastguard Worker except Exception as e: 177*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e 178*523fa7a6SAndroid Build Coastguard Worker impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd") 179*523fa7a6SAndroid Build Coastguard Worker 180*523fa7a6SAndroid Build Coastguard Worker for schema, upgrader_str in upgraders: 181*523fa7a6SAndroid Build Coastguard Worker upgrader_name = upgrader_str.split("(")[0].split(" ")[-1] 182*523fa7a6SAndroid Build Coastguard Worker op_name = schema.split("(")[0].split("::")[-1] 183*523fa7a6SAndroid Build Coastguard Worker schema = schema.replace(op_name, upgrader_name) 184*523fa7a6SAndroid Build Coastguard Worker try: 185*523fa7a6SAndroid Build Coastguard Worker register_old_op( 186*523fa7a6SAndroid Build Coastguard Worker name=upgrader_name, schema=schema, impl_str=upgrader_str 187*523fa7a6SAndroid Build Coastguard Worker ) 188*523fa7a6SAndroid Build Coastguard Worker except RuntimeError as e: 189*523fa7a6SAndroid Build Coastguard Worker if "with the same name and overload name multiple times" in str(e): 190*523fa7a6SAndroid Build Coastguard Worker print(f"Registering {upgrader_name} multiple times") 191*523fa7a6SAndroid Build Coastguard Worker else: 192*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError from e 193*523fa7a6SAndroid Build Coastguard Worker old_op_target = getattr(torch.ops.aten, upgrader_name).default 194*523fa7a6SAndroid Build Coastguard Worker # for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the 195*523fa7a6SAndroid Build Coastguard Worker # "default" at the end. 196*523fa7a6SAndroid Build Coastguard Worker op_name, overload_name = ( 197*523fa7a6SAndroid Build Coastguard Worker (op_name, "default") 198*523fa7a6SAndroid Build Coastguard Worker if "." not in op_name 199*523fa7a6SAndroid Build Coastguard Worker else tuple(op_name.split(".")[:2]) 200*523fa7a6SAndroid Build Coastguard Worker ) 201*523fa7a6SAndroid Build Coastguard Worker new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name) 202*523fa7a6SAndroid Build Coastguard Worker # Note that the graph will have op names in the graph, but actually they are of old versions. 203*523fa7a6SAndroid Build Coastguard Worker upgrader_passes.append( 204*523fa7a6SAndroid Build Coastguard Worker GraphModuleOpUpgrader.UpgraderPass( 205*523fa7a6SAndroid Build Coastguard Worker old_target=new_op_target, new_target=old_op_target 206*523fa7a6SAndroid Build Coastguard Worker ) 207*523fa7a6SAndroid Build Coastguard Worker ) 208*523fa7a6SAndroid Build Coastguard Worker 209*523fa7a6SAndroid Build Coastguard Worker return upgrader_passes 210*523fa7a6SAndroid Build Coastguard Worker 211*523fa7a6SAndroid Build Coastguard Worker def upgrade(self, exported_program): 212*523fa7a6SAndroid Build Coastguard Worker return exported_program 213