1# Owner(s): ["module: unknown"] 2 3import unittest 4from typing import Dict, Optional 5 6import numpy as np 7import torch 8from torch import nn 9from torch.testing._internal.common_utils import TestCase, run_tests 10from torch.testing._internal.static_module import StaticModule 11from typing import List 12 13 14def linear_shim( 15 input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None 16) -> torch.Tensor: 17 output = input.matmul(weight.t()) 18 if bias is not None: 19 output += bias 20 ret = output 21 return ret 22 23 24torch.nn.functional.linear = linear_shim 25 26 27class MultiHeadAttentionLayer(nn.Module): 28 def __init__(self, hid_dim, n_heads, dropout, device): 29 super().__init__() 30 assert hid_dim % n_heads == 0 31 self.hid_dim = hid_dim 32 self.n_heads = n_heads 33 self.head_dim = hid_dim // n_heads 34 self.fc_q = nn.Linear(hid_dim, hid_dim) 35 self.fc_k = nn.Linear(hid_dim, hid_dim) 36 self.fc_v = nn.Linear(hid_dim, hid_dim) 37 self.fc_o = nn.Linear(hid_dim, hid_dim) 38 # self.dropout = nn.Dropout(dropout) 39 self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) 40 41 def forward(self, query, key, value, mask): 42 batch_size = query.shape[0] 43 Q = self.fc_q(query) 44 K = self.fc_k(key) 45 V = self.fc_v(value) 46 Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 47 K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 48 V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 49 energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale 50 # energy = energy.masked_fill(mask == 0, -1e10) 51 attention = torch.softmax(energy, dim=-1) 52 # x = torch.matmul(self.dropout(attention), V) 53 x = torch.matmul(attention, V) 54 x = x.permute(0, 2, 1, 3).contiguous() 55 x = x.view(batch_size, -1, self.hid_dim) 56 x = self.fc_o(x) 57 return x, attention 58 59 60# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py 61def create_mlp(ln, sigmoid_layer): 62 layers = nn.ModuleList() 63 for i in range(0, len(ln) - 1): 64 n = ln[i] 65 m = ln[i + 1] 66 67 LL = nn.Linear(int(n), int(m), bias=True) 68 69 mean = 0.0 # std_dev = np.sqrt(variance) 70 std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n) 71 W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32) 72 std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1)) 73 bt = np.random.normal(mean, std_dev, size=m).astype(np.float32) 74 LL.weight.data = torch.tensor(W, requires_grad=True) 75 LL.bias.data = torch.tensor(bt, requires_grad=True) 76 layers.append(LL) 77 78 if i == sigmoid_layer: 79 layers.append(nn.Sigmoid()) 80 else: 81 layers.append(nn.ReLU()) 82 83 with torch.no_grad(): 84 s = torch.jit.script(torch.nn.Sequential(*layers)) 85 s.eval() 86 return s 87 88 89def trivial_graph(a, b, c): 90 s = torch.tensor([[3, 3], [3, 3]]) 91 return a + b * c + s 92 93def elementwise_square_addition(input1, input2): 94 return input1 * input1 + input2 * input2 95 96def fork_wait_graph1(input1, input2): 97 fut = torch.jit.fork(elementwise_square_addition, input1, input2) 98 return torch.jit.wait(fut) 99 100def fork_wait_graph2(input1, input2): 101 fut = torch.jit.fork(loop_graph, input1, input2, 5) 102 return torch.jit.wait(fut) 103 104""" 105 graph with multiple fork/wait operations 106 :param input: torch.tensor input to forked subgraph 107 :param iters: number of future/wait pairs to be created 108""" 109def fork_wait_graph3(input, iters: int): 110 futures : List[torch.jit.Future[torch.Tensor]] = [] 111 for _ in range(iters): 112 futures.append(torch.jit.fork(torch.neg, input)) 113 results = [] 114 for future in futures: 115 results.append(torch.jit.wait(future)) 116 return torch.sum(torch.stack(results)) 117 118""" 119 graph with multi-level fork/wait operations 120 :param input: torch.tensor input to forked subgraph 121 :param num_forks: number of top level forks 122 :param num_child_forks: number of child forks per parent fork 123""" 124def fork_wait_graph4(input, num_forks: int, num_child_forks: int): 125 futures : List[torch.jit.Future[torch.Tensor]] = [] 126 for _ in range(num_forks): 127 futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks)) 128 results = [] 129 for future in futures: 130 results.append(torch.jit.wait(future)) 131 return torch.sum(torch.stack(results)) 132 133def add_tensor(input1, input2): 134 return input1 + input2 135 136def fork_wait_graph_exception(input1, input2): 137 fut = torch.jit.fork(add_tensor, input1, input2) 138 return torch.jit.wait(fut) 139 140def loop_graph(a, b, iters: int): 141 c = a + b * 2 142 for i in range(iters): 143 c = c + b 144 c *= 2 145 c -= a 146 return c 147 148 149def output_graph(a, b, c, iters: int): 150 s = torch.tensor([[3, 3], [3, 3]]) 151 k = a + b * c + s 152 d: Dict[int, torch.Tensor] = {} 153 for i in range(iters): 154 d[i] = k + i 155 return d 156 157 158class SubModule(nn.Module): 159 def __init__(self) -> None: 160 super().__init__() 161 self.a = 11 162 self.b = 2 163 164 def forward(self, x): 165 return self.a + self.b + x 166 167 168class SubModule2(nn.Module): 169 def __init__(self) -> None: 170 super().__init__() 171 self.a = 12 172 self.b = 2 173 174 def forward(self, x): 175 self.b = 30 176 return self.a + self.b + x 177 178 179class TestModule(nn.Module): 180 def __init__(self) -> None: 181 super().__init__() 182 self.sub1 = SubModule() 183 self.sub2 = SubModule2() 184 self.a = 3 185 self.b = 4 186 187 def forward(self, x): 188 self.b = 20 189 return self.sub1(x) + self.a + self.b + self.sub2(x) 190 191 192class TestStaticModule(TestCase): 193 194 """ 195 Test Case: To test simple fork/wait operation in a graph 196 fork is called on simple addition operation on input tensors 197 """ 198 def test_fork_wait_1(self): 199 inp1 = torch.ones(5, 5) 200 inp2 = torch.randn(5, 5) 201 torch_graph = torch.jit.script(fork_wait_graph1) 202 output_ref = torch_graph(inp1, inp2) 203 static_runtime_module = StaticModule(torch_graph) 204 output_test = static_runtime_module(inp1, inp2) 205 torch.testing.assert_close(output_test, output_ref) 206 207 """ 208 Test Case: To test simple fork/wait operation with 209 StaticRuntime runAsync API returning future 210 """ 211 def test_fork_wait_1_async(self): 212 inp1 = torch.ones(5, 5) 213 inp2 = torch.randn(5, 5) 214 torch_graph = torch.jit.script(fork_wait_graph1) 215 output_ref = torch_graph(inp1, inp2) 216 static_runtime_module = StaticModule(torch_graph) 217 output_test = static_runtime_module.runAsync((inp1, inp2), {}) 218 output_test.wait() 219 torch.testing.assert_close(output_test.value(), output_ref) 220 221 """ 222 Test Case: To test fork/wait operation in a graph on 223 a loop subgraph performing mix of operations 224 """ 225 def test_fork_wait_2(self): 226 inp1 = torch.randn(5, 5) 227 inp2 = torch.randn(5, 5) 228 torch_graph = torch.jit.script(fork_wait_graph2) 229 output_ref = torch_graph(inp1, inp2) 230 static_runtime_module = StaticModule(torch_graph) 231 output_test = static_runtime_module(inp1, inp2) 232 torch.testing.assert_close(output_test, output_ref) 233 234 """ 235 Test Case: To test fork/wait operation on a loop 236 subgraph with StaticRuntime runAsync API returning future 237 """ 238 def test_fork_wait_2_async(self): 239 inp1 = torch.randn(5, 5) 240 inp2 = torch.randn(5, 5) 241 torch_graph = torch.jit.script(fork_wait_graph2) 242 output_ref = torch_graph(inp1, inp2) 243 static_runtime_module = StaticModule(torch_graph) 244 output_test = static_runtime_module.runAsync((inp1, inp2), {}) 245 output_test.wait() 246 torch.testing.assert_close(output_test.value(), output_ref) 247 248 """ 249 Test Case: To test fork/wait operation in a graph on 250 having multiple fork/wait operations 251 """ 252 def test_fork_wait_3(self): 253 input = torch.ones(3, 3) 254 num_forks = 10 255 torch_graph = torch.jit.script(fork_wait_graph3) 256 output_ref = torch_graph(input, num_forks) 257 static_runtime_module = StaticModule(torch_graph) 258 output_test = static_runtime_module(input, num_forks) 259 torch.testing.assert_close(output_test, output_ref) 260 261 """ 262 Test Case: To test fork/wait operation in a graph with 263 multiple fork/wait operations on runAsync API returning future 264 """ 265 def test_fork_wait_3_async(self): 266 input = torch.ones(3, 3) 267 num_forks = 10 268 torch_graph = torch.jit.script(fork_wait_graph3) 269 output_ref = torch_graph(input, num_forks) 270 static_runtime_module = StaticModule(torch_graph) 271 output_test = static_runtime_module.runAsync((input, num_forks), {}) 272 output_test.wait() 273 torch.testing.assert_close(output_test.value(), output_ref) 274 275 """ 276 Test Case: To test fork/wait operation in a graph on 277 multiple nested fork/wait operations 278 """ 279 @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782") 280 def test_fork_wait_4(self): 281 input = torch.ones(3, 3) 282 num_forks = 10 283 num_child_forks = 10 284 torch_graph = torch.jit.script(fork_wait_graph4) 285 static_runtime_module = StaticModule(torch_graph) 286 output_ref = torch_graph(input, num_forks, num_child_forks) 287 output_test = static_runtime_module(input, num_forks, num_child_forks) 288 torch.testing.assert_close(output_test, output_ref) 289 290 """ 291 Test Case: To test fork/wait operation in a graph with multiple 292 nested fork/wait operations on runAsync API returning future 293 """ 294 @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782") 295 def test_fork_wait_4_async(self): 296 input = torch.ones(3, 3) 297 num_forks = 10 298 num_child_forks = 10 299 torch_graph = torch.jit.script(fork_wait_graph4) 300 static_runtime_module = StaticModule(torch_graph) 301 output_ref = torch_graph(input, num_forks, num_child_forks) 302 output_test = static_runtime_module.runAsync( 303 (input, num_forks, num_child_forks), {}) 304 output_test.wait() 305 torch.testing.assert_close(output_test.value(), output_ref) 306 307 """ 308 Test Case: To test exception handling in fork/wait 309 operation. Add.Tensor op is called for tensors with 310 non-matching dims on the forked subgraph and the 311 exception raised by subgraph is set on future returned 312 by prim::fork to parent graph. Returned exception is 313 checked for substring expected_error_msg as declared below 314 """ 315 def test_fork_wait_exception(self): 316 # incompatible tensors for add due to shape mismatch 317 input1 = torch.randn(4, 7) 318 input2 = torch.randn(4, 5) 319 torch_graph = torch.jit.script(fork_wait_graph_exception) 320 try: 321 static_runtime_module = StaticModule(torch_graph) 322 output_test = static_runtime_module(input1, input2) 323 except Exception as error: 324 expected_error_msg = ( 325 "The size of tensor a (7) must match the size " 326 "of tensor b (5) at non-singleton dimension 1" 327 ) 328 # test fails if error does not contain expected substr 329 if str(error).find(expected_error_msg) == -1: 330 raise RuntimeError( 331 "Tried execution of add.Tensors with incompatible shape. " 332 "Exception raised by forked runtime execution does " 333 f'not contain expected substring: "{expected_error_msg}"' 334 ) from error 335 336 """ 337 Test Case: To test exception handling in fork/wait 338 operation with runAsync API. Add.Tensor op is called for 339 tensors with non-matching dims on the forked subgraph 340 and the exception raised by subgraph is set on future returned 341 by prim::fork to parent graph. Returned exception is 342 checked for substring expected_error_msg as declared below 343 """ 344 def test_fork_wait_exception_async(self): 345 # incompatible tensors for add due to shape mismatch 346 input1 = torch.randn(4, 7) 347 input2 = torch.randn(4, 5) 348 torch_graph = torch.jit.script(fork_wait_graph_exception) 349 try: 350 static_runtime_module = StaticModule(torch_graph) 351 output_test = static_runtime_module.runAsync( 352 (input1, input2), {}) 353 except Exception as error: 354 expected_error_msg = ( 355 "The size of tensor a (7) must match the size " 356 "of tensor b (5) at non-singleton dimension 1" 357 ) 358 # test fails if error does not contain expected substr 359 if str(error).find(expected_error_msg) == -1: 360 raise RuntimeError( 361 "Tried execution of add.Tensors with incompatible shape. " 362 "Exception raised by forked runtime execution does " 363 f'not contain expected substring: "{expected_error_msg}"' 364 ) from error 365 366 def test_multihead_attention_layer(self): 367 HID_DIM = 256 368 QUERY_LEN = 8 369 BATCH_SIZE = 128 370 LAYERS = 3 371 HEADS = 8 372 DROPOUT = 0.1 373 device = torch.device("cpu") 374 attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) 375 with torch.no_grad(): 376 src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) 377 src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) 378 379 attention.eval() 380 attention = torch.jit.script(attention) 381 attention.eval() 382 o_ref = attention(src, src, src, src_mask) 383 384 attention_a = StaticModule(attention) 385 o_test = attention_a(src, src, src, src_mask) 386 o_test_kw = attention_a(src, src, value=src, mask=src_mask) 387 388 for a, b in zip(o_ref, o_test): 389 torch.testing.assert_close(a, b) 390 391 for a, b in zip(o_ref, o_test_kw): 392 torch.testing.assert_close(a, b) 393 394 def test_multihead_attention_layer_benchmark(self): 395 HID_DIM = 256 396 QUERY_LEN = 8 397 BATCH_SIZE = 128 398 LAYERS = 3 399 HEADS = 8 400 DROPOUT = 0.1 401 device = torch.device("cpu") 402 attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) 403 with torch.no_grad(): 404 src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) 405 src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) 406 407 attention.eval() 408 attention = torch.jit.script(attention) 409 attention_a = StaticModule(attention) 410 411 attention_a.benchmark([src, src, src, src_mask], {}, 2, 2) 412 metrics = attention_a.benchmark_individual_ops( 413 [src, src, src, src_mask], {}, 2, 2 414 ) 415 416 def test_mlp(self): 417 # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh 418 ln_bot = [512, 512, 64] 419 sigmoid_bot = -1 420 ln_top = [100, 1024, 1024, 1024, 1] 421 sigmoid_top = 3 422 bot_l = create_mlp(ln_bot, sigmoid_bot) 423 bot_l_acc = StaticModule(bot_l) 424 top_l = create_mlp(ln_top, sigmoid_top) 425 top_l_acc = StaticModule(top_l) 426 with torch.no_grad(): 427 bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) 428 top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) 429 ref_bot = bot_l(bot_inp) 430 acc_bot = bot_l_acc(bot_inp) 431 torch.testing.assert_close(acc_bot, ref_bot) 432 ref_top = top_l(top_inp) 433 acc_top = top_l_acc(top_inp) 434 torch.testing.assert_close(acc_top, ref_top) 435 for _ in range(5): 436 with torch.no_grad(): 437 bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) 438 top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) 439 ref_bot = bot_l(bot_inp) 440 acc_bot = bot_l_acc(bot_inp) 441 torch.testing.assert_close(acc_bot, ref_bot) 442 ref_top = top_l(top_inp) 443 acc_top = top_l_acc(top_inp) 444 torch.testing.assert_close(acc_top, ref_top) 445 446 def test_trivial_graph(self): 447 s = torch.full((2, 2), 2) 448 tg = torch.jit.script(trivial_graph) 449 o_ref = tg(s, s, s) 450 tg_a = StaticModule(tg) 451 o_test = tg_a(s, s, s) 452 torch.testing.assert_close(o_ref, o_test) 453 454 def test_leaky_relu(self): 455 s = torch.randn(5, 5) 456 tg = torch.jit.script(nn.LeakyReLU(0.1)) 457 o_ref = tg(s) 458 tg_a = StaticModule(tg) 459 o_test = tg_a(s) 460 torch.testing.assert_close(o_ref, o_test) 461 462 def test_attr(self): 463 """ 464 TorchScript IR of TestModule() after freezing: 465 graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule, 466 %x.1 : Tensor): 467 %18 : int = prim::Constant[value=30]() 468 %30 : int = prim::Constant[value=13]() 469 %3 : int = prim::Constant[value=20]() 470 %2 : int = prim::Constant[value=1]() 471 %self.sub2.a : int = prim::Constant[value=12]() 472 %self.a : int = prim::Constant[value=3]() 473 = prim::SetAttr[name="b"](%self, %3) 474 %17 : Tensor = aten::add(%x.1, %30, %2) 475 %7 : Tensor = aten::add(%17, %self.a, %2) 476 %b.1 : int = prim::GetAttr[name="b"](%self) 477 %9 : Tensor = aten::add(%7, %b.1, %2) 478 %sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self) 479 = prim::SetAttr[name="b"](%sub2, %18) 480 %b : int = prim::GetAttr[name="b"](%sub2) 481 %22 : int = aten::add(%self.sub2.a, %b) 482 %23 : Tensor = aten::add(%x.1, %22, %2) 483 %12 : Tensor = aten::add(%9, %23, %2) 484 return (%12) 485 """ 486 # test prim::SetAttr and prim::GetAttr impl in Static Runtime 487 m = TestModule() 488 489 m.eval() 490 input = torch.randn(2, 2) 491 output_s = m.forward(input) 492 493 ms = torch.jit.script(m) 494 sm = StaticModule(ms) 495 output_sm = sm(input) 496 torch.testing.assert_close(output_s, output_sm) 497 sm.benchmark([input], {}, 2, 2) 498 sm.benchmark_individual_ops([input], {}, 2, 2) 499 sm.benchmark([], {"x": input}, 2, 2) 500 sm.benchmark_individual_ops([], {"x": input}, 2, 2) 501 502 @unittest.skip("Temporarily disabled") 503 def test_fusion_trivial_graph(self): 504 s = torch.full((2, 2), 2) 505 tg = torch.jit.script(trivial_graph) 506 o_ref = tg(s, s, s) 507 torch._C._fuse_to_static_module(tg.graph) 508 assert "StaticSubgraph" in str(tg.graph) 509 o_test = tg(s, s, s) 510 torch.testing.assert_close(o_ref, o_test) 511 512 @unittest.skip("Temporarily disabled") 513 def test_fusion_multihead_attention_layer(self): 514 HID_DIM = 256 515 QUERY_LEN = 8 516 BATCH_SIZE = 128 517 LAYERS = 3 518 HEADS = 8 519 DROPOUT = 0.1 520 device = torch.device("cpu") 521 attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) 522 with torch.no_grad(): 523 src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) 524 src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) 525 526 attention.eval() 527 attention = torch.jit.script(attention) 528 attention.eval() 529 o_ref = attention(src, src, src, src_mask) 530 531 torch._C._fuse_to_static_module(attention._c) 532 o_test = attention(src, src, src, src_mask) 533 534 for a, b in zip(o_ref, o_test): 535 torch.testing.assert_close(a, b) 536 537 @unittest.skip("Temporarily disabled") 538 def test_fusion_loop(self): 539 a = torch.randn(5, 5) 540 b = torch.randn(5, 5) 541 c = 4 542 lg = torch.jit.script(loop_graph) 543 o_ref = lg(a, b, c) 544 torch._C._fuse_to_static_module(lg.graph) 545 assert "StaticSubgraph" in str(lg.graph) 546 o_test = lg(a, b, c) 547 torch.testing.assert_close(o_ref, o_test) 548 549 @unittest.skip("Temporarily disabled") 550 def test_fusion_outputs(self): 551 a = torch.randn(2, 2) 552 b = torch.randn(2, 2) 553 c = 4 554 og = torch.jit.script(output_graph) 555 o_ref = og(a, b, b, c) 556 torch._C._fuse_to_static_module(og.graph) 557 assert "StaticSubgraph" in str(og.graph) 558 o_test = og(a, b, b, c) 559 for i in o_ref.keys(): 560 torch.testing.assert_close(o_ref[i], o_test[i]) 561 562 def test_create_object(self): 563 class Foo: # noqa: B903 564 def __init__(self, x: torch.Tensor) -> None: 565 self.x = x 566 567 class Mod(torch.nn.Module): 568 def __init__(self) -> None: 569 super().__init__() 570 571 def forward(self, y: torch.Tensor) -> torch.Tensor: 572 foo = Foo(y) 573 return y * foo.x 574 575 mod = torch.jit.script(Mod()).eval() 576 y = torch.randn((1, )) 577 expected = mod(y) 578 579 static_mod = StaticModule(torch.jit.freeze(mod)) 580 actual = static_mod(y) 581 582 self.assertEqual(expected, actual) 583 584if __name__ == "__main__": 585 run_tests() 586