xref: /aosp_15_r20/external/pytorch/test/jit/test_warn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4import os
5import sys
6import warnings
7from contextlib import redirect_stderr
8
9import torch
10from torch.testing import FileCheck
11
12
13# Make the helper files in test/ importable
14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15sys.path.append(pytorch_test_dir)
16from torch.testing._internal.jit_utils import JitTestCase
17
18
19if __name__ == "__main__":
20    raise RuntimeError(
21        "This test file is not meant to be run directly, use:\n\n"
22        "\tpython test/test_jit.py TESTNAME\n\n"
23        "instead."
24    )
25
26
27class TestWarn(JitTestCase):
28    def test_warn(self):
29        @torch.jit.script
30        def fn():
31            warnings.warn("I am warning you")
32
33        f = io.StringIO()
34        with redirect_stderr(f):
35            fn()
36
37        FileCheck().check_count(
38            str="UserWarning: I am warning you", count=1, exactly=True
39        ).run(f.getvalue())
40
41    def test_warn_only_once(self):
42        @torch.jit.script
43        def fn():
44            for _ in range(10):
45                warnings.warn("I am warning you")
46
47        f = io.StringIO()
48        with redirect_stderr(f):
49            fn()
50
51        FileCheck().check_count(
52            str="UserWarning: I am warning you", count=1, exactly=True
53        ).run(f.getvalue())
54
55    def test_warn_only_once_in_loop_func(self):
56        def w():
57            warnings.warn("I am warning you")
58
59        @torch.jit.script
60        def fn():
61            for _ in range(10):
62                w()
63
64        f = io.StringIO()
65        with redirect_stderr(f):
66            fn()
67
68        FileCheck().check_count(
69            str="UserWarning: I am warning you", count=1, exactly=True
70        ).run(f.getvalue())
71
72    def test_warn_once_per_func(self):
73        def w1():
74            warnings.warn("I am warning you")
75
76        def w2():
77            warnings.warn("I am warning you")
78
79        @torch.jit.script
80        def fn():
81            w1()
82            w2()
83
84        f = io.StringIO()
85        with redirect_stderr(f):
86            fn()
87
88        FileCheck().check_count(
89            str="UserWarning: I am warning you", count=2, exactly=True
90        ).run(f.getvalue())
91
92    def test_warn_once_per_func_in_loop(self):
93        def w1():
94            warnings.warn("I am warning you")
95
96        def w2():
97            warnings.warn("I am warning you")
98
99        @torch.jit.script
100        def fn():
101            for _ in range(10):
102                w1()
103                w2()
104
105        f = io.StringIO()
106        with redirect_stderr(f):
107            fn()
108
109        FileCheck().check_count(
110            str="UserWarning: I am warning you", count=2, exactly=True
111        ).run(f.getvalue())
112
113    def test_warn_multiple_calls_multiple_warnings(self):
114        @torch.jit.script
115        def fn():
116            warnings.warn("I am warning you")
117
118        f = io.StringIO()
119        with redirect_stderr(f):
120            fn()
121            fn()
122
123        FileCheck().check_count(
124            str="UserWarning: I am warning you", count=2, exactly=True
125        ).run(f.getvalue())
126
127    def test_warn_multiple_calls_same_func_diff_stack(self):
128        def warn(caller: str):
129            warnings.warn("I am warning you from " + caller)
130
131        @torch.jit.script
132        def foo():
133            warn("foo")
134
135        @torch.jit.script
136        def bar():
137            warn("bar")
138
139        f = io.StringIO()
140        with redirect_stderr(f):
141            foo()
142            bar()
143
144        FileCheck().check_count(
145            str="UserWarning: I am warning you from foo", count=1, exactly=True
146        ).check_count(
147            str="UserWarning: I am warning you from bar", count=1, exactly=True
148        ).run(
149            f.getvalue()
150        )
151