1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, Optional, Tuple 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 10*da0073e9SAndroid Build Coastguard Workerimport torch.testing._internal.jit_utils 11*da0073e9SAndroid Build Coastguard Workerfrom jit.test_module_interface import TestModuleInterface # noqa: F401 12*da0073e9SAndroid Build Coastguard Workerfrom torch import jit 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import freeze_rng_state 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 19*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 20*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 23*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 24*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 25*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 26*da0073e9SAndroid Build Coastguard Worker "instead." 27*da0073e9SAndroid Build Coastguard Worker ) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workerclass TestMisc(JitTestCase): 31*da0073e9SAndroid Build Coastguard Worker def test_joined_str(self): 32*da0073e9SAndroid Build Coastguard Worker def func(x): 33*da0073e9SAndroid Build Coastguard Worker hello, test = "Hello", "test" 34*da0073e9SAndroid Build Coastguard Worker print(f"{hello + ' ' + test}, I'm a {test}") 35*da0073e9SAndroid Build Coastguard Worker print("format blank") 36*da0073e9SAndroid Build Coastguard Worker hi = "hi" 37*da0073e9SAndroid Build Coastguard Worker print(f"stuff before {hi}") 38*da0073e9SAndroid Build Coastguard Worker print(f"{hi} stuff after") 39*da0073e9SAndroid Build Coastguard Worker return x + 1 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker x = torch.arange(4.0, requires_grad=True) 42*da0073e9SAndroid Build Coastguard Worker # TODO: Add support for f-strings in string parser frontend 43*da0073e9SAndroid Build Coastguard Worker # self.checkScript(func, [x], optimize=True, capture_output=True) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker with self.capture_stdout() as captured: 46*da0073e9SAndroid Build Coastguard Worker out = func(x) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(func) 49*da0073e9SAndroid Build Coastguard Worker with self.capture_stdout() as captured_script: 50*da0073e9SAndroid Build Coastguard Worker out_script = func(x) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_script) 53*da0073e9SAndroid Build Coastguard Worker self.assertEqual(captured, captured_script) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def test_kwarg_support(self): 56*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 57*da0073e9SAndroid Build Coastguard Worker torch.jit.frontend.NotSupportedError, "variable number of arguments" 58*da0073e9SAndroid Build Coastguard Worker ): 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 61*da0073e9SAndroid Build Coastguard Worker def forward(self, *, n_tokens: int, device_name: str = 2): 62*da0073e9SAndroid Build Coastguard Worker pass 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker torch.jit.script(M()) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 67*da0073e9SAndroid Build Coastguard Worker def forward(self, *, n_tokens: int, device_name: str): 68*da0073e9SAndroid Build Coastguard Worker return n_tokens, device_name 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(M()) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 73*da0073e9SAndroid Build Coastguard Worker RuntimeError, "missing value for argument 'n_tokens'" 74*da0073e9SAndroid Build Coastguard Worker ): 75*da0073e9SAndroid Build Coastguard Worker sm() 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "positional arg"): 78*da0073e9SAndroid Build Coastguard Worker sm(3, "hello") 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sm(n_tokens=3, device_name="hello"), (3, "hello")) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker def test_tuple_subscripted_assign(self): 83*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "subscripted assignment"): 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 86*da0073e9SAndroid Build Coastguard Worker def foo(a: Tuple[int, int]) -> None: 87*da0073e9SAndroid Build Coastguard Worker a[0] = a[1] 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "augmented assignment"): 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 92*da0073e9SAndroid Build Coastguard Worker def bar(a: Tuple[int, int]) -> None: 93*da0073e9SAndroid Build Coastguard Worker a[0] += a[1] 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def test_subexpression_List_Future(self): 96*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 97*da0073e9SAndroid Build Coastguard Worker def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]: 98*da0073e9SAndroid Build Coastguard Worker return x[0] 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Future[int]").check("Future[int]").run(fn.graph) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def test_subexpression_Future_annotate(self): 103*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 104*da0073e9SAndroid Build Coastguard Worker def fn() -> torch.jit.Future[int]: 105*da0073e9SAndroid Build Coastguard Worker x: List[torch.jit.Future[int]] = [] 106*da0073e9SAndroid Build Coastguard Worker return x[0] 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Future[int][]").run(fn.graph) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker def test_future_isinstance(self): 111*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 112*da0073e9SAndroid Build Coastguard Worker def fn(x: Any) -> torch.jit.Future[int]: 113*da0073e9SAndroid Build Coastguard Worker assert isinstance(x, jit.Future[int]) 114*da0073e9SAndroid Build Coastguard Worker return x 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Future[int]").run(fn.graph) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker def test_str_refine_any(self): 119*da0073e9SAndroid Build Coastguard Worker def forward(x: Any) -> str: 120*da0073e9SAndroid Build Coastguard Worker if isinstance(x, str): 121*da0073e9SAndroid Build Coastguard Worker return x 122*da0073e9SAndroid Build Coastguard Worker return "foo" 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker forward = torch.jit.script(forward) 125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(forward(1), "foo") 126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(forward("bar"), "bar") 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker def test_subexpression_Tuple_int_int_Future(self): 129*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 130*da0073e9SAndroid Build Coastguard Worker def fn( 131*da0073e9SAndroid Build Coastguard Worker x: Tuple[int, int, torch.jit.Future[int]] 132*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[int, torch.jit.Future[int]]: 133*da0073e9SAndroid Build Coastguard Worker return x[0], x[2] 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker FileCheck().check("(int, int, Future[int])").check("(int, Future[int])").run( 136*da0073e9SAndroid Build Coastguard Worker fn.graph 137*da0073e9SAndroid Build Coastguard Worker ) 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker def test_subexpression_Dict_int_Future(self): 140*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 141*da0073e9SAndroid Build Coastguard Worker def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]: 142*da0073e9SAndroid Build Coastguard Worker return x[y] 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Dict(int, Future(int))").check("Future[int]").run(fn.graph) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker def test_subexpression_Optional(self): 147*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 148*da0073e9SAndroid Build Coastguard Worker def fn( 149*da0073e9SAndroid Build Coastguard Worker x: Optional[Dict[int, torch.jit.Future[int]]] 150*da0073e9SAndroid Build Coastguard Worker ) -> Optional[torch.jit.Future[int]]: 151*da0073e9SAndroid Build Coastguard Worker if x is not None: 152*da0073e9SAndroid Build Coastguard Worker return x[0] 153*da0073e9SAndroid Build Coastguard Worker else: 154*da0073e9SAndroid Build Coastguard Worker return None 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Dict(int, Future(int))?").run(fn.graph) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker def test_if_returning_any(self): 159*da0073e9SAndroid Build Coastguard Worker """ 160*da0073e9SAndroid Build Coastguard Worker Check that an if statement can return different 161*da0073e9SAndroid Build Coastguard Worker types early from each branch when the return 162*da0073e9SAndroid Build Coastguard Worker type of the function is Any. 163*da0073e9SAndroid Build Coastguard Worker """ 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker def if_function(inp: torch.Tensor) -> Any: 166*da0073e9SAndroid Build Coastguard Worker if inp.shape[0] == 1: 167*da0073e9SAndroid Build Coastguard Worker return inp * inp 168*da0073e9SAndroid Build Coastguard Worker else: 169*da0073e9SAndroid Build Coastguard Worker return "str" 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker self.checkScript(if_function, (torch.randn(5),)) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker def test_hacked_twin(self): 174*da0073e9SAndroid Build Coastguard Worker def gen_data(): 175*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 176*da0073e9SAndroid Build Coastguard Worker return torch.randn(10), torch.randint(10, (20,)), torch.randn(20) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker ( 179*da0073e9SAndroid Build Coastguard Worker input, 180*da0073e9SAndroid Build Coastguard Worker index, 181*da0073e9SAndroid Build Coastguard Worker value, 182*da0073e9SAndroid Build Coastguard Worker ) = gen_data() 183*da0073e9SAndroid Build Coastguard Worker ( 184*da0073e9SAndroid Build Coastguard Worker input1, 185*da0073e9SAndroid Build Coastguard Worker index1, 186*da0073e9SAndroid Build Coastguard Worker value1, 187*da0073e9SAndroid Build Coastguard Worker ) = gen_data() 188*da0073e9SAndroid Build Coastguard Worker out1 = torch.ops.aten.index_put.hacked_twin( 189*da0073e9SAndroid Build Coastguard Worker input, [index], value, accumulate=False 190*da0073e9SAndroid Build Coastguard Worker ) 191*da0073e9SAndroid Build Coastguard Worker out2 = torch.index_put(input1, [index1], value1, accumulate=False) 192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.index_put_.hacked_twin(input, [index], value, accumulate=False) 195*da0073e9SAndroid Build Coastguard Worker torch.index_put_(input1, [index1], value1, accumulate=False) 196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, input1) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker def test_unsafe_hacked_twin(self): 199*da0073e9SAndroid Build Coastguard Worker def gen_data(): 200*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 201*da0073e9SAndroid Build Coastguard Worker return torch.randn(10), torch.randint(10, (20,)), torch.randn(20) 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker ( 204*da0073e9SAndroid Build Coastguard Worker input, 205*da0073e9SAndroid Build Coastguard Worker index, 206*da0073e9SAndroid Build Coastguard Worker value, 207*da0073e9SAndroid Build Coastguard Worker ) = gen_data() 208*da0073e9SAndroid Build Coastguard Worker ( 209*da0073e9SAndroid Build Coastguard Worker input1, 210*da0073e9SAndroid Build Coastguard Worker index1, 211*da0073e9SAndroid Build Coastguard Worker value1, 212*da0073e9SAndroid Build Coastguard Worker ) = gen_data() 213*da0073e9SAndroid Build Coastguard Worker out1 = torch.ops.aten._unsafe_index_put.hacked_twin( 214*da0073e9SAndroid Build Coastguard Worker input, [index], value, accumulate=False 215*da0073e9SAndroid Build Coastguard Worker ) 216*da0073e9SAndroid Build Coastguard Worker out2 = torch.index_put(input1, [index1], value1, accumulate=False) 217*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index]) 220*da0073e9SAndroid Build Coastguard Worker torch.index_put(input1, [index1], value1, accumulate=False) 221*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, input1) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker def index_put_fn(input, index, value): 224*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten._unsafe_index_put( 225*da0073e9SAndroid Build Coastguard Worker input, [index], value, accumulate=False 226*da0073e9SAndroid Build Coastguard Worker ) 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker input2, index2, value2 = gen_data() 229*da0073e9SAndroid Build Coastguard Worker script_index_put_fn = torch.jit.script(index_put_fn) 230*da0073e9SAndroid Build Coastguard Worker expect = index_put_fn(input2.clone(), index2, value2) 231*da0073e9SAndroid Build Coastguard Worker actual = script_index_put_fn(input2.clone(), index2, value2) 232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker def index_fn(input, index, value): 235*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten._unsafe_index_put( 236*da0073e9SAndroid Build Coastguard Worker input, [index], value, accumulate=False 237*da0073e9SAndroid Build Coastguard Worker ) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker script_index_fn = torch.jit.script(index_fn) 240*da0073e9SAndroid Build Coastguard Worker expect = index_fn(input2.clone(), index2, value2) 241*da0073e9SAndroid Build Coastguard Worker actual = script_index_fn(input2.clone(), index2, value2) 242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker def test_export_opnames_interface(self): 245*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 246*da0073e9SAndroid Build Coastguard Worker class OneTwoModule(nn.Module): 247*da0073e9SAndroid Build Coastguard Worker def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 248*da0073e9SAndroid Build Coastguard Worker pass 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker def two(self, x: torch.Tensor) -> torch.Tensor: 251*da0073e9SAndroid Build Coastguard Worker pass 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 254*da0073e9SAndroid Build Coastguard Worker pass 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker class FooMod(nn.Module): 257*da0073e9SAndroid Build Coastguard Worker def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 258*da0073e9SAndroid Build Coastguard Worker return x + y 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker def two(self, x: torch.Tensor) -> torch.Tensor: 261*da0073e9SAndroid Build Coastguard Worker return 2 * x 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 264*da0073e9SAndroid Build Coastguard Worker return self.one(self.two(x), x) 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker class BarMod(nn.Module): 267*da0073e9SAndroid Build Coastguard Worker def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 268*da0073e9SAndroid Build Coastguard Worker return x * y 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker def two(self, x: torch.Tensor) -> torch.Tensor: 271*da0073e9SAndroid Build Coastguard Worker return 2 / x 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 274*da0073e9SAndroid Build Coastguard Worker return self.two(self.one(x, x)) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker make_global(OneTwoModule) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker class M(nn.Module): 279*da0073e9SAndroid Build Coastguard Worker sub: OneTwoModule 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 282*da0073e9SAndroid Build Coastguard Worker super().__init__() 283*da0073e9SAndroid Build Coastguard Worker self.sub = BarMod() 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 286*da0073e9SAndroid Build Coastguard Worker return self.sub.forward(x) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): 289*da0073e9SAndroid Build Coastguard Worker return mod_list[0].forward(x) + mod_list[1].forward(x) 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker torch._C._enable_mobile_interface_call_export() 292*da0073e9SAndroid Build Coastguard Worker scripted_M_mod = torch.jit.script(M()) 293*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 294*da0073e9SAndroid Build Coastguard Worker {"aten::mul.Scalar", "aten::mul.Tensor", "aten::reciprocal"}.issubset( 295*da0073e9SAndroid Build Coastguard Worker set(torch.jit.export_opnames(scripted_M_mod)) 296*da0073e9SAndroid Build Coastguard Worker ) 297*da0073e9SAndroid Build Coastguard Worker ) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker scripted_M_mod.sub = torch.jit.script(FooMod()) 300*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 301*da0073e9SAndroid Build Coastguard Worker {"aten::add.Tensor", "aten::mul.Scalar"}.issubset( 302*da0073e9SAndroid Build Coastguard Worker set(torch.jit.export_opnames(scripted_M_mod)) 303*da0073e9SAndroid Build Coastguard Worker ) 304*da0073e9SAndroid Build Coastguard Worker ) 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker def test_math_inf(self): 307*da0073e9SAndroid Build Coastguard Worker from math import inf 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker def foo(): 310*da0073e9SAndroid Build Coastguard Worker return inf 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, ()) 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker def test_list_literal_infer(self): 315*da0073e9SAndroid Build Coastguard Worker def expects_intlist(x: List[int]): 316*da0073e9SAndroid Build Coastguard Worker x.append(3) 317*da0073e9SAndroid Build Coastguard Worker return x 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker def foo(): 320*da0073e9SAndroid Build Coastguard Worker return expects_intlist([]) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, ()) 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker def annotated_list_fail(): 325*da0073e9SAndroid Build Coastguard Worker return expects_intlist(torch.jit.annotate([], List[Tensor])) # noqa: F821 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 328*da0073e9SAndroid Build Coastguard Worker torch.jit.script(annotated_list_fail) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker def non_temporary_fail(): 331*da0073e9SAndroid Build Coastguard Worker a = [] 332*da0073e9SAndroid Build Coastguard Worker return expects_intlist(a) 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 335*da0073e9SAndroid Build Coastguard Worker torch.jit.script(non_temporary_fail) 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 338*da0073e9SAndroid Build Coastguard Worker def test_return(): 339*da0073e9SAndroid Build Coastguard Worker return [] 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph) 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker def test_legacy_tensor_constructor(self): 344*da0073e9SAndroid Build Coastguard Worker # testing PyObject overload 345*da0073e9SAndroid Build Coastguard Worker def test_all_dtypes(): 346*da0073e9SAndroid Build Coastguard Worker return ( 347*da0073e9SAndroid Build Coastguard Worker torch.BoolTensor([2]), 348*da0073e9SAndroid Build Coastguard Worker torch.LongTensor([3]), 349*da0073e9SAndroid Build Coastguard Worker torch.ByteTensor([4]), 350*da0073e9SAndroid Build Coastguard Worker torch.CharTensor([5]), 351*da0073e9SAndroid Build Coastguard Worker torch.DoubleTensor([6]), 352*da0073e9SAndroid Build Coastguard Worker torch.FloatTensor([7]), 353*da0073e9SAndroid Build Coastguard Worker torch.IntTensor([8]), 354*da0073e9SAndroid Build Coastguard Worker torch.ShortTensor([1]), 355*da0073e9SAndroid Build Coastguard Worker torch.HalfTensor([1]), 356*da0073e9SAndroid Build Coastguard Worker ) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_all_dtypes, ()) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker # now test empty overload 361*da0073e9SAndroid Build Coastguard Worker def empty_overload(): 362*da0073e9SAndroid Build Coastguard Worker return torch.LongTensor(2, 3, 4) 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker eager = empty_overload() 365*da0073e9SAndroid Build Coastguard Worker jit = torch.jit.script(empty_overload)() 366*da0073e9SAndroid Build Coastguard Worker eager[:] = 1 367*da0073e9SAndroid Build Coastguard Worker jit[:] = 1 368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, jit) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker def no_inputs(): 371*da0073e9SAndroid Build Coastguard Worker return torch.DoubleTensor() 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker self.checkScript(no_inputs, ()) 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker # bad schema 376*da0073e9SAndroid Build Coastguard Worker def multiple_args(): 377*da0073e9SAndroid Build Coastguard Worker return torch.LongTensor(1, [2]) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 380*da0073e9SAndroid Build Coastguard Worker RuntimeError, "multiple positional arguments that were not all integers" 381*da0073e9SAndroid Build Coastguard Worker ): 382*da0073e9SAndroid Build Coastguard Worker torch.jit.script(multiple_args) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker # kwarg bad schema 385*da0073e9SAndroid Build Coastguard Worker def bad_kwarg(): 386*da0073e9SAndroid Build Coastguard Worker return torch.LongTensor(hello="1") 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "hello"): 389*da0073e9SAndroid Build Coastguard Worker torch.jit.script(bad_kwarg) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker def test_broadcasting_list(self): 392*da0073e9SAndroid Build Coastguard Worker """ 393*da0073e9SAndroid Build Coastguard Worker Test BroadcastingList and torch.nn._size_N_t alias 394*da0073e9SAndroid Build Coastguard Worker """ 395*da0073e9SAndroid Build Coastguard Worker from torch._jit_internal import BroadcastingList2 396*da0073e9SAndroid Build Coastguard Worker from torch.nn.common_types import _size_2_t 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker def sum_i(x: _size_2_t) -> int: 399*da0073e9SAndroid Build Coastguard Worker return x[0] + x[1] 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker def sum_f(x: BroadcastingList2[float]) -> float: 402*da0073e9SAndroid Build Coastguard Worker return x[0] + x[1] 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.jit.script(sum_i)(4) == 8) 405*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.0) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker def test_parse_ir_annotate(self): 408*da0073e9SAndroid Build Coastguard Worker ir = """ 409*da0073e9SAndroid Build Coastguard Worker graph(): 410*da0073e9SAndroid Build Coastguard Worker %3 : int[] = prim::Constant[value=annotate(List[int], [])]() 411*da0073e9SAndroid Build Coastguard Worker return (%3) 412*da0073e9SAndroid Build Coastguard Worker """ 413*da0073e9SAndroid Build Coastguard Worker graph = torch._C.parse_ir(ir, True) 414*da0073e9SAndroid Build Coastguard Worker func = torch._C._create_function_from_graph("forward", graph) 415*da0073e9SAndroid Build Coastguard Worker ret = func() 416*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ret == []) 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker def test_parse_ir_single_element_tensor_positive(self): 419*da0073e9SAndroid Build Coastguard Worker ir = """ 420*da0073e9SAndroid Build Coastguard Worker graph(): 421*da0073e9SAndroid Build Coastguard Worker %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={0}]() 422*da0073e9SAndroid Build Coastguard Worker return (%7) 423*da0073e9SAndroid Build Coastguard Worker """ 424*da0073e9SAndroid Build Coastguard Worker graph = torch._C.parse_ir(ir, True) 425*da0073e9SAndroid Build Coastguard Worker func = torch._C._create_function_from_graph("forward", graph) 426*da0073e9SAndroid Build Coastguard Worker ret = func() 427*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ret.numel() == 1) 428*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(ret.size()) == 1) 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker def test_parse_ir_single_element_tensor_negative(self): 431*da0073e9SAndroid Build Coastguard Worker ir = """ 432*da0073e9SAndroid Build Coastguard Worker graph(): 433*da0073e9SAndroid Build Coastguard Worker %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={-17}]() 434*da0073e9SAndroid Build Coastguard Worker return (%7) 435*da0073e9SAndroid Build Coastguard Worker """ 436*da0073e9SAndroid Build Coastguard Worker graph = torch._C.parse_ir(ir, True) 437*da0073e9SAndroid Build Coastguard Worker func = torch._C._create_function_from_graph("forward", graph) 438*da0073e9SAndroid Build Coastguard Worker ret = func() 439*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ret.numel() == 1) 440*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(ret.size()) == 1) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker def test_script_many_decorators(self): 443*da0073e9SAndroid Build Coastguard Worker def no_op_decorator(f): 444*da0073e9SAndroid Build Coastguard Worker return f 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker @no_op_decorator 447*da0073e9SAndroid Build Coastguard Worker @no_op_decorator 448*da0073e9SAndroid Build Coastguard Worker @no_op_decorator 449*da0073e9SAndroid Build Coastguard Worker @no_op_decorator 450*da0073e9SAndroid Build Coastguard Worker @no_op_decorator 451*da0073e9SAndroid Build Coastguard Worker def foo(x, dim: int): 452*da0073e9SAndroid Build Coastguard Worker return x.unsqueeze(dim) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker x = torch.randn( 455*da0073e9SAndroid Build Coastguard Worker 1, 456*da0073e9SAndroid Build Coastguard Worker ) 457*da0073e9SAndroid Build Coastguard Worker expected = foo(x, 0) 458*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(foo) 459*da0073e9SAndroid Build Coastguard Worker actual = scripted(x, 0) 460*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(expected, actual) 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA_HALF, "need CUDA half support") 463*da0073e9SAndroid Build Coastguard Worker def test_pow_multiple_dtype(self): 464*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/75476 465*da0073e9SAndroid Build Coastguard Worker def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: 466*da0073e9SAndroid Build Coastguard Worker p = torch.sigmoid(p) 467*da0073e9SAndroid Build Coastguard Worker result = p**gamma 468*da0073e9SAndroid Build Coastguard Worker return result 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2), dtype=torch.half, device="cuda") 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker script_fn = torch.jit.script(fn) 475*da0073e9SAndroid Build Coastguard Worker for i in range(4): 476*da0073e9SAndroid Build Coastguard Worker res = script_fn(x) 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker def test_jit_get_operation_order(self): 481*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/pull/107138. 482*da0073e9SAndroid Build Coastguard Worker # Depending on order of operator registration, you can get different 483*da0073e9SAndroid Build Coastguard Worker # order of overloads in the JIT operator registry. 484*da0073e9SAndroid Build Coastguard Worker # This is to verify that the order of operators returned by 485*da0073e9SAndroid Build Coastguard Worker # _jit_get_operation always puts aten ops first (i.e. by sorting 486*da0073e9SAndroid Build Coastguard Worker # to put them first) 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker # Make sure that this chooses a "scalar" overload not a "complex" overload 489*da0073e9SAndroid Build Coastguard Worker ret = torch.ops.aten.add(4, 3.3) 490*da0073e9SAndroid Build Coastguard Worker self.assertFalse("complex" in str(ret.dtype)) 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker # "Scalar" overload is a normal aten op; "complex" is added by torchscript. 493*da0073e9SAndroid Build Coastguard Worker # We want "Scalar" to come before "complex". 494*da0073e9SAndroid Build Coastguard Worker op, override_names = torch._C._jit_get_operation("aten::add") 495*da0073e9SAndroid Build Coastguard Worker print(override_names) 496*da0073e9SAndroid Build Coastguard Worker complex_indices = [ 497*da0073e9SAndroid Build Coastguard Worker i for i, name in enumerate(override_names) if name == "complex" 498*da0073e9SAndroid Build Coastguard Worker ] 499*da0073e9SAndroid Build Coastguard Worker Scalar_indices = [ 500*da0073e9SAndroid Build Coastguard Worker i for i, name in enumerate(override_names) if name == "Scalar" 501*da0073e9SAndroid Build Coastguard Worker ] 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(complex_indices) > 0) 504*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(Scalar_indices) > 0) 505*da0073e9SAndroid Build Coastguard Worker self.assertTrue(complex_indices[0] > Scalar_indices[0]) 506