1# mypy: allow-untyped-defs 2import copy 3import dataclasses 4import itertools 5import os 6from typing import Any, Callable, Dict, List 7 8import torch 9import torch._lazy as lazy 10import torch._lazy.metrics as metrics 11from torch import fx 12from torch._lazy import computation, debug as lazy_debug 13from torch._lazy.tensor_factory_functions import tensor_factory_functions 14 15 16debug = os.environ.get("debug_extract_compiled_graph") is not None 17 18 19@dataclasses.dataclass 20class GraphInputMatcher: 21 """ 22 The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing. 23 Specifically, those graph inputs corresponding to method parameters should be replaced with the 24 arguments for the current call. 25 26 tensor_id_to_arg_idx maps the tensor id to the parameter index. 27 graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the 28 TS/XLA graph inputs. 29 """ 30 31 tensor_id_to_arg_idx: Dict[int, int] 32 graph_input_tensor_ids: List[int] 33 # there are 2 categories of graph_input_tensors. 34 # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are 35 # most likely const tensors and we can get its content from graph_input_tensors 36 # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get 37 # the tensor from method arguments 38 graph_input_ivalues: List[Any] 39 40 # get the real graph input tensors 41 def __call__(self, args): 42 real_input = [] 43 for tensor_id, traced_ivalue in zip( 44 self.graph_input_tensor_ids, self.graph_input_ivalues 45 ): 46 arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None) 47 if arg_idx is None: 48 inp = traced_ivalue 49 else: 50 inp = args[arg_idx] 51 real_input.append(inp) 52 return real_input 53 54 55class ReturnValueHandler: 56 r""" 57 When ltc_sync_multi is called on multi tensors, the compiled graph 58 will contain output only for unique tensors - if a tensor appears multiple 59 times in the input to _ltc_sync_multi, only the first occurance matters. 60 61 However from python level, we still expect multi tensors returned with duplciation 62 even if the TS graph dedup the output. e.g. for method: 63 64 def forward(self, a): 65 return a, a 66 67 the TS graph captured by LTC will return a single tensor, but Python method expects 2. 68 69 This class dedup the lazy tensors first to get the index that will be used 70 to duplicate the eager tensors later. 71 """ 72 73 def __init__(self, lazy_out_list): 74 self.index: List[List[int]] = [] 75 self.total_count = len(lazy_out_list) 76 77 tensor_id_to_idx: Dict[int, int] = {} 78 for dup_idx, lazy_tensor in enumerate(lazy_out_list): 79 uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None) 80 if uniq_idx is not None: 81 self.index[uniq_idx].append(dup_idx) 82 else: 83 uniq_idx = len(self.index) 84 self.index.append([dup_idx]) 85 tensor_id_to_idx[id(lazy_tensor)] = uniq_idx 86 87 def duplicate_eager_tensors(self, eager_tensor_list): 88 duplicated_list = [None] * self.total_count 89 assert len(eager_tensor_list) == len(self.index) 90 91 for uniq_idx, eager_tensor in enumerate(eager_tensor_list): 92 for dup_idx in self.index[uniq_idx]: 93 duplicated_list[dup_idx] = eager_tensor 94 return duplicated_list 95 96 97def force_lazy_device(model: fx.GraphModule): 98 """ 99 Factory methods in a Fx graph may create tensors for a specific eager devices. 100 If we take no actions, those eager tensors will be mixed with lazy tensors and 101 cause crash. This method overwrite those eager device to lazy device. 102 """ 103 104 def tolazydevice(dev): 105 if isinstance(dev, torch.device): 106 return torch.device("lazy", index=dev.index) 107 return dev 108 109 def hasDeviceArg(args, kwargs): 110 return any( 111 isinstance(arg, torch.device) 112 for arg in itertools.chain(args, kwargs.values()) 113 ) 114 115 for nd in model.graph.nodes: 116 nd.args = tuple(tolazydevice(arg) for arg in nd.args) 117 nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()} 118 119 # For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return 120 # eager tensors on the default device 121 # (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove, 122 # and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart). 123 # To force those tensors on the lazy device, we can not simply override 124 # the device argument since there is no explicit device argument. 125 # What we are doing here is, for the list of covered tensor factory methods 126 # we add a lazy device argument explicity. 127 # 128 # TODO: This solution is no ideal since we may miss some factory methods. In future 129 # when we support lazy mode, this method can be replaced by that. 130 if nd.target in tensor_factory_functions and not hasDeviceArg( 131 nd.args, nd.kwargs 132 ): 133 kwargs = dict(nd.kwargs) # nd.kwargs is immutable. make a mutable copy. 134 kwargs["device"] = torch.device("lazy") 135 nd.kwargs = kwargs 136 137 model.recompile() 138 139 140def get_fallback_ops(): 141 fallback_ops = [] 142 for opname in metrics.counter_names(): 143 if "aten::" not in opname: 144 continue 145 val = int(metrics.counter_value(opname)) 146 if val > 0: 147 fallback_ops.append(f"{opname}={val}") 148 149 return fallback_ops 150 151 152def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable: 153 """ 154 Optimize an eager model with LTC and returns a wrapper to execute the 155 compiled graph directly without retracing. It depends on other mechanisms 156 like TorchDynamo guards to guarantee the returned wrapper is only called 157 when it's safe. 158 """ 159 lazy_args = [arg.to(device="lazy") for arg in example_inputs] 160 args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args] 161 tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)} 162 lazy_model = copy.deepcopy(model).to(device=torch.device("lazy")) 163 force_lazy_device(lazy_model) 164 165 # This line executes lazy tracing and enable us extracting compiled graph later 166 metrics.reset() 167 lazy_out = lazy_model(*lazy_args) 168 fallback_ops = get_fallback_ops() 169 metrics.reset() 170 171 if len(fallback_ops) > 0: 172 raise RuntimeError( 173 f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}" 174 ) 175 176 if not isinstance(lazy_out, (tuple, list)): 177 lazy_out = (lazy_out,) 178 179 args_and_out = tuple(lazy_args) + tuple(lazy_out) 180 return_value_handler = ReturnValueHandler(args_and_out) 181 if debug: 182 print("Fx code:\n", model.code) 183 print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text")) 184 185 # TODO: this part is TS backend specific for now and will be generalized to 186 # support XLA 187 ( 188 graph_input_tensor_ids, 189 graph_input_ivalues, 190 ) = computation.get_tensors_ts_device_data_node(args_and_out) 191 assert len(graph_input_tensor_ids) == len(graph_input_ivalues) 192 graph_input_matcher = GraphInputMatcher( 193 tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues 194 ) 195 196 graph_hash = computation.get_graph_hash(args_and_out) 197 198 if debug: 199 print("graph_hash", graph_hash) 200 print(f"args_tensor_ids {args_tensor_ids}") 201 print("tensor ids from device data:", graph_input_tensor_ids) 202 203 # sync the list of output tensors so the computation graph for these 204 # tensors will be cached. Those computation graphs can be retrieved 205 # by graph hash later. 206 lazy.sync_multi(args_and_out, []) 207 208 def optimized_mod(*args): 209 if len(args_and_out) == 0: 210 return () 211 graph_input = graph_input_matcher(args) 212 res = return_value_handler.duplicate_eager_tensors( 213 computation.run_cached_graph(graph_hash, graph_input) 214 ) 215 216 assert len(res) == len(args_and_out) 217 for i, arg in enumerate(args): 218 # only copy those tensors that get inplace updated 219 if arg is not res[i]: 220 arg.copy_(res[i]) 221 222 # skip the args 223 return res[len(args) :] 224 225 return optimized_mod 226