1# mypy: allow-untyped-defs 2import logging 3from contextlib import contextmanager 4 5import torch 6from torch._C import DispatchKey # @manual 7from torch._functorch._aot_autograd.utils import KNOWN_TYPES 8from torch._higher_order_ops.utils import autograd_not_implemented 9from torch._library.fake_class_registry import _ns_and_class_name, FakeScriptObject 10from torch._ops import HigherOrderOperator 11from torch._subclasses.fake_tensor import FakeTensorMode 12from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree 13from torch.fx.node import has_side_effect 14from torch.utils import _pytree as pytree 15 16 17log = logging.getLogger(__name__) 18 19 20# The call_torchbind operator represents a method invocation on a torchbind 21# object. The calling convention is: 22# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs) 23# We do not expect users to write this operator directly. Instead it will be 24# emitted by Dynamo when tracing encounters a torchbind object. 25class CallTorchBind(HigherOrderOperator): 26 def __init__(self): 27 super().__init__("call_torchbind") 28 29 def __call__(self, obj, method, *args, **kwargs): 30 return super().__call__(obj, method, *args, **kwargs) 31 32 33call_torchbind = CallTorchBind() 34 35# Register this operator as side-effectful with FX. 36# TODO: this is not really sufficient. While passes (hopefully) check 37# Node.is_impure() and make good decisions, we also assume we can execute the 38# graph as many times as we want without changing behavior, which is NOT true of 39# ops that mutate torchbind object state. 40has_side_effect(call_torchbind) 41 42_orig_scriptmethod_call = torch.ScriptMethod.__call__ 43 44 45def torchbind_method_redispatch(self, *args, **kwargs): 46 if isinstance(self.raw_owner, torch.ScriptObject): 47 return call_torchbind(self.raw_owner, self.name, *args, **kwargs) 48 return _orig_scriptmethod_call(self, *args, **kwargs) 49 50 51@contextmanager 52def enable_torchbind_tracing(): 53 """Context manager that acts as a feature flag to enable torchbind tracing 54 behavior. Once torchbind tracing has been stabilized, we can remove this and 55 turn it always on. 56 """ 57 try: 58 KNOWN_TYPES.append(torch.ScriptObject) 59 torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign] 60 yield 61 finally: 62 assert ( 63 KNOWN_TYPES.pop() is torch.ScriptObject 64 ), "Someone else messed with KNOWN_TYPES during tracing, exploding." 65 torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign] 66 67 68@call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd) 69def call_torchbind_impl(obj, method, *args, **kwargs): 70 if isinstance(obj, torch.ScriptObject): 71 return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs) 72 elif isinstance(obj, FakeScriptObject): 73 return getattr(obj.wrapped_obj, method)(*args, **kwargs) 74 else: 75 raise RuntimeError(f"Unsupported first arg type {type(obj)} for call_torchbind") 76 77 78@call_torchbind.py_impl(ProxyTorchDispatchMode) 79def inner(mode, *args, **kwargs): 80 proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) 81 proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) 82 83 out_proxy = mode.tracer.create_proxy( 84 "call_function", 85 call_torchbind, 86 proxy_args, 87 proxy_kwargs, 88 ) 89 out = call_torchbind(*args, **kwargs) 90 91 obj, method, *rest_args = args 92 if isinstance(obj, torch.ScriptObject): 93 ns, class_name = _ns_and_class_name( 94 obj._type().qualified_name() # type: ignore[attr-defined] 95 ) 96 log.warning( 97 "Tracing torchbind method %s.%s with real ScriptObject. This may" 98 " cause the original object being mutated. If this is not intended," 99 ' You can register a fake class with torch._library.register_fake_class("%s::%s").', 100 class_name, 101 method, 102 ns, 103 class_name, 104 ) 105 106 ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) 107 if "val" not in out_proxy.node.meta: 108 assert out is None or isinstance( 109 out, (int, float, bool) 110 ), "Currently, only these constant dtypes are supported to be returned from torchbind methods." 111 out_proxy.node.meta["val"] = out 112 return ret 113 114 115# When tracing with fake script object, the call_torchbind op will return a fake tensor 116# When tracing with real script object, the call_torchbind op may return a real tensor, 117# we need to convert it to fake tensor mannually. Dynamic shape is surpported. 118@call_torchbind.py_impl(FakeTensorMode) 119def call_torchbind_fake(mode, *args, **kwargs): 120 with mode: 121 out = call_torchbind_impl(*args, **kwargs) 122 return pytree.tree_map_only( 123 torch.Tensor, 124 lambda x: mode.from_tensor(x, static_shapes=True) 125 if not isinstance(x, torch._subclasses.fake_tensor.FakeTensor) 126 else x, 127 out, 128 ) 129 130 131call_torchbind.py_impl(DispatchKey.Autograd)( 132 autograd_not_implemented(call_torchbind, deferred_error=True) 133) 134 135 136@call_torchbind.py_functionalize_impl 137def call_torchbind_func(ctx, *args, **kwargs): 138 from torch._higher_order_ops.effects import handle_effects 139 140 return handle_effects( 141 ctx.mode._allow_token_discovery, ctx.mode._tokens, call_torchbind, args, kwargs 142 ) 143