xref: /aosp_15_r20/external/pytorch/torch/_export/wrappers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from contextlib import contextmanager
3
4import torch
5import torch._custom_ops
6from torch._C import DispatchKey
7from torch._higher_order_ops.strict_mode import strict_mode
8from torch._higher_order_ops.utils import autograd_not_implemented
9from torch._ops import HigherOrderOperator
10from torch._subclasses.fake_tensor import FakeTensorMode
11from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
12from torch.utils import _pytree as pytree
13
14
15class ExportTracepoint(HigherOrderOperator):
16    def __init__(self):
17        super().__init__("_export_tracepoint")
18
19    def __call__(self, *args, **kwargs):
20        return super().__call__(*args, **kwargs)
21
22
23_export_tracepoint = ExportTracepoint()
24
25
26@_export_tracepoint.py_impl(ProxyTorchDispatchMode)
27def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
28    p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
29    proxy = mode.tracer.create_proxy(
30        "call_function", _export_tracepoint, p_args, p_kwargs
31    )
32    return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
33
34
35@_export_tracepoint.py_impl(FakeTensorMode)
36def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
37    with mode:
38        return args
39
40
41@_export_tracepoint.py_functionalize_impl
42def export_tracepoint_functional(ctx, *args, **kwargs):
43    unwrapped_args = ctx.unwrap_tensors(args)
44    unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
45
46    with ctx.redispatch_to_next():
47        out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs)
48        return ctx.wrap_tensors(out)
49
50
51_export_tracepoint.py_impl(DispatchKey.Autograd)(
52    autograd_not_implemented(_export_tracepoint, deferred_error=True)
53)
54
55
56@_export_tracepoint.py_impl(DispatchKey.CPU)
57def export_tracepoint_cpu(*args, **kwargs):
58    return args
59
60
61def _wrap_submodule(mod, path, module_call_specs):
62    assert isinstance(mod, torch.nn.Module)
63    assert path != ""
64    submodule = mod
65    for name in path.split("."):
66        if not hasattr(submodule, name):
67            raise RuntimeError(f"Couldn't find submodule at path {path}")
68        submodule = getattr(submodule, name)
69
70    def update_module_call_signatures(path, in_spec, out_spec):
71        if path in module_call_specs:
72            assert module_call_specs[path]["in_spec"] == in_spec
73            assert module_call_specs[path]["out_spec"] == out_spec
74        module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
75
76    def check_flattened(flat_args):
77        for a in flat_args:
78            if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None):
79                raise AssertionError(
80                    f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}"
81                )
82
83    def pre_hook(module, args, kwargs):
84        flat_args, in_spec = pytree.tree_flatten((args, kwargs))
85        check_flattened(flat_args)
86        flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path)
87        args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
88        return args, kwargs
89
90    def post_hook(module, args, kwargs, res):
91        _, in_spec = pytree.tree_flatten((args, kwargs))
92        flat_res, out_spec = pytree.tree_flatten(res)
93        check_flattened(flat_res)
94        flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path)
95        update_module_call_signatures(path, in_spec, out_spec)
96        return pytree.tree_unflatten(flat_res, out_spec)
97
98    pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True)
99    post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True)
100    return pre_handle, post_handle
101
102
103@contextmanager
104def _wrap_submodules(f, preserve_signature, module_call_signatures):
105    handles = []
106
107    try:
108        for path in preserve_signature:
109            handles.extend(_wrap_submodule(f, path, module_call_signatures))
110        yield
111    finally:
112        for handle in handles:
113            handle.remove()
114
115
116def _mark_strict_experimental(cls):
117    def call(self, *args):
118        return strict_mode(self, args)
119
120    cls.__call__ = call
121    return cls
122