1# mypy: allow-untyped-defs 2from contextlib import contextmanager 3 4import torch 5import torch._custom_ops 6from torch._C import DispatchKey 7from torch._higher_order_ops.strict_mode import strict_mode 8from torch._higher_order_ops.utils import autograd_not_implemented 9from torch._ops import HigherOrderOperator 10from torch._subclasses.fake_tensor import FakeTensorMode 11from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree 12from torch.utils import _pytree as pytree 13 14 15class ExportTracepoint(HigherOrderOperator): 16 def __init__(self): 17 super().__init__("_export_tracepoint") 18 19 def __call__(self, *args, **kwargs): 20 return super().__call__(*args, **kwargs) 21 22 23_export_tracepoint = ExportTracepoint() 24 25 26@_export_tracepoint.py_impl(ProxyTorchDispatchMode) 27def export_tracepoint_dispatch_mode(mode, *args, **kwargs): 28 p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs)) 29 proxy = mode.tracer.create_proxy( 30 "call_function", _export_tracepoint, p_args, p_kwargs 31 ) 32 return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer) 33 34 35@_export_tracepoint.py_impl(FakeTensorMode) 36def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs): 37 with mode: 38 return args 39 40 41@_export_tracepoint.py_functionalize_impl 42def export_tracepoint_functional(ctx, *args, **kwargs): 43 unwrapped_args = ctx.unwrap_tensors(args) 44 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) 45 46 with ctx.redispatch_to_next(): 47 out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs) 48 return ctx.wrap_tensors(out) 49 50 51_export_tracepoint.py_impl(DispatchKey.Autograd)( 52 autograd_not_implemented(_export_tracepoint, deferred_error=True) 53) 54 55 56@_export_tracepoint.py_impl(DispatchKey.CPU) 57def export_tracepoint_cpu(*args, **kwargs): 58 return args 59 60 61def _wrap_submodule(mod, path, module_call_specs): 62 assert isinstance(mod, torch.nn.Module) 63 assert path != "" 64 submodule = mod 65 for name in path.split("."): 66 if not hasattr(submodule, name): 67 raise RuntimeError(f"Couldn't find submodule at path {path}") 68 submodule = getattr(submodule, name) 69 70 def update_module_call_signatures(path, in_spec, out_spec): 71 if path in module_call_specs: 72 assert module_call_specs[path]["in_spec"] == in_spec 73 assert module_call_specs[path]["out_spec"] == out_spec 74 module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec} 75 76 def check_flattened(flat_args): 77 for a in flat_args: 78 if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None): 79 raise AssertionError( 80 f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}" 81 ) 82 83 def pre_hook(module, args, kwargs): 84 flat_args, in_spec = pytree.tree_flatten((args, kwargs)) 85 check_flattened(flat_args) 86 flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path) 87 args, kwargs = pytree.tree_unflatten(flat_args, in_spec) 88 return args, kwargs 89 90 def post_hook(module, args, kwargs, res): 91 _, in_spec = pytree.tree_flatten((args, kwargs)) 92 flat_res, out_spec = pytree.tree_flatten(res) 93 check_flattened(flat_res) 94 flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path) 95 update_module_call_signatures(path, in_spec, out_spec) 96 return pytree.tree_unflatten(flat_res, out_spec) 97 98 pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True) 99 post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True) 100 return pre_handle, post_handle 101 102 103@contextmanager 104def _wrap_submodules(f, preserve_signature, module_call_signatures): 105 handles = [] 106 107 try: 108 for path in preserve_signature: 109 handles.extend(_wrap_submodule(f, path, module_call_signatures)) 110 yield 111 finally: 112 for handle in handles: 113 handle.remove() 114 115 116def _mark_strict_experimental(cls): 117 def call(self, *args): 118 return strict_mode(self, args) 119 120 cls.__call__ = call 121 return cls 122