1# mypy: ignore-errors 2 3import functools 4from collections import defaultdict 5from typing import Dict, List, Optional 6 7import torch 8from torch._dynamo import config 9from torch._dynamo.backends.common import aot_autograd 10from torch._dynamo.backends.debugging import boxed_nop 11from torch._inductor.cudagraph_utils import ( 12 BoxedDeviceIndex, 13 check_multiple_devices_or_any_cpu_nodes, 14 format_default_skip_message, 15 get_mutation_stack_trace, 16 get_placeholder_info, 17 log_cudagraph_skip_and_bump_counter, 18) 19from torch._inductor.utils import ( 20 BoxedBool, 21 count_tangents, 22 get_first_incompatible_cudagraph_node, 23 num_fw_fixed_arguments, 24 output_node, 25) 26from torch.multiprocessing.reductions import StorageWeakRef 27 28from .registry import register_backend 29 30 31def find_input_mutations(g): 32 def meta_fk(meta): 33 return meta["val"] if "val" in meta else meta["fake_result"] 34 35 inputs = defaultdict(set) 36 input_idx = 0 37 mutated_inputs = set() 38 for n in g.nodes: 39 if n.op == "placeholder": 40 if isinstance(meta_fk(n.meta), torch.Tensor): 41 inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) 42 input_idx += 1 43 elif n.op == "call_function": 44 if not hasattr(n.target, "_schema"): 45 continue 46 47 schema = n.target._schema 48 for i, arg in enumerate(schema.arguments): 49 if i < len(n.args): 50 argument = n.args[i] 51 else: 52 if arg.name not in n.kwargs: 53 continue 54 argument = n.kwargs[arg.name] 55 mut_arg = False 56 if arg.alias_info: 57 if arg.alias_info.is_write: 58 mut_arg = True 59 if mut_arg: 60 # TODO: not correct for args that contain tensors in a struct 61 # like list 62 mutated_inputs |= inputs[ 63 StorageWeakRef(meta_fk(argument.meta)._typed_storage()) 64 ] 65 66 # TODO: error on unrecognized nodes 67 return mutated_inputs 68 69 70def get_device_node_mapping(gm: torch.fx.GraphModule): 71 device_node_mapping: Dict[torch.device, torch.fx.Node] = {} 72 for n in gm.graph.nodes: 73 t = n.meta.get("val", None) 74 if isinstance(t, torch.Tensor) and t.device not in device_node_mapping: 75 device_node_mapping[t.device] = n 76 return device_node_mapping 77 78 79def check_for_mutation_ignore_cuda_graph_managed_tensor( 80 aot_model: torch.fx.GraphModule, num_fixed 81) -> Optional[str]: 82 mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed)) 83 if not mutation_indices: 84 return None 85 86 placeholders = get_placeholder_info(aot_model.graph) 87 return get_mutation_stack_trace(placeholders, mutation_indices) 88 89 90def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]: 91 if not config.cudagraph_backend_support_input_mutation: 92 if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor( 93 aot_model, num_fixed 94 ): 95 return mut_skip 96 97 if skip := check_multiple_devices_or_any_cpu_nodes( 98 get_device_node_mapping(aot_model) 99 ): 100 return skip 101 102 if node := get_first_incompatible_cudagraph_node(aot_model): 103 return format_default_skip_message(f"incompatible op ({node.name})") 104 105 return None 106 107 108def get_device_index(gm) -> int: 109 device = next(iter(get_device_node_mapping(gm))) 110 assert device.type == "cuda" 111 return device.index 112 113 114def get_stack_traces(gm) -> List[Optional[str]]: 115 output = output_node(gm) 116 assert len(output.args) == 1 117 return [ 118 (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) 119 for arg in output.args[0] 120 ] 121 122 123def cudagraphs(dynamo_model, dynamo_inputs): 124 from torch._inductor.cudagraph_trees import cudagraphify_impl 125 126 do_cudagraphs = BoxedBool(True) 127 boxed_device_index = BoxedDeviceIndex(None) 128 129 def forward_cudagraphs(aot_model, aot_inputs, is_inference=False): 130 interp = boxed_nop(aot_model, aot_inputs) 131 fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs)) 132 if skip_msg := check_for_skip(aot_model, fixed): 133 BoxedBool.disable(do_cudagraphs) 134 log_cudagraph_skip_and_bump_counter( 135 f"skipping cudagraphs due to {skip_msg}" 136 ) 137 return interp 138 139 boxed_device_index.set(get_device_index(aot_model)) 140 out = cudagraphify_impl( 141 interp, 142 aot_inputs, 143 range(fixed), 144 device_index=boxed_device_index.value, 145 is_backward=False, 146 is_inference=False, 147 stack_traces=get_stack_traces(aot_model), 148 placeholders=get_placeholder_info(aot_model.graph), 149 mutated_input_idxs=find_input_mutations(aot_model.graph), 150 ) 151 out._boxed_call = True 152 return out 153 154 def backward_cudagraphs(aot_model, aot_inputs): 155 interp = boxed_nop(aot_model, aot_inputs) 156 if not do_cudagraphs: 157 return aot_model 158 159 fixed = count_tangents(aot_model) 160 if skip_msg := check_for_skip(aot_model, fixed): 161 log_cudagraph_skip_and_bump_counter( 162 "skipping cudagraphs due to %s", skip_msg 163 ) 164 165 # See [Backward Generation Handling] 166 manager = torch._inductor.cudagraph_trees.get_manager( 167 boxed_device_index.value, create_if_none_exists=False 168 ) 169 assert manager is not None 170 171 def fn(inputs): 172 manager.set_to_running_backward() 173 return aot_model(inputs) 174 175 fn._boxed_call = True 176 return fn 177 178 out = cudagraphify_impl( 179 interp, 180 aot_inputs, 181 range(fixed), 182 device_index=get_device_index(aot_model), 183 is_backward=True, 184 is_inference=False, 185 stack_traces=get_stack_traces(aot_model), 186 placeholders=get_placeholder_info(aot_model.graph), 187 mutated_input_idxs=find_input_mutations(aot_model.graph), 188 ) 189 out._boxed_call = True 190 return out 191 192 aot_cudagraphs = aot_autograd( 193 fw_compiler=forward_cudagraphs, 194 bw_compiler=backward_cudagraphs, 195 inference_compiler=functools.partial(forward_cudagraphs, is_inference=True), 196 keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation, 197 ) 198 return aot_cudagraphs(dynamo_model, dynamo_inputs) 199 200 201class CudagraphsBackend: 202 compiler_name = "cudagraphs" 203 204 @staticmethod 205 def reset(): 206 from torch._inductor.cudagraph_trees import reset_cudagraph_trees 207 208 reset_cudagraph_trees() 209 210 @staticmethod 211 def __call__(model, inputs): 212 return cudagraphs(model, inputs) 213 214 215# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful 216# for debugging and can serve as a perf baseline. 217register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend()) 218 219 220def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True): 221 """This isn't registered as a backend, but is used in some benchmarks""" 222 assert isinstance(inputs, (list, tuple)) 223 if copy_inputs: 224 static_inputs = [torch.zeros_like(x) for x in inputs] 225 else: 226 static_inputs = list(inputs) 227 228 # warmup 229 torch.cuda.synchronize() 230 stream = torch.cuda.Stream() 231 stream.wait_stream(torch.cuda.current_stream()) 232 with torch.cuda.stream(stream): 233 model(*inputs) 234 stream.synchronize() 235 torch.cuda.current_stream().wait_stream(stream) 236 torch.cuda.synchronize() 237 238 # record 239 graph = torch.cuda.CUDAGraph() 240 with torch.cuda.graph(graph, stream=stream): 241 static_outputs = model(*static_inputs) 242 if not isinstance(static_outputs, (list, tuple)): 243 static_outputs = (static_outputs,) 244 245 def run(*new_inputs): 246 assert len(static_inputs) == len(new_inputs) 247 if copy_inputs: 248 for dst, src in zip(static_inputs, new_inputs): 249 dst.copy_(src) 250 graph.replay() 251 if copy_outputs: 252 return [x.clone() for x in static_outputs] 253 else: 254 return static_outputs 255 256 return run 257