1# Owner(s): ["oncall: jit"] 2 3import io 4import os 5import sys 6import warnings 7from contextlib import redirect_stderr 8 9import torch 10from torch.testing import FileCheck 11 12 13# Make the helper files in test/ importable 14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15sys.path.append(pytorch_test_dir) 16from torch.testing._internal.jit_utils import JitTestCase 17 18 19if __name__ == "__main__": 20 raise RuntimeError( 21 "This test file is not meant to be run directly, use:\n\n" 22 "\tpython test/test_jit.py TESTNAME\n\n" 23 "instead." 24 ) 25 26 27class TestWarn(JitTestCase): 28 def test_warn(self): 29 @torch.jit.script 30 def fn(): 31 warnings.warn("I am warning you") 32 33 f = io.StringIO() 34 with redirect_stderr(f): 35 fn() 36 37 FileCheck().check_count( 38 str="UserWarning: I am warning you", count=1, exactly=True 39 ).run(f.getvalue()) 40 41 def test_warn_only_once(self): 42 @torch.jit.script 43 def fn(): 44 for _ in range(10): 45 warnings.warn("I am warning you") 46 47 f = io.StringIO() 48 with redirect_stderr(f): 49 fn() 50 51 FileCheck().check_count( 52 str="UserWarning: I am warning you", count=1, exactly=True 53 ).run(f.getvalue()) 54 55 def test_warn_only_once_in_loop_func(self): 56 def w(): 57 warnings.warn("I am warning you") 58 59 @torch.jit.script 60 def fn(): 61 for _ in range(10): 62 w() 63 64 f = io.StringIO() 65 with redirect_stderr(f): 66 fn() 67 68 FileCheck().check_count( 69 str="UserWarning: I am warning you", count=1, exactly=True 70 ).run(f.getvalue()) 71 72 def test_warn_once_per_func(self): 73 def w1(): 74 warnings.warn("I am warning you") 75 76 def w2(): 77 warnings.warn("I am warning you") 78 79 @torch.jit.script 80 def fn(): 81 w1() 82 w2() 83 84 f = io.StringIO() 85 with redirect_stderr(f): 86 fn() 87 88 FileCheck().check_count( 89 str="UserWarning: I am warning you", count=2, exactly=True 90 ).run(f.getvalue()) 91 92 def test_warn_once_per_func_in_loop(self): 93 def w1(): 94 warnings.warn("I am warning you") 95 96 def w2(): 97 warnings.warn("I am warning you") 98 99 @torch.jit.script 100 def fn(): 101 for _ in range(10): 102 w1() 103 w2() 104 105 f = io.StringIO() 106 with redirect_stderr(f): 107 fn() 108 109 FileCheck().check_count( 110 str="UserWarning: I am warning you", count=2, exactly=True 111 ).run(f.getvalue()) 112 113 def test_warn_multiple_calls_multiple_warnings(self): 114 @torch.jit.script 115 def fn(): 116 warnings.warn("I am warning you") 117 118 f = io.StringIO() 119 with redirect_stderr(f): 120 fn() 121 fn() 122 123 FileCheck().check_count( 124 str="UserWarning: I am warning you", count=2, exactly=True 125 ).run(f.getvalue()) 126 127 def test_warn_multiple_calls_same_func_diff_stack(self): 128 def warn(caller: str): 129 warnings.warn("I am warning you from " + caller) 130 131 @torch.jit.script 132 def foo(): 133 warn("foo") 134 135 @torch.jit.script 136 def bar(): 137 warn("bar") 138 139 f = io.StringIO() 140 with redirect_stderr(f): 141 foo() 142 bar() 143 144 FileCheck().check_count( 145 str="UserWarning: I am warning you from foo", count=1, exactly=True 146 ).check_count( 147 str="UserWarning: I am warning you from bar", count=1, exactly=True 148 ).run( 149 f.getvalue() 150 ) 151