1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3# pyre-strict 4 5import logging 6from typing import Optional, Sequence, Union 7 8import torch 9from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue 10from torch._subclasses import FakeTensor, FakeTensorMode 11from torch.fx.node import Argument, Target 12from torch.utils import _pytree as pytree 13 14 15class GraphBuilder(ExportPass): 16 """Utility class for creating a graph module with user-specified ops. 17 18 This class allows us to create test graph modules with any ops we want 19 directly, rather than relying on decomposition or passes. 20 21 Usage: 22 builder = GraphBuilder() 23 # To insert placeholders, use builder.placeholder. 24 x = builder.placeholder("x", torch.randn(1, 3, 224, 224)) 25 # To insert an op, use builder.call_operator. 26 op = builder.call_operator( 27 some_op 28 (x, other_args, ...), 29 ) 30 # Insert outputs as a list of ProxyValues using builder.output. 31 builder.output([op]) 32 # Get GraphModule from builder. 33 gm = builder.get_graph_module() 34 """ 35 36 def __init__(self) -> None: 37 self.exporter = ExportPass() 38 self.tracer: ExportPass.ExportTracer = self.ExportTracer( 39 self, torch.fx.graph.CodeGen() 40 ) 41 self.fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) 42 self.tracer.fake_tensor_mode = self.fake_tensor_mode 43 44 # This will be called to create nodes in tracer. 45 self.interpreter = torch.fx.Interpreter( 46 torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) 47 ) 48 49 # pyre-ignore[14]: Inconsistent override. 50 def placeholder( 51 self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor] 52 ) -> ProxyValue: 53 if not isinstance(fake_tensor, FakeTensor): 54 fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor) 55 logging.info(f"Creating placeholder {target} => {fake_tensor.shape}") 56 placeholder = super().placeholder(target, fake_tensor, NodeMetadata({})) 57 return placeholder 58 59 # pyre-ignore[14]: Inconsistent override. 60 def output(self, results: list[ProxyValue]) -> ProxyValue: 61 logging.info(f"Creating outputs {results}") 62 return super().output(results, NodeMetadata({})) 63 64 def get_graph_module(self) -> torch.fx.GraphModule: 65 return torch.fx.GraphModule(self.tracer.root, self.tracer.graph) 66 67 def call_operator( 68 self, 69 op, # pyre-ignore 70 args: tuple[Argument, ...], 71 kwargs: Optional[dict[str, Argument]] = None, 72 meta: Optional[NodeMetadata] = None, 73 ) -> ProxyValue: 74 if meta is None: 75 meta = NodeMetadata({}) 76 if kwargs is None: 77 kwargs = {} 78 return super().call_operator(op, args, kwargs, meta) 79 80 81def single_op_builder( 82 placeholders: Sequence[Union[torch.Tensor, FakeTensor]], 83 op: Target, 84 args: Sequence[Argument], 85 kwargs: Optional[dict[str, Argument]] = None, 86) -> torch.fx.GraphModule: 87 """Create a graph module with a single op. 88 89 Args: 90 placeholders: Placeholders to be used as inputs to the GraphModule. 91 op: The op to be inserted. 92 args: The args to be passed to the op. 93 kwargs: The kwargs to be passed to the op. 94 95 Returns: 96 A graph module with a single op 97 """ 98 builder = GraphBuilder() 99 op_to_placeholder_dict = { 100 p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders) 101 } 102 proxy_args, proxy_kwargs = pytree.tree_map_only( 103 (torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs) 104 ) 105 node = builder.call_operator(op, proxy_args, proxy_kwargs) 106 builder.output([node]) 107 return builder.get_graph_module() 108