xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/script_object.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import functools
4from typing import Dict
5
6import torch
7
8from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported
9from .base import VariableTracker
10from .user_defined import UserDefinedObjectVariable
11
12
13def _raise_hard_error_if_graph_break(reason):
14    def deco(fn):
15        @functools.wraps(fn)
16        def graph_break_as_hard_error(*args, **kwargs):
17            try:
18                return fn(*args, **kwargs)
19            except Unsupported as e:
20                raise UnsafeScriptObjectError(e.msg) from e
21
22        return graph_break_as_hard_error
23
24    return deco
25
26
27class TorchScriptObjectVariable(UserDefinedObjectVariable):
28    _fake_script_object_cache: Dict[int, "TorchScriptObjectVariable"] = {}
29
30    @classmethod
31    def is_matching_cls(cls, user_cls: type):
32        return issubclass(user_cls, torch.ScriptObject)
33
34    @staticmethod
35    def create(proxy, value, **options):
36        return TorchScriptObjectVariable(proxy, value, **options)
37
38    def __init__(self, proxy, value, source, **kwargs) -> None:
39        super().__init__(value, **kwargs)
40        self.proxy = proxy
41        self.proxy.node.meta["example_value"] = value
42        self.source = source
43
44    def as_proxy(self):
45        return self.proxy
46
47    @_raise_hard_error_if_graph_break(
48        "Dynamo cannot safely trace script object due to graph break."
49    )
50    def var_getattr(self, tx, name: str) -> VariableTracker:
51        from torch._higher_order_ops.torchbind import call_torchbind
52
53        from ..source import AttrSource
54        from .higher_order_ops import TorchHigherOrderOperatorVariable
55
56        method = getattr(self.value, name, None)
57        if method is None:
58            unimplemented(
59                f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?"
60            )
61
62        if not callable(method):
63            unimplemented(
64                "Only method calls on TorchScript objects can be supported safely."
65                " Please use method calls instead of attribute access."
66            )
67
68        return TorchHigherOrderOperatorVariable.make(
69            call_torchbind,
70            source=AttrSource(self.source, name),
71            script_obj_var=self,
72            method_name=name,
73        )
74
75    # We only support method calls on script objects. Interpreting the bytecodes
76    # should go through var_getattr then call_function instead of call_method.
77    #
78    # However, it's possible for call_method to be used directly e.g. for __setattr__.
79    @_raise_hard_error_if_graph_break(
80        "Dynamo cannot safely trace script object due to graph break."
81    )
82    def call_method(self, tx, name, args, kwargs):
83        unimplemented(f"call method {name} on script object is not safe.")
84