1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: fx.passes"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass 4*da0073e9SAndroid Build Coastguard Workerimport operator 5*da0073e9SAndroid Build Coastguard Workerimport logging 6*da0073e9SAndroid Build Coastguard Workerimport sys 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerfrom torch.fx._symbolic_trace import symbolic_trace 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 12*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.operator_support import OperatorSupport 13*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.utils.fuser_utils import fuse_by_partitions 14*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.utils.matcher_utils import SubgraphMatcher 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerlogging.basicConfig(level=logging.WARNING) 20*da0073e9SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__) 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerclass TestModule(torch.nn.Module): 23*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 24*da0073e9SAndroid Build Coastguard Worker super().__init__() 25*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 4) 26*da0073e9SAndroid Build Coastguard Worker self.linear2 = torch.nn.Linear(4, 4) 27*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 4)) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b, c): 30*da0073e9SAndroid Build Coastguard Worker add = a + b 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker linear_1 = self.linear(add) 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker add_1 = add + c 35*da0073e9SAndroid Build Coastguard Worker add_2 = add_1 + self.param 36*da0073e9SAndroid Build Coastguard Worker add_3 = add_1 + linear_1 37*da0073e9SAndroid Build Coastguard Worker add_4 = add_2 + add_3 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker linear_2 = self.linear2(add_4) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker add_5 = linear_2 + add_4 42*da0073e9SAndroid Build Coastguard Worker add_6 = add_5 + a 43*da0073e9SAndroid Build Coastguard Worker relu = add_6.relu() 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker return add_4, add_6, relu 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Workerclass TestDeepModule(torch.nn.Module): 48*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 49*da0073e9SAndroid Build Coastguard Worker super().__init__() 50*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 4) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b, c): 53*da0073e9SAndroid Build Coastguard Worker o = a + b 54*da0073e9SAndroid Build Coastguard Worker o = o + 1.0 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker # testing to avoid DFS uses in passes. Since Python has max recursion depth. 57*da0073e9SAndroid Build Coastguard Worker for _ in range(sys.getrecursionlimit() + 1): 58*da0073e9SAndroid Build Coastguard Worker o = o - c 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker return o 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Workerclass TestPartitionFunctions: 64*da0073e9SAndroid Build Coastguard Worker @staticmethod 65*da0073e9SAndroid Build Coastguard Worker def forward1(a, b, c): 66*da0073e9SAndroid Build Coastguard Worker add = a + b 67*da0073e9SAndroid Build Coastguard Worker add_1 = add + b 68*da0073e9SAndroid Build Coastguard Worker add_2 = add_1 + c 69*da0073e9SAndroid Build Coastguard Worker relu_1 = add_2.relu() 70*da0073e9SAndroid Build Coastguard Worker add_3 = add_1 + add_2 71*da0073e9SAndroid Build Coastguard Worker add_4 = add_1 + relu_1 + add_3 72*da0073e9SAndroid Build Coastguard Worker relu_2 = add_4.relu() 73*da0073e9SAndroid Build Coastguard Worker add_5 = relu_2 + add_4 74*da0073e9SAndroid Build Coastguard Worker add_6 = add_5 + add_4 75*da0073e9SAndroid Build Coastguard Worker return add_4, add_6 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker @staticmethod 78*da0073e9SAndroid Build Coastguard Worker def forward2(a, b, _): 79*da0073e9SAndroid Build Coastguard Worker add = a + b 80*da0073e9SAndroid Build Coastguard Worker add_1 = add + b 81*da0073e9SAndroid Build Coastguard Worker relu_1 = add_1.relu() # blocked by this 82*da0073e9SAndroid Build Coastguard Worker add_3 = add_1 + relu_1 83*da0073e9SAndroid Build Coastguard Worker add_4 = add_1 + add_3 84*da0073e9SAndroid Build Coastguard Worker return add_4, add_1 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker @staticmethod 87*da0073e9SAndroid Build Coastguard Worker def forward3(a, b, c): 88*da0073e9SAndroid Build Coastguard Worker add = a + b 89*da0073e9SAndroid Build Coastguard Worker add_1 = a + c 90*da0073e9SAndroid Build Coastguard Worker add_2 = b + c 91*da0073e9SAndroid Build Coastguard Worker return add, add_1, add_2 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker @staticmethod 94*da0073e9SAndroid Build Coastguard Worker def forward4(a, b, c): 95*da0073e9SAndroid Build Coastguard Worker add = a + b 96*da0073e9SAndroid Build Coastguard Worker add_1 = a + c 97*da0073e9SAndroid Build Coastguard Worker add_2 = b + c 98*da0073e9SAndroid Build Coastguard Worker return torch.where(add > 0, add_1, add_2) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker @staticmethod 101*da0073e9SAndroid Build Coastguard Worker def forward5(a, b, c): 102*da0073e9SAndroid Build Coastguard Worker # add should be fused right branch, as left branch is not supported 103*da0073e9SAndroid Build Coastguard Worker add = a + 1 104*da0073e9SAndroid Build Coastguard Worker # left branch 105*da0073e9SAndroid Build Coastguard Worker relu = add.relu() 106*da0073e9SAndroid Build Coastguard Worker # right branch 107*da0073e9SAndroid Build Coastguard Worker add_1 = add + 2 108*da0073e9SAndroid Build Coastguard Worker return relu, add_1 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker @staticmethod 111*da0073e9SAndroid Build Coastguard Worker def forward6(a, b, c): 112*da0073e9SAndroid Build Coastguard Worker # add should have its own partition, as neither branchs are supported 113*da0073e9SAndroid Build Coastguard Worker add = a + 1 114*da0073e9SAndroid Build Coastguard Worker # left branch 115*da0073e9SAndroid Build Coastguard Worker relu = add.relu() 116*da0073e9SAndroid Build Coastguard Worker # right branch 117*da0073e9SAndroid Build Coastguard Worker relu_1 = add.relu() 118*da0073e9SAndroid Build Coastguard Worker return relu, relu_1 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker @staticmethod 121*da0073e9SAndroid Build Coastguard Worker def forward7(a, b, c): 122*da0073e9SAndroid Build Coastguard Worker # both branches are supported, all adds should be fused together 123*da0073e9SAndroid Build Coastguard Worker add = a + 1 124*da0073e9SAndroid Build Coastguard Worker # left branch 125*da0073e9SAndroid Build Coastguard Worker add_1 = add + 2 126*da0073e9SAndroid Build Coastguard Worker # right branch is larger 127*da0073e9SAndroid Build Coastguard Worker add_2 = add + 1 128*da0073e9SAndroid Build Coastguard Worker add_3 = add_2 + 1 129*da0073e9SAndroid Build Coastguard Worker return add_3, add_1 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker @staticmethod 132*da0073e9SAndroid Build Coastguard Worker def forward8(a, b, c): 133*da0073e9SAndroid Build Coastguard Worker # both branches are in the same partition, add should join the same partition 134*da0073e9SAndroid Build Coastguard Worker add = a + 1 135*da0073e9SAndroid Build Coastguard Worker # left branch 136*da0073e9SAndroid Build Coastguard Worker add_1 = add + 2 137*da0073e9SAndroid Build Coastguard Worker # right branch 138*da0073e9SAndroid Build Coastguard Worker add_2 = add + 1 139*da0073e9SAndroid Build Coastguard Worker # left and right branch merges 140*da0073e9SAndroid Build Coastguard Worker add_3 = add_2 + add_1 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker return add_3 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker @staticmethod 145*da0073e9SAndroid Build Coastguard Worker def forward9(a, b, c): 146*da0073e9SAndroid Build Coastguard Worker add = a + 1 147*da0073e9SAndroid Build Coastguard Worker # branch 1 148*da0073e9SAndroid Build Coastguard Worker add_1 = add + 1 149*da0073e9SAndroid Build Coastguard Worker # branch 2 150*da0073e9SAndroid Build Coastguard Worker add_2 = add + 1 151*da0073e9SAndroid Build Coastguard Worker # branch_3 152*da0073e9SAndroid Build Coastguard Worker add_3 = add + 1 153*da0073e9SAndroid Build Coastguard Worker out = torch.stack([add_1, add_2, add_3]) 154*da0073e9SAndroid Build Coastguard Worker return out 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker @staticmethod 157*da0073e9SAndroid Build Coastguard Worker def forward10(a, b, c): 158*da0073e9SAndroid Build Coastguard Worker add = a + 1 159*da0073e9SAndroid Build Coastguard Worker # branch 1 160*da0073e9SAndroid Build Coastguard Worker add_1 = add + 1 161*da0073e9SAndroid Build Coastguard Worker # branch 2 162*da0073e9SAndroid Build Coastguard Worker add_2 = add + 1 163*da0073e9SAndroid Build Coastguard Worker # branch 3: depends on branch 2 164*da0073e9SAndroid Build Coastguard Worker add_3 = add + add_2 165*da0073e9SAndroid Build Coastguard Worker out = torch.stack([add_1, add_2, add_3]) 166*da0073e9SAndroid Build Coastguard Worker return out 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker @staticmethod 169*da0073e9SAndroid Build Coastguard Worker def forward11(a, b, c): 170*da0073e9SAndroid Build Coastguard Worker add = a + 1 171*da0073e9SAndroid Build Coastguard Worker # branch 1 172*da0073e9SAndroid Build Coastguard Worker add_1 = add.relu() 173*da0073e9SAndroid Build Coastguard Worker # branch 2 depends on branch 1 174*da0073e9SAndroid Build Coastguard Worker add_2 = add + add_1 175*da0073e9SAndroid Build Coastguard Worker # branch 3 176*da0073e9SAndroid Build Coastguard Worker add_3 = add.relu() 177*da0073e9SAndroid Build Coastguard Worker out = torch.stack([add_1, add_2, add_3]) 178*da0073e9SAndroid Build Coastguard Worker return out 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker @staticmethod 181*da0073e9SAndroid Build Coastguard Worker def forward12(a, b, c): 182*da0073e9SAndroid Build Coastguard Worker b0 = a + 1.0 183*da0073e9SAndroid Build Coastguard Worker c0 = a + 1.5 184*da0073e9SAndroid Build Coastguard Worker x0 = b0.relu() 185*da0073e9SAndroid Build Coastguard Worker x1 = c0.relu() 186*da0073e9SAndroid Build Coastguard Worker b1 = b0 + x1 187*da0073e9SAndroid Build Coastguard Worker c1 = c0 + 1.2 188*da0073e9SAndroid Build Coastguard Worker # c2 has dependency on x0 & b0, when we merge {c0, c1, c2} 189*da0073e9SAndroid Build Coastguard Worker # this dependency should be updated to the fusion group and reflected 190*da0073e9SAndroid Build Coastguard Worker # on the decision to not fuse b0 & b1, which forms a cyclic dependency in 191*da0073e9SAndroid Build Coastguard Worker # the new graph 192*da0073e9SAndroid Build Coastguard Worker c2 = x0 + c0 193*da0073e9SAndroid Build Coastguard Worker return b1, c2 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker @staticmethod 196*da0073e9SAndroid Build Coastguard Worker def forward13(a, b, c): 197*da0073e9SAndroid Build Coastguard Worker a0, a1, a2, a3 = a.split(1, 0) 198*da0073e9SAndroid Build Coastguard Worker b1 = a0 + b 199*da0073e9SAndroid Build Coastguard Worker c1 = a1 + c 200*da0073e9SAndroid Build Coastguard Worker return b1 + c1 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker @staticmethod 203*da0073e9SAndroid Build Coastguard Worker def forward14(a, b, c): 204*da0073e9SAndroid Build Coastguard Worker a0, a1 = torch.ops.aten.std_mean(a) 205*da0073e9SAndroid Build Coastguard Worker out = a0 + 1.0 206*da0073e9SAndroid Build Coastguard Worker return out 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker @staticmethod 209*da0073e9SAndroid Build Coastguard Worker def forward15(a, b, c): 210*da0073e9SAndroid Build Coastguard Worker a0 = torch.ops.aten.view(a, [2, 2]) 211*da0073e9SAndroid Build Coastguard Worker a1 = torch.ops.aten.permute(a0, [1, 0]) 212*da0073e9SAndroid Build Coastguard Worker a2 = a1 + 1.0 213*da0073e9SAndroid Build Coastguard Worker a3 = torch.ops.aten.permute(a2, [1, 0]) 214*da0073e9SAndroid Build Coastguard Worker a4 = a3 + 1.0 215*da0073e9SAndroid Build Coastguard Worker a5 = torch.ops.aten.permute(a4, [1, 0]) 216*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten.permute(a5, [1, 0]) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker @staticmethod 219*da0073e9SAndroid Build Coastguard Worker def forward16(a, b, c): 220*da0073e9SAndroid Build Coastguard Worker a0 = a - 1.0 221*da0073e9SAndroid Build Coastguard Worker a1 = torch.ops.aten.view(a0, [2, 2]) 222*da0073e9SAndroid Build Coastguard Worker a2 = torch.ops.aten.permute(a1, [1, 0]) 223*da0073e9SAndroid Build Coastguard Worker a3 = a2 + 1.0 224*da0073e9SAndroid Build Coastguard Worker a4 = torch.ops.aten.permute(a3, [1, 0]) 225*da0073e9SAndroid Build Coastguard Worker a5 = a4 + 1.0 226*da0073e9SAndroid Build Coastguard Worker a6 = torch.ops.aten.permute(a5, [1, 0]) 227*da0073e9SAndroid Build Coastguard Worker a7 = torch.ops.aten.permute(a6, [1, 0]) 228*da0073e9SAndroid Build Coastguard Worker return a7 - 1.0 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker @staticmethod 231*da0073e9SAndroid Build Coastguard Worker def forward17(a, b, c, d, e, f): 232*da0073e9SAndroid Build Coastguard Worker a0 = a + b 233*da0073e9SAndroid Build Coastguard Worker a1 = c + d 234*da0073e9SAndroid Build Coastguard Worker a2 = e + f 235*da0073e9SAndroid Build Coastguard Worker return a0, a1, a2 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker @staticmethod 238*da0073e9SAndroid Build Coastguard Worker def forward18(a, b, c): 239*da0073e9SAndroid Build Coastguard Worker a0, a1 = torch.ops.aten.var_mean(a) 240*da0073e9SAndroid Build Coastguard Worker return a0 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker# A mock OperatorSupport class, where only operator.add is supported 243*da0073e9SAndroid Build Coastguard Workerclass MockOperatorSupport(OperatorSupport): 244*da0073e9SAndroid Build Coastguard Worker def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 245*da0073e9SAndroid Build Coastguard Worker return (node.op == "call_function" and 246*da0073e9SAndroid Build Coastguard Worker node.target in {operator.add, operator.getitem, 247*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.view, 248*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.permute, 249*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.std_mean}) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker@instantiate_parametrized_tests 252*da0073e9SAndroid Build Coastguard Workerclass TestFXGraphPasses(JitTestCase): 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker @parametrize("fn, expected_partition, bookend_non_compute_pass", [ 255*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]], False), 256*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]], False), 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker # 1 horizontal fusion with common producer 259*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]], False), 260*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]], False), 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker # 2 branches cases 263*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward5, [["add_1", "add"]], False), 264*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward6, [["add"]], False), 265*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]], False), 266*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]], False), 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker # 3 branch cases 269*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']], False), 270*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']], False), 271*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward11, [['add_1'], ['add']], False), 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker # 4 not necessarily the only partition, just to verify that there's no cyclic dependency after partition 274*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward12, [["add_2", "add_3", "add_4"], ["add", "add_1"]], False), 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker # 5 getitem special case 277*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]], False), 278*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward14, [["add", "std_mean", "getitem", "getitem_1"]], False), 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker # 6 bookend non_compute pass 281*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward15, [["permute_1", "add_1", "add"]], True), 282*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward15, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False), 283*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward16, [["permute_1", "add_1", "add"]], True), 284*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward16, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False), 285*da0073e9SAndroid Build Coastguard Worker # should be empty partition, not a partiton with empty nodes 286*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward18, [], False), 287*da0073e9SAndroid Build Coastguard Worker ]) 288*da0073e9SAndroid Build Coastguard Worker def test_partitioner(self, fn, expected_partition, bookend_non_compute_pass): 289*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(fn) 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker non_compute_ops = [] 292*da0073e9SAndroid Build Coastguard Worker if bookend_non_compute_pass: 293*da0073e9SAndroid Build Coastguard Worker non_compute_ops = ["torch.ops.aten.view", "torch.ops.aten.permute"] 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker supported_ops = MockOperatorSupport() 296*da0073e9SAndroid Build Coastguard Worker partitioner = CapabilityBasedPartitioner(traced, 297*da0073e9SAndroid Build Coastguard Worker supported_ops, 298*da0073e9SAndroid Build Coastguard Worker allows_single_node_partition=True, 299*da0073e9SAndroid Build Coastguard Worker non_compute_ops=non_compute_ops) 300*da0073e9SAndroid Build Coastguard Worker partitions = partitioner.propose_partitions() 301*da0073e9SAndroid Build Coastguard Worker if bookend_non_compute_pass: 302*da0073e9SAndroid Build Coastguard Worker partitioner.remove_bookend_non_compute_ops(partitions) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker partitions_name = [[node.name for node in partition.nodes] for partition in partitions] 305*da0073e9SAndroid Build Coastguard Worker assert len(partitions_name) == len(expected_partition) 306*da0073e9SAndroid Build Coastguard Worker for i in range(len(partitions_name)): 307*da0073e9SAndroid Build Coastguard Worker assert set(partitions_name[i]) == set(expected_partition[i]) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker fused_graph = partitioner.fuse_partitions(partitions) 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker a, b, c = torch.rand(4), torch.rand(4), torch.rand(4) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker expected = fn(a, b, c) 314*da0073e9SAndroid Build Coastguard Worker result = fused_graph(a, b, c) 315*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(expected, result) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker @parametrize("fn, expected_partition", [ 318*da0073e9SAndroid Build Coastguard Worker (TestPartitionFunctions.forward17, [['add', 'add_1', 'add_2']]), 319*da0073e9SAndroid Build Coastguard Worker ]) 320*da0073e9SAndroid Build Coastguard Worker def test_partitioner_independent_output(self, fn, expected_partition): 321*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(fn) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker supported_ops = MockOperatorSupport() 324*da0073e9SAndroid Build Coastguard Worker partitioner = CapabilityBasedPartitioner(traced, 325*da0073e9SAndroid Build Coastguard Worker supported_ops, 326*da0073e9SAndroid Build Coastguard Worker allows_single_node_partition=True) 327*da0073e9SAndroid Build Coastguard Worker partitions = partitioner.propose_partitions() 328*da0073e9SAndroid Build Coastguard Worker partitions_name = [[node.name for node in partition.nodes] for partition in partitions] 329*da0073e9SAndroid Build Coastguard Worker assert len(partitions_name) == len(expected_partition) 330*da0073e9SAndroid Build Coastguard Worker for i in range(len(partitions_name)): 331*da0073e9SAndroid Build Coastguard Worker assert set(partitions_name[i]) == set(expected_partition[i]) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker fused_graph = partitioner.fuse_partitions(partitions) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker a, b, c, d, e, f = torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4) 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker expected = fn(a, b, c, d, e, f) 338*da0073e9SAndroid Build Coastguard Worker result = fused_graph(a, b, c, d, e, f) 339*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(expected, result) 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker @parametrize("partition", [ 342*da0073e9SAndroid Build Coastguard Worker [['add', 'add_1'], ['add_5', 'add_6']], 343*da0073e9SAndroid Build Coastguard Worker [['add', 'add_1', 'add_2']], # vertical fusion 344*da0073e9SAndroid Build Coastguard Worker [['add_2', 'add_3']], # horizontal fusion 345*da0073e9SAndroid Build Coastguard Worker [['add_3', 'add_4']], 346*da0073e9SAndroid Build Coastguard Worker [['add_6', 'add_5']], # arbitray node order 347*da0073e9SAndroid Build Coastguard Worker [['add_4', 'add_1', 'add_3', 'add_2']], # arbitray node order 348*da0073e9SAndroid Build Coastguard Worker [['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']], # arbitray partition order 349*da0073e9SAndroid Build Coastguard Worker [['add_5', 'linear2']], # includes call_function + call_module node 350*da0073e9SAndroid Build Coastguard Worker [['add_6', 'relu']], # includes call_function + call_module node 351*da0073e9SAndroid Build Coastguard Worker [['param', 'add_2']], # includes get_attr + call_module nodes 352*da0073e9SAndroid Build Coastguard Worker [['param', 'add_1', 'linear']], # includes get_attr + call_function + call_module nodes 353*da0073e9SAndroid Build Coastguard Worker [["add", "linear", "add_1", "param", "add_2", "add_3", "add_4", "linear2", "add_5", "add_6", "relu"]], # full graph 354*da0073e9SAndroid Build Coastguard Worker ]) 355*da0073e9SAndroid Build Coastguard Worker def test_fuser_util(self, partition): 356*da0073e9SAndroid Build Coastguard Worker m = TestModule() 357*da0073e9SAndroid Build Coastguard Worker gm = symbolic_trace(m) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker nodes_by_name = {node.name : node for node in gm.graph.nodes} 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker partitions = [] 362*da0073e9SAndroid Build Coastguard Worker for node_names in partition: 363*da0073e9SAndroid Build Coastguard Worker partitions.append([nodes_by_name[name] for name in node_names]) 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker fused_graph = fuse_by_partitions(gm, partitions) 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker a, b, c = torch.rand(4), torch.rand(4), torch.rand(4) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker expected = m(a, b, c) 370*da0073e9SAndroid Build Coastguard Worker result = fused_graph(a, b, c) 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(expected, result) 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker @parametrize("partition", [ 375*da0073e9SAndroid Build Coastguard Worker [['add', 'add_1'], ['add_1', 'add_5', 'add_6']], # add_1 exists in multiple partitions 376*da0073e9SAndroid Build Coastguard Worker [['add', 'add_1', 'add_3']], # invalid partition: circular dependency 377*da0073e9SAndroid Build Coastguard Worker [['add_4', 'add_5']], # invalid partition: circular dependency 378*da0073e9SAndroid Build Coastguard Worker [['relu', 'add_5']], # invalid partition: circular dependency 379*da0073e9SAndroid Build Coastguard Worker ]) 380*da0073e9SAndroid Build Coastguard Worker def test_fuser_util_xfail(self, partition): 381*da0073e9SAndroid Build Coastguard Worker m = TestModule() 382*da0073e9SAndroid Build Coastguard Worker gm = symbolic_trace(m) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker nodes_by_name = {node.name : node for node in gm.graph.nodes} 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker partitions = [] 387*da0073e9SAndroid Build Coastguard Worker for node_names in partition: 388*da0073e9SAndroid Build Coastguard Worker partitions.append([nodes_by_name[name] for name in node_names]) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 391*da0073e9SAndroid Build Coastguard Worker fuse_by_partitions(gm, partitions) 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker def test_fuser_pass_deep_model(self): 394*da0073e9SAndroid Build Coastguard Worker m = TestDeepModule() 395*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker supported_ops = MockOperatorSupport() 398*da0073e9SAndroid Build Coastguard Worker partitioner = CapabilityBasedPartitioner(traced, 399*da0073e9SAndroid Build Coastguard Worker supported_ops, 400*da0073e9SAndroid Build Coastguard Worker allows_single_node_partition=True) 401*da0073e9SAndroid Build Coastguard Worker partitions = partitioner.propose_partitions() 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker@dataclass 404*da0073e9SAndroid Build Coastguard Workerclass TestCase: 405*da0073e9SAndroid Build Coastguard Worker match_output: bool 406*da0073e9SAndroid Build Coastguard Worker match_placeholder: bool 407*da0073e9SAndroid Build Coastguard Worker num_matches: int 408*da0073e9SAndroid Build Coastguard Worker remove_overlapping_matches: bool = True 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Workerclass SingleNodePattern: 411*da0073e9SAndroid Build Coastguard Worker @staticmethod 412*da0073e9SAndroid Build Coastguard Worker def forward(x): 413*da0073e9SAndroid Build Coastguard Worker val = torch.neg(x) 414*da0073e9SAndroid Build Coastguard Worker return torch.add(val, val) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker @staticmethod 417*da0073e9SAndroid Build Coastguard Worker def pattern(a): 418*da0073e9SAndroid Build Coastguard Worker return torch.neg(a) 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker test_cases = [ 421*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 422*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 1), 423*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 0), 424*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 1), 425*da0073e9SAndroid Build Coastguard Worker TestCase(True, True, 0) 426*da0073e9SAndroid Build Coastguard Worker ] 427*da0073e9SAndroid Build Coastguard Workerclass SimplePattern: 428*da0073e9SAndroid Build Coastguard Worker @staticmethod 429*da0073e9SAndroid Build Coastguard Worker def forward(x, w1, w2): 430*da0073e9SAndroid Build Coastguard Worker m1 = torch.cat([w1, w2]).sum() 431*da0073e9SAndroid Build Coastguard Worker m2 = torch.cat([w2, w1]).sum() 432*da0073e9SAndroid Build Coastguard Worker m3 = torch.cat([m1, m2]).sum() 433*da0073e9SAndroid Build Coastguard Worker return x + torch.max(m1) + torch.max(m2) + m3 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker @staticmethod 436*da0073e9SAndroid Build Coastguard Worker def pattern(a, b): 437*da0073e9SAndroid Build Coastguard Worker return torch.cat([a, b]).sum() 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker test_cases = [ 440*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 441*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 3), 442*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 0), 443*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 2), 444*da0073e9SAndroid Build Coastguard Worker TestCase(True, True, 0) 445*da0073e9SAndroid Build Coastguard Worker ] 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Workerclass SimpleFullGraphMatching: 448*da0073e9SAndroid Build Coastguard Worker @staticmethod 449*da0073e9SAndroid Build Coastguard Worker def forward(x): 450*da0073e9SAndroid Build Coastguard Worker a = torch.neg(x) 451*da0073e9SAndroid Build Coastguard Worker return torch.add(a, a) 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker @staticmethod 454*da0073e9SAndroid Build Coastguard Worker def pattern(x): 455*da0073e9SAndroid Build Coastguard Worker a = torch.neg(x) 456*da0073e9SAndroid Build Coastguard Worker return torch.add(a, a) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker test_cases = [ 459*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 460*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 1), 461*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 1), 462*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 1), 463*da0073e9SAndroid Build Coastguard Worker TestCase(True, True, 1) 464*da0073e9SAndroid Build Coastguard Worker ] 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Workerclass DiamondShapePatternTestCase: 467*da0073e9SAndroid Build Coastguard Worker @staticmethod 468*da0073e9SAndroid Build Coastguard Worker def forward(x): 469*da0073e9SAndroid Build Coastguard Worker a = torch.neg(x) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker a = a.relu() 472*da0073e9SAndroid Build Coastguard Worker left = a.sigmoid() 473*da0073e9SAndroid Build Coastguard Worker right = a.relu() 474*da0073e9SAndroid Build Coastguard Worker out = left + right 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker return out 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker @staticmethod 479*da0073e9SAndroid Build Coastguard Worker def pattern(a): 480*da0073e9SAndroid Build Coastguard Worker a = a.relu() 481*da0073e9SAndroid Build Coastguard Worker left = a.sigmoid() 482*da0073e9SAndroid Build Coastguard Worker right = a.relu() 483*da0073e9SAndroid Build Coastguard Worker out = left + right 484*da0073e9SAndroid Build Coastguard Worker return out 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker test_cases = [ 487*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 488*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 1), 489*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 1), 490*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 0), 491*da0073e9SAndroid Build Coastguard Worker TestCase(True, True, 0) 492*da0073e9SAndroid Build Coastguard Worker ] 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Workerclass NonFullyContainedMatches: 495*da0073e9SAndroid Build Coastguard Worker @staticmethod 496*da0073e9SAndroid Build Coastguard Worker def forward(x, w1, w2, b1, b2): 497*da0073e9SAndroid Build Coastguard Worker # fully contained matched subgraph 498*da0073e9SAndroid Build Coastguard Worker m1 = torch.cat([w1, w2]) 499*da0073e9SAndroid Build Coastguard Worker m2 = torch.cat([x, b2]) 500*da0073e9SAndroid Build Coastguard Worker t0 = torch.addmm(b1, m1, m2.t()) 501*da0073e9SAndroid Build Coastguard Worker t0_sum = torch.sum(t0) # use of t0 is not leaking 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker # leaking matched subgraph, m3 is leaked 504*da0073e9SAndroid Build Coastguard Worker m3 = torch.cat([w1, w2]) 505*da0073e9SAndroid Build Coastguard Worker m4 = torch.cat([x, b2]) 506*da0073e9SAndroid Build Coastguard Worker t1 = torch.addmm(b1, m3, m4.t()) 507*da0073e9SAndroid Build Coastguard Worker m3_sum = torch.sum(m3) 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker return t0_sum, m3_sum 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker @staticmethod 512*da0073e9SAndroid Build Coastguard Worker def pattern(x, w1, w2, b1, b2): 513*da0073e9SAndroid Build Coastguard Worker m1 = torch.cat([w1, w2]) 514*da0073e9SAndroid Build Coastguard Worker m2 = torch.cat([x, b2]) 515*da0073e9SAndroid Build Coastguard Worker return torch.addmm(b1, m1, m2.t()) 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker test_cases = [ 518*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 519*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 1), 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 0), 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 1), # leaked used of placeholder is not leaking 524*da0073e9SAndroid Build Coastguard Worker ] 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Workerclass ChainRepeatedPattern: 527*da0073e9SAndroid Build Coastguard Worker @staticmethod 528*da0073e9SAndroid Build Coastguard Worker def forward(x): 529*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 530*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 531*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 532*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(x) 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker @staticmethod 535*da0073e9SAndroid Build Coastguard Worker def pattern(x): 536*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(torch.sigmoid(x)) 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker test_cases = [ 539*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 540*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 3, remove_overlapping_matches=False), 541*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 2, remove_overlapping_matches=True), 542*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 1), 543*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 1), 544*da0073e9SAndroid Build Coastguard Worker TestCase(True, True, 0) 545*da0073e9SAndroid Build Coastguard Worker ] 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Workerclass QuantizationModel: 548*da0073e9SAndroid Build Coastguard Worker @staticmethod 549*da0073e9SAndroid Build Coastguard Worker def forward(x): 550*da0073e9SAndroid Build Coastguard Worker x += 3 551*da0073e9SAndroid Build Coastguard Worker x = x.dequantize() 552*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 553*da0073e9SAndroid Build Coastguard Worker x = x.to(torch.float16) 554*da0073e9SAndroid Build Coastguard Worker return x 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker @staticmethod 557*da0073e9SAndroid Build Coastguard Worker def pattern(x): 558*da0073e9SAndroid Build Coastguard Worker x = x.dequantize() 559*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(x) 560*da0073e9SAndroid Build Coastguard Worker x = x.to(torch.float16) 561*da0073e9SAndroid Build Coastguard Worker return x 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker test_cases = [ 564*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 565*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 1), 566*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 1), 567*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 0), 568*da0073e9SAndroid Build Coastguard Worker TestCase(True, True, 0) 569*da0073e9SAndroid Build Coastguard Worker ] 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Workerclass MultipleOutputsWithDependency: 572*da0073e9SAndroid Build Coastguard Worker @staticmethod 573*da0073e9SAndroid Build Coastguard Worker def forward(x): 574*da0073e9SAndroid Build Coastguard Worker y = x.relu() 575*da0073e9SAndroid Build Coastguard Worker z = y.sigmoid() 576*da0073e9SAndroid Build Coastguard Worker return z, y 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker @staticmethod 579*da0073e9SAndroid Build Coastguard Worker def pattern(a): 580*da0073e9SAndroid Build Coastguard Worker b = a.relu() 581*da0073e9SAndroid Build Coastguard Worker c = b.sigmoid() 582*da0073e9SAndroid Build Coastguard Worker return b, c # outputs have data dependency 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker test_cases = [ 585*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 586*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 1), 587*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 0), 588*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 1), 589*da0073e9SAndroid Build Coastguard Worker TestCase(True, True, 0) 590*da0073e9SAndroid Build Coastguard Worker ] 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Workerclass MultipleOutputsWithoutDependency: 593*da0073e9SAndroid Build Coastguard Worker @staticmethod 594*da0073e9SAndroid Build Coastguard Worker def forward(x): 595*da0073e9SAndroid Build Coastguard Worker x = x + 1 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker # target subgraph to match 598*da0073e9SAndroid Build Coastguard Worker x = x.relu() 599*da0073e9SAndroid Build Coastguard Worker z = x.sum() 600*da0073e9SAndroid Build Coastguard Worker y = x.sigmoid() 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker out = y.sigmoid() + z.sum() 603*da0073e9SAndroid Build Coastguard Worker return out 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker @staticmethod 606*da0073e9SAndroid Build Coastguard Worker def pattern(a): 607*da0073e9SAndroid Build Coastguard Worker a = a.relu() 608*da0073e9SAndroid Build Coastguard Worker b = a.sigmoid() 609*da0073e9SAndroid Build Coastguard Worker c = a.sum() 610*da0073e9SAndroid Build Coastguard Worker return b, c 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker test_cases = [ 613*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 614*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 1), 615*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 0), 616*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 0), 617*da0073e9SAndroid Build Coastguard Worker TestCase(True, True, 0) 618*da0073e9SAndroid Build Coastguard Worker ] 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Workerclass MultipleOutputsMultipleOverlappingMatches: 621*da0073e9SAndroid Build Coastguard Worker @staticmethod 622*da0073e9SAndroid Build Coastguard Worker def forward(x): 623*da0073e9SAndroid Build Coastguard Worker x = x + 1 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker # target subgraph to match 626*da0073e9SAndroid Build Coastguard Worker x = x.relu() 627*da0073e9SAndroid Build Coastguard Worker z = x.sum() 628*da0073e9SAndroid Build Coastguard Worker z1 = x.sum() 629*da0073e9SAndroid Build Coastguard Worker y = x.sigmoid() 630*da0073e9SAndroid Build Coastguard Worker y1 = x.sigmoid() 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker return z + z1 + y + y1 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker @staticmethod 635*da0073e9SAndroid Build Coastguard Worker def pattern(a): 636*da0073e9SAndroid Build Coastguard Worker a = a.relu() 637*da0073e9SAndroid Build Coastguard Worker b = a.sigmoid() 638*da0073e9SAndroid Build Coastguard Worker c = a.sum() 639*da0073e9SAndroid Build Coastguard Worker return a, b, c 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker test_cases = [ 642*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 643*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 4, remove_overlapping_matches=False), 644*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 1, remove_overlapping_matches=True), 645*da0073e9SAndroid Build Coastguard Worker ] 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Workerclass MultipleOutputsMultipleNonOverlappingMatches: 648*da0073e9SAndroid Build Coastguard Worker @staticmethod 649*da0073e9SAndroid Build Coastguard Worker def forward(x): 650*da0073e9SAndroid Build Coastguard Worker x = x + 1 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker # target subgraph to match 653*da0073e9SAndroid Build Coastguard Worker x = x.relu() 654*da0073e9SAndroid Build Coastguard Worker z = x.sum() 655*da0073e9SAndroid Build Coastguard Worker y = x.sigmoid() 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker x = x.relu() 658*da0073e9SAndroid Build Coastguard Worker z1 = x.sum() 659*da0073e9SAndroid Build Coastguard Worker y1 = x.sigmoid() 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker return z + z1 + y + y1 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker @staticmethod 664*da0073e9SAndroid Build Coastguard Worker def pattern(a): 665*da0073e9SAndroid Build Coastguard Worker a = a.relu() 666*da0073e9SAndroid Build Coastguard Worker b = a.sigmoid() 667*da0073e9SAndroid Build Coastguard Worker c = a.sum() 668*da0073e9SAndroid Build Coastguard Worker return b, c 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker test_cases = [ 671*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 672*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 1), 673*da0073e9SAndroid Build Coastguard Worker ] 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Workerclass MultipleOutputsIdenticalAnchor: 676*da0073e9SAndroid Build Coastguard Worker @staticmethod 677*da0073e9SAndroid Build Coastguard Worker def forward(x): 678*da0073e9SAndroid Build Coastguard Worker x = x + 1 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker # target subgraph to match 681*da0073e9SAndroid Build Coastguard Worker x = x.relu() 682*da0073e9SAndroid Build Coastguard Worker y = x.sigmoid() 683*da0073e9SAndroid Build Coastguard Worker y1 = x.sigmoid() 684*da0073e9SAndroid Build Coastguard Worker 685*da0073e9SAndroid Build Coastguard Worker return y, y1 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker @staticmethod 688*da0073e9SAndroid Build Coastguard Worker def pattern(a): 689*da0073e9SAndroid Build Coastguard Worker a = a.relu() 690*da0073e9SAndroid Build Coastguard Worker b = a.sigmoid() 691*da0073e9SAndroid Build Coastguard Worker b1 = a.sigmoid() 692*da0073e9SAndroid Build Coastguard Worker return b, b1 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker test_cases = [ 695*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 696*da0073e9SAndroid Build Coastguard Worker # (False, False, 2), # FIXME: currently still matches to 2, should fix to 1 697*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 1), 698*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 0), 699*da0073e9SAndroid Build Coastguard Worker ] 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker 702*da0073e9SAndroid Build Coastguard Workerclass MultipleOutputsHorizontalPattern: 703*da0073e9SAndroid Build Coastguard Worker @staticmethod 704*da0073e9SAndroid Build Coastguard Worker def forward(x): 705*da0073e9SAndroid Build Coastguard Worker x = x + 1 706*da0073e9SAndroid Build Coastguard Worker 707*da0073e9SAndroid Build Coastguard Worker # target subgraph to match 708*da0073e9SAndroid Build Coastguard Worker y1 = x.relu() 709*da0073e9SAndroid Build Coastguard Worker y2 = x.sigmoid() 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker return y1, y2 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Worker @staticmethod 714*da0073e9SAndroid Build Coastguard Worker def pattern(a): 715*da0073e9SAndroid Build Coastguard Worker b1 = a.relu() 716*da0073e9SAndroid Build Coastguard Worker b2 = a.sigmoid() 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Worker return b1, b2 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker test_cases = [ 721*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 722*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 1), 723*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 1), 724*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 0), 725*da0073e9SAndroid Build Coastguard Worker TestCase(True, True, 0) 726*da0073e9SAndroid Build Coastguard Worker ] 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Workerclass MultiOutputWithWithInvalidMatches: 729*da0073e9SAndroid Build Coastguard Worker @staticmethod 730*da0073e9SAndroid Build Coastguard Worker def forward(x): 731*da0073e9SAndroid Build Coastguard Worker res0 = torch.nn.functional.linear(x, torch.rand(3, 3)) 732*da0073e9SAndroid Build Coastguard Worker res1 = torch.sigmoid(res0) 733*da0073e9SAndroid Build Coastguard Worker res2 = res0 * res1 734*da0073e9SAndroid Build Coastguard Worker res3 = torch.sum(res2, dim=1) 735*da0073e9SAndroid Build Coastguard Worker return res3 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker @staticmethod 738*da0073e9SAndroid Build Coastguard Worker def pattern(a, b, c): 739*da0073e9SAndroid Build Coastguard Worker lin_res = torch.nn.functional.linear(a, b) 740*da0073e9SAndroid Build Coastguard Worker mul_res = lin_res * c 741*da0073e9SAndroid Build Coastguard Worker return lin_res, mul_res 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker test_cases = [ 744*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 745*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 0), 746*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 0), 747*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 0), 748*da0073e9SAndroid Build Coastguard Worker ] 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Workerclass QuantizationFp8Pattern: 751*da0073e9SAndroid Build Coastguard Worker @classmethod 752*da0073e9SAndroid Build Coastguard Worker def setup(cls): 753*da0073e9SAndroid Build Coastguard Worker cls.quantization = torch.library.Library("fp8_quantization", "DEF") # noqa: TOR901 754*da0073e9SAndroid Build Coastguard Worker cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor") 755*da0073e9SAndroid Build Coastguard Worker cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor") 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker @classmethod 758*da0073e9SAndroid Build Coastguard Worker def tearDown(cls): 759*da0073e9SAndroid Build Coastguard Worker del cls.quantization 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker @staticmethod 762*da0073e9SAndroid Build Coastguard Worker def forward(self, arg0_1, arg1_1): 763*da0073e9SAndroid Build Coastguard Worker qt = torch.ops.fp8_quantization 764*da0073e9SAndroid Build Coastguard Worker _scale_0 = self._scale_0 765*da0073e9SAndroid Build Coastguard Worker quantize_per_tensor_affine_fp8 = qt.quantize_per_tensor_affine_fp8(arg0_1, 0, _scale_0) 766*da0073e9SAndroid Build Coastguard Worker dequantize_per_tensor_affine_fp8 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8, 0, _scale_0) 767*da0073e9SAndroid Build Coastguard Worker _scale_1 = self._scale_0 768*da0073e9SAndroid Build Coastguard Worker quantize_per_tensor_affine_fp8_1 = qt.quantize_per_tensor_affine_fp8(arg1_1, 0, _scale_1) 769*da0073e9SAndroid Build Coastguard Worker dequantize_per_tensor_affine_fp8_1 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_1, 0, _scale_1) 770*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(dequantize_per_tensor_affine_fp8, dequantize_per_tensor_affine_fp8_1) 771*da0073e9SAndroid Build Coastguard Worker _scale_2 = self._scale_0 772*da0073e9SAndroid Build Coastguard Worker quantize_per_tensor_affine_fp8_2 = qt.quantize_per_tensor_affine_fp8(add, 0, _scale_2) 773*da0073e9SAndroid Build Coastguard Worker dequantize_per_tensor_affine_fp8_2 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_2, 0, _scale_2) 774*da0073e9SAndroid Build Coastguard Worker return dequantize_per_tensor_affine_fp8_2 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker @staticmethod 777*da0073e9SAndroid Build Coastguard Worker def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale): 778*da0073e9SAndroid Build Coastguard Worker qt = torch.ops.fp8_quantization 779*da0073e9SAndroid Build Coastguard Worker a = qt.dequantize_per_tensor_affine_fp8(a, a_dtype, a_scale) 780*da0073e9SAndroid Build Coastguard Worker b = qt.dequantize_per_tensor_affine_fp8(b, b_dtype, b_scale) 781*da0073e9SAndroid Build Coastguard Worker output = torch.ops.aten.add.Tensor(a, b) 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker qt.dequantize_per_tensor_affine_fp8 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker output = qt.quantize_per_tensor_affine_fp8(output, a_dtype, out_scale) 786*da0073e9SAndroid Build Coastguard Worker return output 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker test_cases = [ 789*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 790*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 1), 791*da0073e9SAndroid Build Coastguard Worker ] 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Workerclass NoAnchorFound: 794*da0073e9SAndroid Build Coastguard Worker # This test case is for pattern where no matching anchor is found in the target graph 795*da0073e9SAndroid Build Coastguard Worker # `anchor` is the starting point of the pattern matching, it's usually the boundary returning nodes 796*da0073e9SAndroid Build Coastguard Worker @staticmethod 797*da0073e9SAndroid Build Coastguard Worker def forward(x): 798*da0073e9SAndroid Build Coastguard Worker x = x + 1 799*da0073e9SAndroid Build Coastguard Worker return x 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker @staticmethod 802*da0073e9SAndroid Build Coastguard Worker def pattern(a): 803*da0073e9SAndroid Build Coastguard Worker b1 = a.relu() 804*da0073e9SAndroid Build Coastguard Worker return b1 805*da0073e9SAndroid Build Coastguard Worker 806*da0073e9SAndroid Build Coastguard Worker test_cases = [ 807*da0073e9SAndroid Build Coastguard Worker # match_output, match_placeholder, num_matches 808*da0073e9SAndroid Build Coastguard Worker TestCase(False, False, 0), 809*da0073e9SAndroid Build Coastguard Worker TestCase(True, False, 0), 810*da0073e9SAndroid Build Coastguard Worker TestCase(False, True, 0), 811*da0073e9SAndroid Build Coastguard Worker TestCase(True, True, 0) 812*da0073e9SAndroid Build Coastguard Worker ] 813*da0073e9SAndroid Build Coastguard Worker 814*da0073e9SAndroid Build Coastguard Worker@instantiate_parametrized_tests 815*da0073e9SAndroid Build Coastguard Workerclass TestFXMatcherUtils(JitTestCase): 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker @parametrize("test_model", [ 818*da0073e9SAndroid Build Coastguard Worker SingleNodePattern, 819*da0073e9SAndroid Build Coastguard Worker SimplePattern, 820*da0073e9SAndroid Build Coastguard Worker SimpleFullGraphMatching, 821*da0073e9SAndroid Build Coastguard Worker DiamondShapePatternTestCase, 822*da0073e9SAndroid Build Coastguard Worker NonFullyContainedMatches, 823*da0073e9SAndroid Build Coastguard Worker ChainRepeatedPattern, 824*da0073e9SAndroid Build Coastguard Worker QuantizationModel, 825*da0073e9SAndroid Build Coastguard Worker MultipleOutputsWithDependency, 826*da0073e9SAndroid Build Coastguard Worker MultipleOutputsWithoutDependency, 827*da0073e9SAndroid Build Coastguard Worker MultipleOutputsMultipleOverlappingMatches, 828*da0073e9SAndroid Build Coastguard Worker MultipleOutputsMultipleNonOverlappingMatches, 829*da0073e9SAndroid Build Coastguard Worker MultipleOutputsIdenticalAnchor, 830*da0073e9SAndroid Build Coastguard Worker MultipleOutputsHorizontalPattern, 831*da0073e9SAndroid Build Coastguard Worker MultiOutputWithWithInvalidMatches, 832*da0073e9SAndroid Build Coastguard Worker QuantizationFp8Pattern, 833*da0073e9SAndroid Build Coastguard Worker NoAnchorFound, 834*da0073e9SAndroid Build Coastguard Worker ]) 835*da0073e9SAndroid Build Coastguard Worker def test_subgraph_matcher(self, test_model): 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker setup = getattr(test_model, "setup", None) 838*da0073e9SAndroid Build Coastguard Worker if callable(setup): 839*da0073e9SAndroid Build Coastguard Worker setup() 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(test_model.forward) 842*da0073e9SAndroid Build Coastguard Worker pattern_traced = symbolic_trace(test_model.pattern) 843*da0073e9SAndroid Build Coastguard Worker 844*da0073e9SAndroid Build Coastguard Worker for test_case in test_model.test_cases: 845*da0073e9SAndroid Build Coastguard Worker 846*da0073e9SAndroid Build Coastguard Worker matcher = SubgraphMatcher(pattern_traced.graph, 847*da0073e9SAndroid Build Coastguard Worker match_output=test_case.match_output, 848*da0073e9SAndroid Build Coastguard Worker match_placeholder=test_case.match_placeholder, 849*da0073e9SAndroid Build Coastguard Worker remove_overlapping_matches=test_case.remove_overlapping_matches) 850*da0073e9SAndroid Build Coastguard Worker matches = matcher.match(traced.graph) 851*da0073e9SAndroid Build Coastguard Worker 852*da0073e9SAndroid Build Coastguard Worker assert len(matches) == test_case.num_matches 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker for match in matches: 855*da0073e9SAndroid Build Coastguard Worker for node in pattern_traced.graph.nodes: 856*da0073e9SAndroid Build Coastguard Worker if not test_case.match_placeholder and node.op == "placeholder": 857*da0073e9SAndroid Build Coastguard Worker continue 858*da0073e9SAndroid Build Coastguard Worker if not test_case.match_output and node.op == "output": 859*da0073e9SAndroid Build Coastguard Worker continue 860*da0073e9SAndroid Build Coastguard Worker assert node in match.nodes_map 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker tearDown = getattr(test_model, "tearDown", None) 863*da0073e9SAndroid Build Coastguard Worker if callable(setup): 864*da0073e9SAndroid Build Coastguard Worker tearDown() 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 868*da0073e9SAndroid Build Coastguard Worker run_tests() 869