1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8import re 9from collections import defaultdict 10from typing import Dict, List, Optional, Tuple 11 12import torch 13from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse 14from torch._export.pass_infra.node_metadata import NodeMetadata 15from torch._export.pass_infra.proxy_value import ProxyValue 16from torch.fx.node import Argument, Target 17from torch.library import Library 18 19lib = Library("aten", "FRAGMENT") 20impl_lib = Library("aten", "IMPL") 21 22log = logging.getLogger(__name__) 23 24 25def get_target_version(versioned_upgrader_name: str) -> int: 26 """div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is 27 upgrading to version 4.""" 28 if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name): 29 raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid") 30 31 return int(versioned_upgrader_name.split("_")[-1]) + 1 32 33 34def get_upgraders() -> Dict[str, Tuple[str, str]]: 35 """Getting upgraders entry map and operator version map and merge them into one dict.""" 36 upgraders = torch._C._get_upgraders_entry_map() 37 op_version_map = torch._C._get_operator_version_map() 38 output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type] 39 for opname, entry_list in op_version_map.items(): 40 if not entry_list: 41 raise RuntimeError(f"Op version map has an empty entry for opname {opname}") 42 entry = entry_list[0] 43 old_schema = entry.old_schema 44 upgrader_name = entry.upgrader_name 45 upgrader_str = upgraders.get(upgrader_name, None) 46 if not upgrader_str: 47 raise RuntimeError( 48 f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}" 49 ) 50 output[upgrader_name] = (old_schema, upgrader_str) 51 return output 52 53 54class GraphModuleOpUpgrader: 55 """This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available. 56 To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In 57 __init__() it does the following: 58 1. parse the upgrader list and reorder for upgrading purpose. 59 2. register old versions of operators as custom ops. 60 3. prepare upgrader passes. 61 62 In `upgrade()` API run these upgrader passes. 63 64 An example of op_upgraders input: 65 { 66 "aten::div__Scalar_0_3": ( # versioned op name 67 "div._Scalar(self: Tensor, other: Scalar)", # old schema 68 ''' 69 def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor: # upgrader in literal string 70 if (self.is_floating_point() or isinstance(other, float)): 71 return self.true_divide_(other) 72 return self.divide_(other, rounding_mode='trunc') 73 ''', 74 ), 75 }, 76 77 Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the 78 original TorchScript upgrader). 79 """ 80 81 class UpgraderPass(_ExportPassBaseDeprecatedDoNotUse): 82 def __init__(self, old_target: Target, new_target: Target): 83 super().__init__() 84 self.old_target = old_target 85 self.new_target = new_target 86 87 def call_operator( 88 self, 89 op, 90 args: Tuple[Argument, ...], 91 kwargs: Dict[str, Argument], 92 meta: NodeMetadata, 93 ) -> ProxyValue: 94 if op == self.old_target: 95 return super().call_operator(self.new_target, args, kwargs, meta) 96 return super().call_operator(op, args, kwargs, meta) 97 98 def __init__( 99 self, 100 compiler_opset_version: Optional[Dict[str, int]] = None, 101 model_opset_version: Optional[Dict[str, int]] = None, 102 op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None, 103 ): 104 self.op_upgraders: Dict[str, Tuple[str, str]] = ( 105 get_upgraders() if not op_upgraders else op_upgraders 106 ) 107 self.compiler_opset_version = ( 108 compiler_opset_version if compiler_opset_version else {} 109 ) 110 self.model_opset_version = model_opset_version if model_opset_version else {} 111 self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = ( 112 GraphModuleOpUpgrader._populate_passes( 113 self._parse_upgraders(self.op_upgraders) 114 ) 115 ) 116 117 def _parse_upgraders( 118 self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None 119 ) -> List[Tuple[str, str]]: 120 """Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well 121 as the upgrader function string literal.""" 122 # TODO(larryliu0820): Add support for custom ops 123 op_namespace = "aten" 124 if ( 125 not op_upgraders 126 or op_namespace not in self.model_opset_version 127 or op_namespace not in self.compiler_opset_version 128 ): 129 return [] 130 model_ver = self.model_opset_version[op_namespace] 131 curr_ver = self.compiler_opset_version[op_namespace] 132 133 # key is the target version. div__Scalar_0_3 should have a key of 4. 134 versioned_upgraders: Dict[int, Tuple[str, str]] = { 135 get_target_version(name): v for name, v in op_upgraders.items() 136 } 137 target_upgraders: List[Tuple[str, str]] = [] 138 # we need all upgraders from model_ver + 1 to curr_ver, inclusively 139 for ver in range(model_ver + 1, curr_ver + 1): 140 if ver in versioned_upgraders: 141 target_upgraders.append(versioned_upgraders[ver]) 142 else: 143 # we may be able to get away with missing upgraders, if that operator is missing from given graph 144 # module. 145 log.warning( 146 "Missing an upgrader to upgrade to version {ver}.", 147 extra={"ver": ver}, 148 ) 149 150 return target_upgraders 151 152 @staticmethod 153 def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]: 154 """Given a list of upgraders, loop through it from lower version to higher version and create passes for all 155 upgraders. se torch.Library API to register old ops. Op name will be 156 <name>_<valid_from_ver>_<valid_till_ver>. Register upgraders as CompositeImplicitAutograd kernels. For example: 157 158 lib = Library("aten", "FRAGMENT") 159 lib.define(old_schema) 160 161 impl_lib = Library("aten", "IMPL") 162 impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd") 163 164 @:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the 165 upgrader function literal text. 166 @:return upgrader passes, order matters 167 """ 168 169 upgrader_passes = [] 170 171 def register_old_op(name: str, schema: str, impl_str: str): 172 """Registers an old version operator using impl_name as old op name.""" 173 lib.define(schema) 174 try: 175 exec(impl_str) 176 except Exception as e: 177 raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e 178 impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd") 179 180 for schema, upgrader_str in upgraders: 181 upgrader_name = upgrader_str.split("(")[0].split(" ")[-1] 182 op_name = schema.split("(")[0].split("::")[-1] 183 schema = schema.replace(op_name, upgrader_name) 184 try: 185 register_old_op( 186 name=upgrader_name, schema=schema, impl_str=upgrader_str 187 ) 188 except RuntimeError as e: 189 if "with the same name and overload name multiple times" in str(e): 190 print(f"Registering {upgrader_name} multiple times") 191 else: 192 raise RuntimeError from e 193 old_op_target = getattr(torch.ops.aten, upgrader_name).default 194 # for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the 195 # "default" at the end. 196 op_name, overload_name = ( 197 (op_name, "default") 198 if "." not in op_name 199 else tuple(op_name.split(".")[:2]) 200 ) 201 new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name) 202 # Note that the graph will have op names in the graph, but actually they are of old versions. 203 upgrader_passes.append( 204 GraphModuleOpUpgrader.UpgraderPass( 205 old_target=new_op_target, new_target=old_op_target 206 ) 207 ) 208 209 return upgrader_passes 210 211 def upgrade(self, exported_program): 212 return exported_program 213