1# Owner(s): ["oncall: jit"] 2 3import torch 4from torch.testing import FileCheck 5from torch.testing._internal.jit_utils import JitTestCase 6 7 8if __name__ == "__main__": 9 raise RuntimeError( 10 "This test file is not meant to be run directly, use:\n\n" 11 "\tpython test/test_jit.py TESTNAME\n\n" 12 "instead." 13 ) 14 15 16class TestBatchMM(JitTestCase): 17 @staticmethod 18 def _get_test_tensors(n: int): 19 return [ 20 torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]]) 21 if x % 2 == 0 22 else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]]) 23 for x in range(n) 24 ] 25 26 def test_batch_mm_no_mutation(self): 27 def test_batch_mm( 28 T1: torch.Tensor, 29 T2: torch.Tensor, 30 T3: torch.Tensor, 31 T4: torch.Tensor, 32 T5: torch.Tensor, 33 T6: torch.Tensor, 34 T7: torch.Tensor, 35 T8: torch.Tensor, 36 ): 37 return ( 38 torch.mm(T1, T2) 39 + torch.mm(T3, T4) 40 + torch.mm(T5, T6) 41 + torch.mm(T7, T8) 42 ) 43 44 test_batch_mm_scripted = torch.jit.script(test_batch_mm) 45 46 tensors = TestBatchMM._get_test_tensors(8) 47 expected = test_batch_mm(*tensors) 48 49 FileCheck().check_count("aten::mm", 4, exactly=True).run( 50 test_batch_mm_scripted.graph 51 ) 52 self.run_pass("batch_mm", test_batch_mm_scripted.graph) 53 FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run( 54 test_batch_mm_scripted.graph 55 ) 56 57 actual = test_batch_mm_scripted(*tensors) 58 self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9) 59 60 def test_batch_mm_permitted_mutation(self): 61 def test_batch_mm( 62 T1: torch.Tensor, 63 T2: torch.Tensor, 64 T3: torch.Tensor, 65 T4: torch.Tensor, 66 T5: torch.Tensor, 67 T6: torch.Tensor, 68 T7: torch.Tensor, 69 T8: torch.Tensor, 70 ): 71 result = {} 72 result["product"] = ( 73 torch.mm(T1, T2) 74 + torch.mm(T3, T4) 75 + torch.mm(T5, T6) 76 + torch.mm(T7, T8) 77 ) 78 result["constant"] = torch.tensor([42.0]) 79 return result 80 81 test_batch_mm_scripted = torch.jit.script(test_batch_mm) 82 83 tensors = TestBatchMM._get_test_tensors(8) 84 expected = test_batch_mm(*tensors) 85 86 FileCheck().check_count("aten::mm", 4, exactly=True).run( 87 test_batch_mm_scripted.graph 88 ) 89 self.run_pass("batch_mm", test_batch_mm_scripted.graph) 90 FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run( 91 test_batch_mm_scripted.graph 92 ) 93 94 actual = test_batch_mm_scripted(*tensors) 95 self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9) 96 97 def test_batch_mm_prohibited_mutation(self): 98 @torch.jit.script 99 def test_batch_mm(n: int): 100 T1 = torch.zeros((n, n)) 101 T2 = torch.zeros((n, n)) 102 T3 = torch.zeros((n, n)) 103 T4 = torch.zeros((n, n)) 104 T5 = torch.zeros((n, n)) 105 T6 = torch.zeros((n, n)) 106 T7 = torch.zeros((n, n)) 107 T8 = torch.zeros((n, n)) 108 torch.relu_(T1) 109 result = ( 110 torch.mm(T1, T2) 111 + torch.mm(T3, T4) 112 + torch.mm(T5, T6) 113 + torch.mm(T7, T8) 114 ) 115 return result 116 117 FileCheck().check_count("aten::mm", 4, exactly=True).run(test_batch_mm.graph) 118 self.run_pass("batch_mm", test_batch_mm.graph) 119 FileCheck().check_count("aten::mm", 4, exactly=True).check_not( 120 "prim::MMTreeReduce" 121 ).run(test_batch_mm.graph) 122 123 def test_batch_mm_prohibited_mutation_multiple_adds(self): 124 @torch.jit.script 125 def test_batch_mm(n: int): 126 T1 = torch.zeros((n, n)) 127 T2 = torch.zeros((n, n)) 128 T3 = torch.zeros((n, n)) 129 T4 = torch.zeros((n, n)) 130 T5 = torch.zeros((n, n)) 131 T6 = torch.zeros((n, n)) 132 T7 = torch.zeros((n, n)) 133 T8 = torch.zeros((n, n)) 134 T9 = torch.zeros((n, n)) 135 T10 = torch.zeros((n, n)) 136 torch.relu_(T1) 137 result = {} 138 result["no_mutated_parameters"] = ( 139 torch.mm(T2, T3) 140 + torch.mm(T4, T5) 141 + torch.mm(T6, T7) 142 + torch.mm(T8, T9) 143 ) 144 result["all_parameters"] = ( 145 torch.mm(T1, T2) 146 + torch.mm(T3, T4) 147 + torch.mm(T5, T6) 148 + torch.mm(T7, T8) 149 + torch.mm(T9, T10) 150 ) 151 return result 152 153 self.run_pass("batch_mm", test_batch_mm.graph) 154 FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).check_count( 155 "aten::mm", 5, exactly=True 156 ).run(test_batch_mm.graph) 157 158 def test_batch_mm_prohibited_mutation_if_node(self): 159 @torch.jit.script 160 def test_batch_mm(n: int, use_t1: bool): 161 T1 = torch.zeros((n, n)) 162 T2 = torch.zeros((n, n)) 163 T3 = torch.zeros((n, n)) 164 T4 = torch.zeros((n, n)) 165 T5 = torch.zeros((n, n)) 166 T6 = torch.zeros((n, n)) 167 T7 = torch.zeros((n, n)) 168 T8 = torch.zeros((n, n)) 169 T9 = torch.zeros((n, n)) 170 T10 = torch.zeros((n, n)) 171 if use_t1: 172 torch.relu_(T1) 173 return ( 174 torch.mm(T1, T2) 175 + torch.mm(T3, T4) 176 + torch.mm(T5, T6) 177 + torch.mm(T7, T8) 178 + torch.mm(T9, T10) 179 ) 180 else: 181 return ( 182 torch.mm(T2, T3) 183 + torch.mm(T4, T5) 184 + torch.mm(T6, T7) 185 + torch.mm(T8, T9) 186 ) 187 188 self.run_pass("batch_mm", test_batch_mm.graph) 189 FileCheck().check_count("aten::mm", 5, exactly=True).check_count( 190 "prim::MMTreeReduce", 1, exactly=True 191 ).run(test_batch_mm.graph) 192 193 def test_batch_mm_side_permitted_mutation(self): 194 @torch.jit.script 195 def test_batch_mm(n: int): 196 result = {} 197 A = torch.zeros((n, n)) 198 T1 = torch.zeros((n, n)) 199 T2 = torch.zeros((n, n)) 200 T3 = torch.zeros((n, n)) 201 T4 = torch.zeros((n, n)) 202 T5 = torch.zeros((n, n)) 203 T6 = torch.zeros((n, n)) 204 T7 = torch.zeros((n, n)) 205 T8 = torch.zeros((n, n)) 206 result["T1"] = torch.mm(A, T1) 207 result["T2"] = torch.mm(A, T2) 208 result["T3"] = torch.mm(A, T3) 209 result["T4"] = torch.mm(A, T4) 210 result["T5"] = torch.mm(A, T5) 211 result["T6"] = torch.mm(A, T6) 212 result["T7"] = torch.mm(A, T7) 213 result["T8"] = torch.mm(A, T8) 214 return result 215 216 FileCheck().check_count("aten::mm", 8, exactly=True).run(test_batch_mm.graph) 217 self.run_pass("batch_mm", test_batch_mm.graph) 218 FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).check_not( 219 "aten::mm" 220 ).run(test_batch_mm.graph) 221 222 def test_batch_mm_side_prohibited_mutation_uncommon_side(self): 223 @torch.jit.script 224 def test_batch_mm(n: int): 225 A = torch.zeros((n, n)) 226 T1 = torch.zeros((n, n)) 227 T2 = torch.zeros((n, n)) 228 T3 = torch.zeros((n, n)) 229 T4 = torch.zeros((n, n)) 230 T5 = torch.zeros((n, n)) 231 T6 = torch.zeros((n, n)) 232 T7 = torch.zeros((n, n)) 233 T8 = torch.zeros((n, n)) 234 T9 = torch.zeros((n, n)) 235 T10 = torch.zeros((n, n)) 236 torch.relu_(T1) 237 result = {} 238 result["T1"] = torch.mm(A, T1) 239 result["T2"] = torch.mm(A, T2) 240 result["T3"] = torch.mm(A, T3) 241 result["T4"] = torch.mm(A, T4) 242 result["T5"] = torch.mm(A, T5) 243 result["T6"] = torch.mm(A, T6) 244 result["T7"] = torch.mm(A, T7) 245 result["T8"] = torch.mm(A, T8) 246 result["T9"] = torch.mm(A, T9) 247 result["T10"] = torch.mm(A, T10) 248 return result 249 250 FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph) 251 self.run_pass("batch_mm", test_batch_mm.graph) 252 253 FileCheck().check_count("aten::mm", 1, exactly=True).run(test_batch_mm.graph) 254 FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).run( 255 test_batch_mm.graph 256 ) 257 258 def test_batch_mm_side_prohibited_mutation_common_side(self): 259 @torch.jit.script 260 def test_batch_mm(n: int): 261 A = torch.zeros((n, n)) 262 T1 = torch.zeros((n, n)) 263 T2 = torch.zeros((n, n)) 264 T3 = torch.zeros((n, n)) 265 T4 = torch.zeros((n, n)) 266 T5 = torch.zeros((n, n)) 267 T6 = torch.zeros((n, n)) 268 T7 = torch.zeros((n, n)) 269 T8 = torch.zeros((n, n)) 270 T9 = torch.zeros((n, n)) 271 T10 = torch.zeros((n, n)) 272 torch.relu_(A) 273 result = {} 274 result["T1"] = torch.mm(A, T1) 275 result["T2"] = torch.mm(A, T2) 276 result["T3"] = torch.mm(A, T3) 277 result["T4"] = torch.mm(A, T4) 278 result["T5"] = torch.mm(A, T5) 279 result["T6"] = torch.mm(A, T6) 280 result["T7"] = torch.mm(A, T7) 281 result["T8"] = torch.mm(A, T8) 282 result["T9"] = torch.mm(A, T9) 283 result["T10"] = torch.mm(A, T10) 284 return result 285 286 FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph) 287 self.run_pass("batch_mm", test_batch_mm.graph) 288 FileCheck().check_count("aten::mm", 10, exactly=True).check_not( 289 "prim::MMBatchSide" 290 ).run(test_batch_mm.graph) 291