1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3 4import unittest 5 6import executorch.backends.cadence.aot.ops_registrations # noqa 7import torch 8from executorch.backends.cadence.aot.compiler import ( 9 export_to_edge, 10 quantize_and_export_to_cadence, 11) 12from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass 13from executorch.backends.cadence.aot.pass_utils import ( 14 count_node, 15 get_compute_nodes_in_gm, 16 nodes_not_adjacent_in_gm, 17 nodes_not_connected_in_gm, 18) 19from executorch.backends.cadence.aot.reorder_ops import ( 20 AdvanceQuantizeOpAboveDefInBranchPass, 21 PostponeDequantizeOpBelowUseChainPass, 22 PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, 23) 24from executorch.exir.dialects._ops import ops as exir_ops 25 26 27class TestReorderPasses(unittest.TestCase): 28 def test_sink_dequantize(self): 29 class M(torch.nn.Module): 30 def __init__(self): 31 super().__init__() 32 self.linear = torch.nn.Linear(6, 12, bias=False) 33 34 def forward(self, x, y): 35 x1 = self.linear(x) 36 y1 = self.linear(y) 37 x2 = torch.ops.aten.abs(x1) 38 return torch.ops.aten.cat((x2, y1)) 39 40 inputs = (torch.randn(32, 6), torch.randn(32, 6)) 41 graph_module = ( 42 quantize_and_export_to_cadence(M(), inputs).exported_program().graph_module 43 ) 44 # Expect the SinkDequant pass to move dequant(y) from above the relu to just below it 45 self.assertTrue( 46 nodes_not_adjacent_in_gm( 47 graph_module, 48 exir_ops.edge.aten.abs.default, 49 exir_ops.edge.aten.cat.default, 50 ), 51 ) 52 self.assertTrue( 53 nodes_not_adjacent_in_gm( 54 graph_module, 55 exir_ops.edge.cadence.dequantize_per_tensor.default, 56 exir_ops.edge.cadence.dequantize_per_tensor.default, 57 ), 58 ) 59 60 def test_advance_branched_quantize(self): 61 class ReorderOpsBranch(torch.nn.Module): 62 def forward(self, x): 63 x = x.view((32, 6)) 64 x1 = torch.slice_copy(x, dim=0, start=0, end=6, step=1) 65 x1 = exir_ops.edge.quantized_decomposed.quantize_per_tensor( 66 x1, 0.1, 10, 0, 255, torch.uint8 67 ) 68 x2 = torch.slice_copy(x, dim=0, start=6, end=12, step=1) 69 x2 = exir_ops.edge.quantized_decomposed.quantize_per_tensor( 70 x2, 0.1, 10, 0, 255, torch.uint8 71 ) 72 x3 = torch.slice_copy(x, dim=0, start=12, end=18, step=1) 73 x3 = exir_ops.edge.quantized_decomposed.quantize_per_tensor( 74 x3, 0.1, 10, 0, 255, torch.uint8 75 ) 76 x4 = torch.slice_copy(x, dim=0, start=18, end=24, step=1) 77 x4 = exir_ops.edge.quantized_decomposed.quantize_per_tensor( 78 x4, 0.2, 4, 0, 255, torch.uint8 79 ) 80 return (x1, x2, x3, x4) 81 82 model = ReorderOpsBranch() 83 X = torch.randn(64, 3) 84 graph_module = export_to_edge(model, (X,)).exported_program().graph_module 85 graph_module = AdvanceQuantizeOpAboveDefInBranchPass()( 86 graph_module 87 ).graph_module 88 graph_module.graph.eliminate_dead_code() 89 nodes = get_compute_nodes_in_gm(graph_module) 90 # The quantize op should be hoisted to dominate the branch 91 self.assertTrue( 92 nodes[0] == exir_ops.edge.quantized_decomposed.quantize_per_tensor 93 ) 94 # There should be 5 quantize ops: the 4 originally present in the model, 95 # and the one that was hoisted above the slices 96 self.assertEqual( 97 count_node( 98 graph_module, 99 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 100 ), 101 5, 102 ) 103 # Ensure none of the slice nodes were erroneously removed 104 self.assertEqual( 105 count_node( 106 graph_module, 107 exir_ops.edge.aten.slice_copy.Tensor, 108 ), 109 4, 110 ) 111 # Each of the 4 original quant ops should now be paired with a dequant op 112 self.assertEqual( 113 count_node( 114 graph_module, 115 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 116 ), 117 4, 118 ) 119 graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module 120 # We expect 3 dequant/quant pairs to be removed because they have matching params, 121 # leaving a single dequant/quant pair that is then merged into a requantize op 122 self.assertEqual( 123 count_node( 124 graph_module, 125 exir_ops.edge.cadence.requantize.default, 126 ), 127 1, 128 ) 129 130 @torch.no_grad() 131 def test_advance_quantize(self): 132 class ReorderOpsChain(torch.nn.Module): 133 def __init__(self): 134 super().__init__() 135 self.linear = torch.nn.Linear(6, 12, bias=False) 136 137 def forward(self, x): 138 x = x.permute([1, 0, 3, 2]) 139 x = self.linear(x) 140 return x 141 142 model = ReorderOpsChain() 143 X = torch.randn(16, 1, 6, 32) 144 145 graph_module = ( 146 quantize_and_export_to_cadence(model, (X,)).exported_program().graph_module 147 ) 148 # Assert that the quant node is no longer the successor of 149 # permute node. 150 self.assertTrue( 151 nodes_not_connected_in_gm( 152 graph_module, 153 exir_ops.edge.aten.permute_copy.default, 154 exir_ops.edge.cadence.quantize_per_tensor.default, 155 ), 156 ) 157 # Assert that permute node is the successor of quant node 158 self.assertFalse( 159 nodes_not_connected_in_gm( 160 graph_module, 161 exir_ops.edge.cadence.quantize_per_tensor.default, 162 exir_ops.edge.aten.permute_copy.default, 163 ), 164 ) 165 166 def test_postpone_dequantize(self): 167 class ReorderOpsChain(torch.nn.Module): 168 def __init__(self): 169 super().__init__() 170 self.linear = torch.nn.Linear(6, 12, bias=False) 171 172 def forward(self, x): 173 x = self.linear(x) 174 x = x.permute([1, 0, 3, 2]) 175 return x 176 177 model = ReorderOpsChain() 178 X = torch.randn(1, 16, 32, 6) 179 180 graph_module = ( 181 quantize_and_export_to_cadence(model, (X,)).exported_program().graph_module 182 ) 183 # Assert that the dequant node is no longer the predecessor of the permute node 184 self.assertTrue( 185 nodes_not_connected_in_gm( 186 graph_module, 187 exir_ops.edge.cadence.dequantize_per_tensor.default, 188 exir_ops.edge.aten.permute_copy.default, 189 ), 190 ) 191 # Assert that dequant node is the successor of permute node 192 self.assertFalse( 193 nodes_not_connected_in_gm( 194 graph_module, 195 exir_ops.edge.aten.permute_copy.default, 196 exir_ops.edge.cadence.dequantize_per_tensor.default, 197 ), 198 ) 199 200 def test_postpone_dequantize_branched(self): 201 class ReorderOpsBranch(torch.nn.Module): 202 def __init__(self): 203 super().__init__() 204 self.linear = torch.nn.Linear(3, 12, bias=False) 205 206 def forward(self, x): 207 x0 = exir_ops.edge.quantized_decomposed.dequantize_per_tensor( 208 x, 0.1, 10, 0, 255, torch.uint8 209 ) 210 x0 = torch.squeeze(x0, 0) 211 x1 = torch.slice_copy(x0, dim=0, start=0, end=6, step=1) 212 x1 = self.linear(x1) 213 214 x2 = torch.slice_copy(x0, dim=0, start=6, end=12, step=1) 215 x2 = self.linear(x2) 216 217 x3 = torch.slice_copy(x0, dim=0, start=12, end=18, step=1) 218 x3 = self.linear(x3) 219 220 return (x1, x2, x3) 221 222 model = ReorderOpsBranch() 223 X = torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8) 224 graph_module = export_to_edge(model, (X,)).exported_program().graph_module 225 graph_module = PostponeDequantizeOpBelowUseChainPass()( 226 graph_module 227 ).graph_module 228 graph_module.graph.eliminate_dead_code() 229 230 # Asset that the dequant node was split into 4, one per branch 231 self.assertEqual( 232 count_node( 233 graph_module, 234 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 235 ), 236 3, 237 ) 238 239 # Assert that the dequant node is no longer the predecessor of the squeeze node 240 self.assertTrue( 241 nodes_not_connected_in_gm( 242 graph_module, 243 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 244 exir_ops.edge.aten.squeeze_copy.dims, 245 ), 246 ) 247 # Assert that dequant node is not predecessor of slice (it should've been moved below slice) 248 self.assertTrue( 249 nodes_not_connected_in_gm( 250 graph_module, 251 exir_ops.edge.cadence.dequantize_per_tensor.default, 252 exir_ops.edge.aten.slice_copy.Tensor, 253 ), 254 ) 255 256 # 4d -> permute -> 4d -> view -> 3d 257 def test_permute3_view4_chains(self): 258 class PermuteViewChain(torch.nn.Module): 259 def forward(self, x): 260 # x is [3, 1, 768] 261 x = x.view((3, 12, 64)) 262 # x is [3, 12, 64] 263 x = x.permute([1, 0, 2]) 264 # x is [12, 3, 64] 265 x = x.view((1, 12, 3, 64)) 266 # x is [1, 12, 3, 64] 267 x = x.permute([0, 1, 3, 2]) 268 # x is [1, 12, 64, 3] 269 return x 270 271 model = PermuteViewChain() 272 X = torch.randn(3, 1, 768) 273 graph_module = export_to_edge(model, (X,)).exported_program().graph_module 274 275 # Performing transform 276 graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( 277 graph_module 278 ).graph_module 279 graph_module.graph.eliminate_dead_code() 280 281 # Assert the order becomes view, view, permute, permute 282 nodes = get_compute_nodes_in_gm(graph_module) 283 self.assertEqual(len(nodes), 4) 284 self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy) 285 self.assertTrue(nodes[1] == exir_ops.edge.aten.view_copy) 286 self.assertTrue(nodes[2] == exir_ops.edge.aten.permute_copy) 287 self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy) 288 289 # 3d -> permute -> 3d -> view -> 4d 290 def test_permute4_view3_chains(self): 291 class PermuteViewChain(torch.nn.Module): 292 def forward(self, x): 293 # x is [3, 1, 768] 294 x = x.view((1, 3, 12, 64)) 295 # x is [1, 3, 12, 64] 296 x = x.permute([3, 1, 0, 2]) 297 # x is [64, 3, 1, 12] 298 x = x.view((64, 3, 12)) 299 # x is [64, 3, 12] 300 x = x.permute([2, 1, 0]) 301 # x is [12, 3, 64] 302 return x 303 304 model = PermuteViewChain() 305 X = torch.randn(3, 1, 768) 306 graph_module = export_to_edge(model, (X,)).exported_program().graph_module 307 308 # Performing transform 309 graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( 310 graph_module 311 ).graph_module 312 graph_module.graph.eliminate_dead_code() 313 314 # Assert the order becomes view, view, permute, permute 315 nodes = get_compute_nodes_in_gm(graph_module) 316 self.assertEqual(len(nodes), 4) 317 self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy) 318 self.assertTrue(nodes[1] == exir_ops.edge.aten.view_copy) 319 self.assertTrue(nodes[2] == exir_ops.edge.aten.permute_copy) 320 self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy) 321 322 # Negative test case where the transform should not happen. 323 # permute->4d->view->3d where the view not only removes the dimension whose 324 # size is 1 (this is ok), but also changes the size of the dimensions (not ok). 325 def test_permute_view_chains_neg(self): 326 class PermuteViewChain(torch.nn.Module): 327 def forward(self, x): 328 # x is [3, 1, 768] 329 x = x.view((1, 3, 12, 64)) 330 # x is [1, 3, 12, 64] 331 x = x.permute([3, 1, 0, 2]) 332 # x is [64, 3, 1, 12] 333 x = x.view((64, 6, 6)) 334 # x is [64, 6, 6] 335 x = x.permute([2, 1, 0]) 336 # x is [6, 6, 64] 337 return x 338 339 model = PermuteViewChain() 340 X = torch.randn(3, 1, 768) 341 graph_module = export_to_edge(model, (X,)).exported_program().graph_module 342 343 # Performing transform (nothing should happen) 344 graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( 345 graph_module 346 ).graph_module 347 graph_module.graph.eliminate_dead_code() 348 349 # Assert the order is still view, permute, view, permute 350 nodes = get_compute_nodes_in_gm(graph_module) 351 self.assertEqual(len(nodes), 4) 352 self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy) 353 self.assertTrue(nodes[1] == exir_ops.edge.aten.permute_copy) 354 self.assertTrue(nodes[2] == exir_ops.edge.aten.view_copy) 355 self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy) 356