xref: /aosp_15_r20/external/pytorch/torch/_prims/executor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Callable, Optional
3
4from torch._prims.context import TorchRefsMode
5from torch.fx import GraphModule
6from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
7
8
9def execute(
10    gm: GraphModule,
11    *args,
12    executor: str = "aten",
13    executor_parameters: Optional[dict] = None,
14):
15    """
16    Prototype ATen executor.
17
18    Just executes the context's graph.
19    """
20
21    if executor == "aten":
22        return gm.forward(*args)
23
24    msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: aten."
25    raise ValueError(msg)
26
27
28def make_traced(fn: Callable):
29    """
30    Returns a function that, when called, will
31    trace its torch operations to prims and then
32    execute those prims on the requested trace executor
33    (possibly lowering them to that trace executor first).
34
35    Only supports the torch operations defined in _torch_to_reference_map
36    in context.py and operations with positional args. All args must
37    be tensors.
38    In the near future all these restrictions will be lifted.
39
40    Example usage:
41
42    def foo(a, b):
43      return torch.add(a, b)
44
45    traced_foo = make_traced(foo)
46
47    a = torch.randn((1, 2, 3, 4, 5), device='cuda')
48    b = torch.randn((1, 2, 3, 4, 5), device='cuda')
49    result = traced_foo(a, b, executor='aten')
50    """
51
52    def _traced(*args, executor="aten", **kwargs):
53        # TODO: caching
54        wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs)
55
56        with TorchRefsMode():
57            gm = make_fx(wrapped)(all_args)
58        return execute(gm, all_args, executor=executor)
59
60    return _traced
61