1# mypy: allow-untyped-defs 2from typing import Dict, Optional 3import torch 4from torch._ops import OpOverload, HigherOrderOperator 5from torch._export.error import InternalError 6from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse 7 8 9__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"] 10 11 12_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = { 13 torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default, 14} 15 16 17def is_view_op(schema: torch._C.FunctionSchema) -> bool: 18 if len(schema.arguments) == 0: 19 return False 20 alias_info = schema.arguments[0].alias_info 21 return (alias_info is not None) and (not alias_info.is_write) 22 23 24def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]: 25 if is_view_op(schema) and schema.name.startswith("aten::"): 26 view_op_name = schema.name.split("::")[1] 27 view_op_overload = ( 28 schema.overload_name 29 if schema.overload_name != "" 30 else "default" 31 ) 32 view_copy_op_name = view_op_name + "_copy" 33 if not hasattr(torch.ops.aten, view_copy_op_name): 34 raise InternalError(f"{schema.name} is missing a view_copy variant") 35 36 view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name) 37 38 if not hasattr(view_copy_op_overload_packet, view_op_overload): 39 raise InternalError(f"{schema.name} is missing a view_copy variant") 40 41 return getattr(view_copy_op_overload_packet, view_op_overload) 42 43 return None 44 45 46class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse): 47 """ 48 Our backend expects pure functional operators. For efficiency 49 purposes, we keep view ops around while functionalizing the exported 50 program. This pass replaces view ops with view copy ops for backends that 51 need AOT memory planning. 52 """ 53 def call_operator(self, op, args, kwargs, meta): 54 if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: 55 return super().call_operator( 56 (_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta 57 ) 58 59 if isinstance(op, HigherOrderOperator): 60 return super().call_operator(op, args, kwargs, meta) 61 62 if view_copy_op := get_view_copy_of_view_op(op._schema): 63 return super().call_operator(view_copy_op, args, kwargs, meta) 64 65 return super().call_operator(op, args, kwargs, meta) 66