xref: /aosp_15_r20/external/executorch/backends/cadence/aot/graph_builder.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3# pyre-strict
4
5import logging
6from typing import Optional, Sequence, Union
7
8import torch
9from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
10from torch._subclasses import FakeTensor, FakeTensorMode
11from torch.fx.node import Argument, Target
12from torch.utils import _pytree as pytree
13
14
15class GraphBuilder(ExportPass):
16    """Utility class for creating a graph module with user-specified ops.
17
18    This class allows us to create test graph modules with any ops we want
19    directly, rather than relying on decomposition or passes.
20
21    Usage:
22        builder = GraphBuilder()
23        # To insert placeholders, use builder.placeholder.
24        x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
25        # To insert an op, use builder.call_operator.
26        op = builder.call_operator(
27            some_op
28            (x, other_args, ...),
29        )
30        # Insert outputs as a list of ProxyValues using builder.output.
31        builder.output([op])
32        # Get GraphModule from builder.
33        gm = builder.get_graph_module()
34    """
35
36    def __init__(self) -> None:
37        self.exporter = ExportPass()
38        self.tracer: ExportPass.ExportTracer = self.ExportTracer(
39            self, torch.fx.graph.CodeGen()
40        )
41        self.fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
42        self.tracer.fake_tensor_mode = self.fake_tensor_mode
43
44        # This will be called to create nodes in tracer.
45        self.interpreter = torch.fx.Interpreter(
46            torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
47        )
48
49    # pyre-ignore[14]: Inconsistent override.
50    def placeholder(
51        self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor]
52    ) -> ProxyValue:
53        if not isinstance(fake_tensor, FakeTensor):
54            fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor)
55        logging.info(f"Creating placeholder {target} => {fake_tensor.shape}")
56        placeholder = super().placeholder(target, fake_tensor, NodeMetadata({}))
57        return placeholder
58
59    # pyre-ignore[14]: Inconsistent override.
60    def output(self, results: list[ProxyValue]) -> ProxyValue:
61        logging.info(f"Creating outputs {results}")
62        return super().output(results, NodeMetadata({}))
63
64    def get_graph_module(self) -> torch.fx.GraphModule:
65        return torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
66
67    def call_operator(
68        self,
69        op,  # pyre-ignore
70        args: tuple[Argument, ...],
71        kwargs: Optional[dict[str, Argument]] = None,
72        meta: Optional[NodeMetadata] = None,
73    ) -> ProxyValue:
74        if meta is None:
75            meta = NodeMetadata({})
76        if kwargs is None:
77            kwargs = {}
78        return super().call_operator(op, args, kwargs, meta)
79
80
81def single_op_builder(
82    placeholders: Sequence[Union[torch.Tensor, FakeTensor]],
83    op: Target,
84    args: Sequence[Argument],
85    kwargs: Optional[dict[str, Argument]] = None,
86) -> torch.fx.GraphModule:
87    """Create a graph module with a single op.
88
89    Args:
90        placeholders: Placeholders to be used as inputs to the GraphModule.
91        op: The op to be inserted.
92        args: The args to be passed to the op.
93        kwargs: The kwargs to be passed to the op.
94
95    Returns:
96        A graph module with a single op
97    """
98    builder = GraphBuilder()
99    op_to_placeholder_dict = {
100        p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders)
101    }
102    proxy_args, proxy_kwargs = pytree.tree_map_only(
103        (torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs)
104    )
105    node = builder.call_operator(op, proxy_args, proxy_kwargs)
106    builder.output([node])
107    return builder.get_graph_module()
108