xref: /aosp_15_r20/external/pytorch/test/test_static_runtime.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import unittest
4from typing import Dict, Optional
5
6import numpy as np
7import torch
8from torch import nn
9from torch.testing._internal.common_utils import TestCase, run_tests
10from torch.testing._internal.static_module import StaticModule
11from typing import List
12
13
14def linear_shim(
15    input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
16) -> torch.Tensor:
17    output = input.matmul(weight.t())
18    if bias is not None:
19        output += bias
20    ret = output
21    return ret
22
23
24torch.nn.functional.linear = linear_shim
25
26
27class MultiHeadAttentionLayer(nn.Module):
28    def __init__(self, hid_dim, n_heads, dropout, device):
29        super().__init__()
30        assert hid_dim % n_heads == 0
31        self.hid_dim = hid_dim
32        self.n_heads = n_heads
33        self.head_dim = hid_dim // n_heads
34        self.fc_q = nn.Linear(hid_dim, hid_dim)
35        self.fc_k = nn.Linear(hid_dim, hid_dim)
36        self.fc_v = nn.Linear(hid_dim, hid_dim)
37        self.fc_o = nn.Linear(hid_dim, hid_dim)
38        # self.dropout = nn.Dropout(dropout)
39        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
40
41    def forward(self, query, key, value, mask):
42        batch_size = query.shape[0]
43        Q = self.fc_q(query)
44        K = self.fc_k(key)
45        V = self.fc_v(value)
46        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
47        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
48        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
49        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
50        # energy = energy.masked_fill(mask == 0, -1e10)
51        attention = torch.softmax(energy, dim=-1)
52        # x = torch.matmul(self.dropout(attention), V)
53        x = torch.matmul(attention, V)
54        x = x.permute(0, 2, 1, 3).contiguous()
55        x = x.view(batch_size, -1, self.hid_dim)
56        x = self.fc_o(x)
57        return x, attention
58
59
60# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py
61def create_mlp(ln, sigmoid_layer):
62    layers = nn.ModuleList()
63    for i in range(0, len(ln) - 1):
64        n = ln[i]
65        m = ln[i + 1]
66
67        LL = nn.Linear(int(n), int(m), bias=True)
68
69        mean = 0.0  # std_dev = np.sqrt(variance)
70        std_dev = np.sqrt(2 / (m + n))  # np.sqrt(1 / m) # np.sqrt(1 / n)
71        W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
72        std_dev = np.sqrt(1 / m)  # np.sqrt(2 / (m + 1))
73        bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
74        LL.weight.data = torch.tensor(W, requires_grad=True)
75        LL.bias.data = torch.tensor(bt, requires_grad=True)
76        layers.append(LL)
77
78        if i == sigmoid_layer:
79            layers.append(nn.Sigmoid())
80        else:
81            layers.append(nn.ReLU())
82
83    with torch.no_grad():
84        s = torch.jit.script(torch.nn.Sequential(*layers))
85    s.eval()
86    return s
87
88
89def trivial_graph(a, b, c):
90    s = torch.tensor([[3, 3], [3, 3]])
91    return a + b * c + s
92
93def elementwise_square_addition(input1, input2):
94    return input1 * input1 + input2 * input2
95
96def fork_wait_graph1(input1, input2):
97    fut = torch.jit.fork(elementwise_square_addition, input1, input2)
98    return torch.jit.wait(fut)
99
100def fork_wait_graph2(input1, input2):
101    fut = torch.jit.fork(loop_graph, input1, input2, 5)
102    return torch.jit.wait(fut)
103
104"""
105   graph with multiple fork/wait operations
106   :param input: torch.tensor input to forked subgraph
107   :param iters: number of future/wait pairs to be created
108"""
109def fork_wait_graph3(input, iters: int):
110    futures : List[torch.jit.Future[torch.Tensor]] = []
111    for _ in range(iters):
112        futures.append(torch.jit.fork(torch.neg, input))
113    results = []
114    for future in futures:
115        results.append(torch.jit.wait(future))
116    return torch.sum(torch.stack(results))
117
118"""
119   graph with multi-level fork/wait operations
120   :param input: torch.tensor input to forked subgraph
121   :param num_forks: number of top level forks
122   :param num_child_forks: number of child forks per parent fork
123"""
124def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
125    futures : List[torch.jit.Future[torch.Tensor]] = []
126    for _ in range(num_forks):
127        futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
128    results = []
129    for future in futures:
130        results.append(torch.jit.wait(future))
131    return torch.sum(torch.stack(results))
132
133def add_tensor(input1, input2):
134    return input1 + input2
135
136def fork_wait_graph_exception(input1, input2):
137    fut = torch.jit.fork(add_tensor, input1, input2)
138    return torch.jit.wait(fut)
139
140def loop_graph(a, b, iters: int):
141    c = a + b * 2
142    for i in range(iters):
143        c = c + b
144        c *= 2
145        c -= a
146    return c
147
148
149def output_graph(a, b, c, iters: int):
150    s = torch.tensor([[3, 3], [3, 3]])
151    k = a + b * c + s
152    d: Dict[int, torch.Tensor] = {}
153    for i in range(iters):
154        d[i] = k + i
155    return d
156
157
158class SubModule(nn.Module):
159    def __init__(self) -> None:
160        super().__init__()
161        self.a = 11
162        self.b = 2
163
164    def forward(self, x):
165        return self.a + self.b + x
166
167
168class SubModule2(nn.Module):
169    def __init__(self) -> None:
170        super().__init__()
171        self.a = 12
172        self.b = 2
173
174    def forward(self, x):
175        self.b = 30
176        return self.a + self.b + x
177
178
179class TestModule(nn.Module):
180    def __init__(self) -> None:
181        super().__init__()
182        self.sub1 = SubModule()
183        self.sub2 = SubModule2()
184        self.a = 3
185        self.b = 4
186
187    def forward(self, x):
188        self.b = 20
189        return self.sub1(x) + self.a + self.b + self.sub2(x)
190
191
192class TestStaticModule(TestCase):
193
194    """
195    Test Case: To test simple fork/wait operation in a graph
196    fork is called on simple addition operation on input tensors
197    """
198    def test_fork_wait_1(self):
199        inp1 = torch.ones(5, 5)
200        inp2 = torch.randn(5, 5)
201        torch_graph = torch.jit.script(fork_wait_graph1)
202        output_ref = torch_graph(inp1, inp2)
203        static_runtime_module = StaticModule(torch_graph)
204        output_test = static_runtime_module(inp1, inp2)
205        torch.testing.assert_close(output_test, output_ref)
206
207    """
208    Test Case: To test simple fork/wait operation with
209    StaticRuntime runAsync API returning future
210    """
211    def test_fork_wait_1_async(self):
212        inp1 = torch.ones(5, 5)
213        inp2 = torch.randn(5, 5)
214        torch_graph = torch.jit.script(fork_wait_graph1)
215        output_ref = torch_graph(inp1, inp2)
216        static_runtime_module = StaticModule(torch_graph)
217        output_test = static_runtime_module.runAsync((inp1, inp2), {})
218        output_test.wait()
219        torch.testing.assert_close(output_test.value(), output_ref)
220
221    """
222    Test Case: To test fork/wait operation in a graph on
223    a loop subgraph performing mix of operations
224    """
225    def test_fork_wait_2(self):
226        inp1 = torch.randn(5, 5)
227        inp2 = torch.randn(5, 5)
228        torch_graph = torch.jit.script(fork_wait_graph2)
229        output_ref = torch_graph(inp1, inp2)
230        static_runtime_module = StaticModule(torch_graph)
231        output_test = static_runtime_module(inp1, inp2)
232        torch.testing.assert_close(output_test, output_ref)
233
234    """
235    Test Case: To test fork/wait operation on a loop
236    subgraph with StaticRuntime runAsync API returning future
237    """
238    def test_fork_wait_2_async(self):
239        inp1 = torch.randn(5, 5)
240        inp2 = torch.randn(5, 5)
241        torch_graph = torch.jit.script(fork_wait_graph2)
242        output_ref = torch_graph(inp1, inp2)
243        static_runtime_module = StaticModule(torch_graph)
244        output_test = static_runtime_module.runAsync((inp1, inp2), {})
245        output_test.wait()
246        torch.testing.assert_close(output_test.value(), output_ref)
247
248    """
249    Test Case: To test fork/wait operation in a graph on
250    having multiple fork/wait operations
251    """
252    def test_fork_wait_3(self):
253        input = torch.ones(3, 3)
254        num_forks = 10
255        torch_graph = torch.jit.script(fork_wait_graph3)
256        output_ref = torch_graph(input, num_forks)
257        static_runtime_module = StaticModule(torch_graph)
258        output_test = static_runtime_module(input, num_forks)
259        torch.testing.assert_close(output_test, output_ref)
260
261    """
262    Test Case: To test fork/wait operation in a graph with
263    multiple fork/wait operations on runAsync API returning future
264    """
265    def test_fork_wait_3_async(self):
266        input = torch.ones(3, 3)
267        num_forks = 10
268        torch_graph = torch.jit.script(fork_wait_graph3)
269        output_ref = torch_graph(input, num_forks)
270        static_runtime_module = StaticModule(torch_graph)
271        output_test = static_runtime_module.runAsync((input, num_forks), {})
272        output_test.wait()
273        torch.testing.assert_close(output_test.value(), output_ref)
274
275    """
276    Test Case: To test fork/wait operation in a graph on
277    multiple nested fork/wait operations
278    """
279    @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
280    def test_fork_wait_4(self):
281        input = torch.ones(3, 3)
282        num_forks = 10
283        num_child_forks = 10
284        torch_graph = torch.jit.script(fork_wait_graph4)
285        static_runtime_module = StaticModule(torch_graph)
286        output_ref = torch_graph(input, num_forks, num_child_forks)
287        output_test = static_runtime_module(input, num_forks, num_child_forks)
288        torch.testing.assert_close(output_test, output_ref)
289
290    """
291    Test Case: To test fork/wait operation in a graph with multiple
292    nested fork/wait operations on runAsync API returning future
293    """
294    @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
295    def test_fork_wait_4_async(self):
296        input = torch.ones(3, 3)
297        num_forks = 10
298        num_child_forks = 10
299        torch_graph = torch.jit.script(fork_wait_graph4)
300        static_runtime_module = StaticModule(torch_graph)
301        output_ref = torch_graph(input, num_forks, num_child_forks)
302        output_test = static_runtime_module.runAsync(
303            (input, num_forks, num_child_forks), {})
304        output_test.wait()
305        torch.testing.assert_close(output_test.value(), output_ref)
306
307    """
308    Test Case: To test exception handling in fork/wait
309    operation. Add.Tensor op is called for tensors with
310    non-matching dims on the forked subgraph and the
311    exception raised by subgraph is set on future returned
312    by prim::fork to parent graph. Returned exception is
313    checked for substring expected_error_msg as declared below
314    """
315    def test_fork_wait_exception(self):
316        # incompatible tensors for add due to shape mismatch
317        input1 = torch.randn(4, 7)
318        input2 = torch.randn(4, 5)
319        torch_graph = torch.jit.script(fork_wait_graph_exception)
320        try:
321            static_runtime_module = StaticModule(torch_graph)
322            output_test = static_runtime_module(input1, input2)
323        except Exception as error:
324            expected_error_msg = (
325                "The size of tensor a (7) must match the size "
326                "of tensor b (5) at non-singleton dimension 1"
327            )
328            # test fails if error does not contain expected substr
329            if str(error).find(expected_error_msg) == -1:
330                raise RuntimeError(
331                    "Tried execution of add.Tensors with incompatible shape. "
332                    "Exception raised by forked runtime execution does "
333                    f'not contain expected substring: "{expected_error_msg}"'
334                ) from error
335
336    """
337    Test Case: To test exception handling in fork/wait
338    operation with runAsync API. Add.Tensor op is called for
339    tensors with non-matching dims on the forked subgraph
340    and the exception raised by subgraph is set on future returned
341    by prim::fork to parent graph. Returned exception is
342    checked for substring expected_error_msg as declared below
343    """
344    def test_fork_wait_exception_async(self):
345        # incompatible tensors for add due to shape mismatch
346        input1 = torch.randn(4, 7)
347        input2 = torch.randn(4, 5)
348        torch_graph = torch.jit.script(fork_wait_graph_exception)
349        try:
350            static_runtime_module = StaticModule(torch_graph)
351            output_test = static_runtime_module.runAsync(
352                (input1, input2), {})
353        except Exception as error:
354            expected_error_msg = (
355                "The size of tensor a (7) must match the size "
356                "of tensor b (5) at non-singleton dimension 1"
357            )
358            # test fails if error does not contain expected substr
359            if str(error).find(expected_error_msg) == -1:
360                raise RuntimeError(
361                    "Tried execution of add.Tensors with incompatible shape. "
362                    "Exception raised by forked runtime execution does "
363                    f'not contain expected substring: "{expected_error_msg}"'
364                ) from error
365
366    def test_multihead_attention_layer(self):
367        HID_DIM = 256
368        QUERY_LEN = 8
369        BATCH_SIZE = 128
370        LAYERS = 3
371        HEADS = 8
372        DROPOUT = 0.1
373        device = torch.device("cpu")
374        attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
375        with torch.no_grad():
376            src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
377        src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
378
379        attention.eval()
380        attention = torch.jit.script(attention)
381        attention.eval()
382        o_ref = attention(src, src, src, src_mask)
383
384        attention_a = StaticModule(attention)
385        o_test = attention_a(src, src, src, src_mask)
386        o_test_kw = attention_a(src, src, value=src, mask=src_mask)
387
388        for a, b in zip(o_ref, o_test):
389            torch.testing.assert_close(a, b)
390
391        for a, b in zip(o_ref, o_test_kw):
392            torch.testing.assert_close(a, b)
393
394    def test_multihead_attention_layer_benchmark(self):
395        HID_DIM = 256
396        QUERY_LEN = 8
397        BATCH_SIZE = 128
398        LAYERS = 3
399        HEADS = 8
400        DROPOUT = 0.1
401        device = torch.device("cpu")
402        attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
403        with torch.no_grad():
404            src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
405        src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
406
407        attention.eval()
408        attention = torch.jit.script(attention)
409        attention_a = StaticModule(attention)
410
411        attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
412        metrics = attention_a.benchmark_individual_ops(
413            [src, src, src, src_mask], {}, 2, 2
414        )
415
416    def test_mlp(self):
417        # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
418        ln_bot = [512, 512, 64]
419        sigmoid_bot = -1
420        ln_top = [100, 1024, 1024, 1024, 1]
421        sigmoid_top = 3
422        bot_l = create_mlp(ln_bot, sigmoid_bot)
423        bot_l_acc = StaticModule(bot_l)
424        top_l = create_mlp(ln_top, sigmoid_top)
425        top_l_acc = StaticModule(top_l)
426        with torch.no_grad():
427            bot_inp = torch.randn(2048, 512)  # torch.Size([2048, 512])
428            top_inp = torch.randn(2048, 100)  # torch.Size([2048, 100])
429        ref_bot = bot_l(bot_inp)
430        acc_bot = bot_l_acc(bot_inp)
431        torch.testing.assert_close(acc_bot, ref_bot)
432        ref_top = top_l(top_inp)
433        acc_top = top_l_acc(top_inp)
434        torch.testing.assert_close(acc_top, ref_top)
435        for _ in range(5):
436            with torch.no_grad():
437                bot_inp = torch.randn(2048, 512)  # torch.Size([2048, 512])
438                top_inp = torch.randn(2048, 100)  # torch.Size([2048, 100])
439            ref_bot = bot_l(bot_inp)
440            acc_bot = bot_l_acc(bot_inp)
441            torch.testing.assert_close(acc_bot, ref_bot)
442            ref_top = top_l(top_inp)
443            acc_top = top_l_acc(top_inp)
444            torch.testing.assert_close(acc_top, ref_top)
445
446    def test_trivial_graph(self):
447        s = torch.full((2, 2), 2)
448        tg = torch.jit.script(trivial_graph)
449        o_ref = tg(s, s, s)
450        tg_a = StaticModule(tg)
451        o_test = tg_a(s, s, s)
452        torch.testing.assert_close(o_ref, o_test)
453
454    def test_leaky_relu(self):
455        s = torch.randn(5, 5)
456        tg = torch.jit.script(nn.LeakyReLU(0.1))
457        o_ref = tg(s)
458        tg_a = StaticModule(tg)
459        o_test = tg_a(s)
460        torch.testing.assert_close(o_ref, o_test)
461
462    def test_attr(self):
463        """
464        TorchScript IR of TestModule() after freezing:
465        graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule,
466              %x.1 : Tensor):
467            %18 : int = prim::Constant[value=30]()
468            %30 : int = prim::Constant[value=13]()
469            %3 : int = prim::Constant[value=20]()
470            %2 : int = prim::Constant[value=1]()
471            %self.sub2.a : int = prim::Constant[value=12]()
472            %self.a : int = prim::Constant[value=3]()
473            = prim::SetAttr[name="b"](%self, %3)
474            %17 : Tensor = aten::add(%x.1, %30, %2)
475            %7 : Tensor = aten::add(%17, %self.a, %2)
476            %b.1 : int = prim::GetAttr[name="b"](%self)
477            %9 : Tensor = aten::add(%7, %b.1, %2)
478            %sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self)
479            = prim::SetAttr[name="b"](%sub2, %18)
480            %b : int = prim::GetAttr[name="b"](%sub2)
481            %22 : int = aten::add(%self.sub2.a, %b)
482            %23 : Tensor = aten::add(%x.1, %22, %2)
483            %12 : Tensor = aten::add(%9, %23, %2)
484            return (%12)
485        """
486        # test prim::SetAttr and prim::GetAttr impl in Static Runtime
487        m = TestModule()
488
489        m.eval()
490        input = torch.randn(2, 2)
491        output_s = m.forward(input)
492
493        ms = torch.jit.script(m)
494        sm = StaticModule(ms)
495        output_sm = sm(input)
496        torch.testing.assert_close(output_s, output_sm)
497        sm.benchmark([input], {}, 2, 2)
498        sm.benchmark_individual_ops([input], {}, 2, 2)
499        sm.benchmark([], {"x": input}, 2, 2)
500        sm.benchmark_individual_ops([], {"x": input}, 2, 2)
501
502    @unittest.skip("Temporarily disabled")
503    def test_fusion_trivial_graph(self):
504        s = torch.full((2, 2), 2)
505        tg = torch.jit.script(trivial_graph)
506        o_ref = tg(s, s, s)
507        torch._C._fuse_to_static_module(tg.graph)
508        assert "StaticSubgraph" in str(tg.graph)
509        o_test = tg(s, s, s)
510        torch.testing.assert_close(o_ref, o_test)
511
512    @unittest.skip("Temporarily disabled")
513    def test_fusion_multihead_attention_layer(self):
514        HID_DIM = 256
515        QUERY_LEN = 8
516        BATCH_SIZE = 128
517        LAYERS = 3
518        HEADS = 8
519        DROPOUT = 0.1
520        device = torch.device("cpu")
521        attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
522        with torch.no_grad():
523            src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
524        src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
525
526        attention.eval()
527        attention = torch.jit.script(attention)
528        attention.eval()
529        o_ref = attention(src, src, src, src_mask)
530
531        torch._C._fuse_to_static_module(attention._c)
532        o_test = attention(src, src, src, src_mask)
533
534        for a, b in zip(o_ref, o_test):
535            torch.testing.assert_close(a, b)
536
537    @unittest.skip("Temporarily disabled")
538    def test_fusion_loop(self):
539        a = torch.randn(5, 5)
540        b = torch.randn(5, 5)
541        c = 4
542        lg = torch.jit.script(loop_graph)
543        o_ref = lg(a, b, c)
544        torch._C._fuse_to_static_module(lg.graph)
545        assert "StaticSubgraph" in str(lg.graph)
546        o_test = lg(a, b, c)
547        torch.testing.assert_close(o_ref, o_test)
548
549    @unittest.skip("Temporarily disabled")
550    def test_fusion_outputs(self):
551        a = torch.randn(2, 2)
552        b = torch.randn(2, 2)
553        c = 4
554        og = torch.jit.script(output_graph)
555        o_ref = og(a, b, b, c)
556        torch._C._fuse_to_static_module(og.graph)
557        assert "StaticSubgraph" in str(og.graph)
558        o_test = og(a, b, b, c)
559        for i in o_ref.keys():
560            torch.testing.assert_close(o_ref[i], o_test[i])
561
562    def test_create_object(self):
563        class Foo:  # noqa: B903
564            def __init__(self, x: torch.Tensor) -> None:
565                self.x = x
566
567        class Mod(torch.nn.Module):
568            def __init__(self) -> None:
569                super().__init__()
570
571            def forward(self, y: torch.Tensor) -> torch.Tensor:
572                foo = Foo(y)
573                return y * foo.x
574
575        mod = torch.jit.script(Mod()).eval()
576        y = torch.randn((1, ))
577        expected = mod(y)
578
579        static_mod = StaticModule(torch.jit.freeze(mod))
580        actual = static_mod(y)
581
582        self.assertEqual(expected, actual)
583
584if __name__ == "__main__":
585    run_tests()
586