1# mypy: allow-untyped-defs 2import functools 3import itertools 4 5import torch 6 7from ..._dynamo.utils import counters 8from ..pattern_matcher import Arg, CallFunction, KeywordArg 9from .freezing_patterns import register_binary_folding_pattern 10 11 12aten = torch.ops.aten 13prims = torch.ops.prims 14 15 16def mark_mixed_dtype_conv(conv): 17 conv_dtype = conv.meta["val"].dtype 18 if conv_dtype not in (torch.float16, torch.bfloat16): 19 return 20 21 if not len(conv.users) == 1: 22 return 23 24 conv_user = next(iter(conv.users.keys())) 25 if not isinstance(conv_user.meta["val"], torch.Tensor): 26 return 27 28 if not conv_user.meta["val"].dtype == torch.float32: 29 return 30 31 while conv_user.target in _binary_ops: 32 if not len(conv_user.users) == 1: 33 return 34 35 conv_user = next(iter(conv_user.users.keys())) 36 37 if conv_user.target != prims.convert_element_type.default: 38 return 39 40 conv.meta["_allow_conv_mixed_dtype_folding"] = conv_dtype 41 42 43def mark_mixed_dtype_allowed_convs(gm): 44 """ 45 Mark convolutions which we will binary fold even with mixed precision constants. We constant fold in the higher precision 46 for better accuracy and then recover the original precision after. 47 """ 48 for node in gm.graph.find_nodes( 49 op="call_function", target=aten.convolution.default 50 ): 51 mark_mixed_dtype_conv(node) 52 53 54def recover_original_precision_folded_convs(gm): 55 """ 56 After binary folding conv weights and biases to a higher dtype, recover the original precision they were in. 57 """ 58 graph = gm.graph 59 for node in graph.find_nodes(op="call_function", target=aten.convolution.default): 60 orig_dtype = node.meta.get("_allow_conv_mixed_dtype_folding", None) 61 if orig_dtype is None: 62 continue 63 64 with graph.inserting_before(node): 65 for idx in [1, 2]: 66 old_input = node.args[idx] 67 if old_input is None: 68 continue 69 70 new_input = graph.create_node( 71 "call_function", 72 prims.convert_element_type.default, 73 (old_input, orig_dtype), 74 ) 75 node.replace_input_with(old_input, new_input) 76 77 78_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor] 79 80 81@functools.lru_cache(None) 82def binary_folding_init(): 83 _conv_args = [Arg() for _ in range(9)] 84 _computation_ops = [aten.convolution.default] 85 _computation_calls = [CallFunction(aten.convolution.default, *_conv_args, _users=1)] 86 87 """ 88 In order to fuse add/sub/mul/div with conv, the dimensions of its 89 constant tensor must satisfy the following: 90 - with resizing, broadcast to w/ weight/bias tensor shape 91 - broadcast to the conv output shape 92 It needs to have a shape that can resize to weight/bias 93 tensor shape because we need to run the op with the conv 94 weights/bias without changing their sizes. 95 It needs to broadcast to the conv output shape so that we do 96 accidentally change the shape of op output by pre-fusing it 97 compared to eager. 98 The only dimension value shared by weight/bias/conv output 99 is they all contain a dim with value = channels-out. In the 100 conv output tensor, this is in the second dimension, 101 so the pointwise op tensor may have a second dimension of 102 value == channels-out, but all the other dimensions have to be 1 103 """ 104 105 def _op_not_broadcasting_with_conv(weight_tensor, other_tensor): 106 # According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp 107 weight_shape = weight_tensor.shape 108 other_shape = other_tensor.shape 109 if len(weight_shape) < len(other_shape): 110 return False 111 if len(weight_shape) == len(other_shape) + 1: 112 # weight shape is [o, i, *], other_shape is [o, 1...]. 113 for i in reversed(range(len(other_shape))): 114 if i == 0 and weight_shape[0] == other_shape[i]: 115 continue 116 if other_shape[i] != 1: 117 return False 118 else: 119 # weight shape is [o, i, *], other_shape is [1, i, *] 120 for i in reversed(range(len(other_shape))): 121 if i == 1 and weight_shape[0] == other_shape[i]: 122 continue 123 if other_shape[i] != 1: 124 return False 125 return True 126 127 def _check_conv_and_broadcast_op(conv_node, other): 128 # According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp. 129 # conv.weight 130 if conv_node.args[1].op != "get_attr": 131 return False 132 # conv.bias 133 if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr": 134 return False 135 if ( 136 not isinstance(other, int) 137 and not isinstance(other, float) 138 and other.op != "get_attr" 139 ): 140 return False 141 142 if not len(conv_node.args[1].users) == 1: 143 return False 144 145 weight_meta_value = conv_node.args[1].meta.get("val") 146 if weight_meta_value is None: 147 return False 148 # Avoid fusing op that causes type promotion 149 # restricting to float avoids int/float difficulties with scalar overload 150 if not weight_meta_value.is_floating_point(): 151 return False 152 if isinstance(other, torch.fx.Node) and other.op == "get_attr": 153 other_meta_value = other.meta.get("val") 154 if not other_meta_value.is_floating_point(): # type: ignore[union-attr] 155 return False 156 if ( 157 torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr] 158 != weight_meta_value.dtype 159 ): 160 if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False): 161 return False 162 163 if ( 164 other_meta_value.dtype != torch.float # type: ignore[union-attr] 165 and weight_meta_value.dtype not in (torch.float16, torch.bfloat16) 166 ): 167 return False 168 169 if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value): 170 return False 171 else: 172 # TODO: support scalar case 173 return False 174 175 return True 176 177 def _is_foldable_pattern(match): 178 binary_node = match.output_node() 179 computation_node = binary_node.args[0] 180 other = binary_node.args[1] 181 if binary_node.args[0].target not in _computation_ops: 182 computation_node = binary_node.args[1] 183 other = binary_node.args[0] 184 if binary_node.args[0].target == aten.convolution.default: 185 return _check_conv_and_broadcast_op(computation_node, other) 186 187 return False 188 189 def resize_scalar_or_tensor_to_shape(graph, other, shape): 190 # TODO: support scalar case 191 if other.meta.get("val").numel() == 1: 192 # expand errors if the shape input has less # dims than the tensor input 193 res = graph.create_node( 194 "call_function", 195 aten.reshape.default, 196 (other, (1,)), 197 ) 198 res = graph.create_node( 199 "call_function", 200 aten.expand.default, 201 (res, shape), 202 ) 203 else: 204 res = graph.create_node( 205 "call_function", 206 aten.reshape.default, 207 (other, shape), 208 ) 209 return res 210 211 def _create_new_conv_node(graph, conv_node, binary_node, other): 212 assert conv_node.target == aten.convolution.default 213 conv_args = list(conv_node.args) 214 weight_meta_value = conv_node.args[1].meta.get("val") 215 bias = conv_args[2] 216 if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]: 217 other_reshape = resize_scalar_or_tensor_to_shape( 218 graph, other, (weight_meta_value.size(0),) 219 ) 220 new_bias = graph.create_node( 221 "call_function", 222 binary_node.target, 223 (0 if bias is None else bias, other_reshape), 224 ) 225 conv_args[2] = new_bias 226 else: 227 assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor] 228 weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))] 229 weight_broadcast_shape[0] = weight_meta_value.size(0) 230 other_reshape1 = resize_scalar_or_tensor_to_shape( 231 graph, other, tuple(weight_broadcast_shape) 232 ) 233 new_weight = graph.create_node( 234 "call_function", binary_node.target, (conv_args[1], other_reshape1) 235 ) 236 new_weight.meta.update(conv_args[1].meta) 237 conv_args[1] = new_weight 238 if bias is not None: 239 other_reshape = resize_scalar_or_tensor_to_shape( 240 graph, other, (weight_meta_value.size(0),) 241 ) 242 new_bias = graph.create_node( 243 "call_function", binary_node.target, (bias, other_reshape) 244 ) 245 new_bias.meta.update(bias.meta) 246 conv_args[2] = new_bias 247 return graph.create_node("call_function", conv_node.target, tuple(conv_args)) 248 249 for _computation_call, binary_op in itertools.product( 250 _computation_calls, _binary_ops 251 ): 252 253 @register_binary_folding_pattern( 254 CallFunction(binary_op, _computation_call, KeywordArg("other")), 255 extra_check=_is_foldable_pattern, 256 ) 257 def folded_op(match, *args, **kwargs): 258 counters["inductor"]["binary_folding"] += 1 259 other = kwargs.get("other") 260 binary_node = match.output_node() 261 computation_node = ( 262 binary_node.args[0] 263 if binary_node.args[0].target in _computation_ops 264 else binary_node.args[1] 265 ) 266 graph = match.graph 267 with graph.inserting_before(binary_node): 268 # TODO: support linear? 269 assert computation_node.target == aten.convolution.default 270 new_computation_node = _create_new_conv_node( 271 graph, computation_node, binary_node, other 272 ) 273 binary_node.replace_all_uses_with(new_computation_node) 274 new_computation_node.meta.update(computation_node.meta) 275 graph.erase_node(binary_node) 276 graph.erase_node(computation_node) 277