1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import functools 4from typing import Dict 5 6import torch 7 8from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported 9from .base import VariableTracker 10from .user_defined import UserDefinedObjectVariable 11 12 13def _raise_hard_error_if_graph_break(reason): 14 def deco(fn): 15 @functools.wraps(fn) 16 def graph_break_as_hard_error(*args, **kwargs): 17 try: 18 return fn(*args, **kwargs) 19 except Unsupported as e: 20 raise UnsafeScriptObjectError(e.msg) from e 21 22 return graph_break_as_hard_error 23 24 return deco 25 26 27class TorchScriptObjectVariable(UserDefinedObjectVariable): 28 _fake_script_object_cache: Dict[int, "TorchScriptObjectVariable"] = {} 29 30 @classmethod 31 def is_matching_cls(cls, user_cls: type): 32 return issubclass(user_cls, torch.ScriptObject) 33 34 @staticmethod 35 def create(proxy, value, **options): 36 return TorchScriptObjectVariable(proxy, value, **options) 37 38 def __init__(self, proxy, value, source, **kwargs) -> None: 39 super().__init__(value, **kwargs) 40 self.proxy = proxy 41 self.proxy.node.meta["example_value"] = value 42 self.source = source 43 44 def as_proxy(self): 45 return self.proxy 46 47 @_raise_hard_error_if_graph_break( 48 "Dynamo cannot safely trace script object due to graph break." 49 ) 50 def var_getattr(self, tx, name: str) -> VariableTracker: 51 from torch._higher_order_ops.torchbind import call_torchbind 52 53 from ..source import AttrSource 54 from .higher_order_ops import TorchHigherOrderOperatorVariable 55 56 method = getattr(self.value, name, None) 57 if method is None: 58 unimplemented( 59 f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?" 60 ) 61 62 if not callable(method): 63 unimplemented( 64 "Only method calls on TorchScript objects can be supported safely." 65 " Please use method calls instead of attribute access." 66 ) 67 68 return TorchHigherOrderOperatorVariable.make( 69 call_torchbind, 70 source=AttrSource(self.source, name), 71 script_obj_var=self, 72 method_name=name, 73 ) 74 75 # We only support method calls on script objects. Interpreting the bytecodes 76 # should go through var_getattr then call_function instead of call_method. 77 # 78 # However, it's possible for call_method to be used directly e.g. for __setattr__. 79 @_raise_hard_error_if_graph_break( 80 "Dynamo cannot safely trace script object due to graph break." 81 ) 82 def call_method(self, tx, name, args, kwargs): 83 unimplemented(f"call method {name} on script object is not safe.") 84