# Owner(s): ["oncall: jit"] import os import sys import warnings from typing import Any, Dict, List, Optional, Tuple import torch # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.testing._internal.jit_utils import JitTestCase if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) # Tests for torch.jit.isinstance class TestIsinstance(JitTestCase): def test_int(self): def int_test(x: Any): assert torch.jit.isinstance(x, int) assert not torch.jit.isinstance(x, float) x = 1 self.checkScript(int_test, (x,)) def test_float(self): def float_test(x: Any): assert torch.jit.isinstance(x, float) assert not torch.jit.isinstance(x, int) x = 1.0 self.checkScript(float_test, (x,)) def test_bool(self): def bool_test(x: Any): assert torch.jit.isinstance(x, bool) assert not torch.jit.isinstance(x, float) x = False self.checkScript(bool_test, (x,)) def test_list(self): def list_str_test(x: Any): assert torch.jit.isinstance(x, List[str]) assert not torch.jit.isinstance(x, List[int]) assert not torch.jit.isinstance(x, Tuple[int]) x = ["1", "2", "3"] self.checkScript(list_str_test, (x,)) def test_list_tensor(self): def list_tensor_test(x: Any): assert torch.jit.isinstance(x, List[torch.Tensor]) assert not torch.jit.isinstance(x, Tuple[int]) x = [torch.tensor([1]), torch.tensor([2]), torch.tensor([3])] self.checkScript(list_tensor_test, (x,)) def test_dict(self): def dict_str_int_test(x: Any): assert torch.jit.isinstance(x, Dict[str, int]) assert not torch.jit.isinstance(x, Dict[int, str]) assert not torch.jit.isinstance(x, Dict[str, str]) x = {"a": 1, "b": 2} self.checkScript(dict_str_int_test, (x,)) def test_dict_tensor(self): def dict_int_tensor_test(x: Any): assert torch.jit.isinstance(x, Dict[int, torch.Tensor]) x = {2: torch.tensor([2])} self.checkScript(dict_int_tensor_test, (x,)) def test_tuple(self): def tuple_test(x: Any): assert torch.jit.isinstance(x, Tuple[str, int, str]) assert not torch.jit.isinstance(x, Tuple[int, str, str]) assert not torch.jit.isinstance(x, Tuple[str]) x = ("a", 1, "b") self.checkScript(tuple_test, (x,)) def test_tuple_tensor(self): def tuple_tensor_test(x: Any): assert torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor]) x = (torch.tensor([1]), torch.tensor([[2], [3]])) self.checkScript(tuple_tensor_test, (x,)) def test_optional(self): def optional_test(x: Any): assert torch.jit.isinstance(x, Optional[torch.Tensor]) assert not torch.jit.isinstance(x, Optional[str]) x = torch.ones(3, 3) self.checkScript(optional_test, (x,)) def test_optional_none(self): def optional_test_none(x: Any): assert torch.jit.isinstance(x, Optional[torch.Tensor]) # assert torch.jit.isinstance(x, Optional[str]) # TODO: above line in eager will evaluate to True while in # the TS interpreter will evaluate to False as the # first torch.jit.isinstance refines the 'None' type x = None self.checkScript(optional_test_none, (x,)) def test_list_nested(self): def list_nested(x: Any): assert torch.jit.isinstance(x, List[Dict[str, int]]) assert not torch.jit.isinstance(x, List[List[str]]) x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] self.checkScript(list_nested, (x,)) def test_dict_nested(self): def dict_nested(x: Any): assert torch.jit.isinstance(x, Dict[str, Tuple[str, str, str]]) assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) x = {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")} self.checkScript(dict_nested, (x,)) def test_tuple_nested(self): def tuple_nested(x: Any): assert torch.jit.isinstance( x, Tuple[Dict[str, Tuple[str, str, str]], List[bool], Optional[str]] ) assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) assert not torch.jit.isinstance(x, Tuple[str]) assert not torch.jit.isinstance(x, Tuple[List[bool], List[str], List[int]]) x = ( {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")}, [True, False, True], None, ) self.checkScript(tuple_nested, (x,)) def test_optional_nested(self): def optional_nested(x: Any): assert torch.jit.isinstance(x, Optional[List[str]]) x = ["a", "b", "c"] self.checkScript(optional_nested, (x,)) def test_list_tensor_type_true(self): def list_tensor_type_true(x: Any): assert torch.jit.isinstance(x, List[torch.Tensor]) x = [torch.rand(3, 3), torch.rand(4, 3)] self.checkScript(list_tensor_type_true, (x,)) def test_tensor_type_false(self): def list_tensor_type_false(x: Any): assert not torch.jit.isinstance(x, List[torch.Tensor]) x = [1, 2, 3] self.checkScript(list_tensor_type_false, (x,)) def test_in_if(self): def list_in_if(x: Any): if torch.jit.isinstance(x, List[int]): assert True if torch.jit.isinstance(x, List[str]): assert not True x = [1, 2, 3] self.checkScript(list_in_if, (x,)) def test_if_else(self): def list_in_if_else(x: Any): if torch.jit.isinstance(x, Tuple[str, str, str]): assert True else: assert not True x = ("a", "b", "c") self.checkScript(list_in_if_else, (x,)) def test_in_while_loop(self): def list_in_while_loop(x: Any): count = 0 while torch.jit.isinstance(x, List[Dict[str, int]]) and count <= 0: count = count + 1 assert count == 1 x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] self.checkScript(list_in_while_loop, (x,)) def test_type_refinement(self): def type_refinement(obj: Any): hit = False if torch.jit.isinstance(obj, List[torch.Tensor]): hit = not hit for el in obj: # perform some tensor operation y = el.clamp(0, 0.5) if torch.jit.isinstance(obj, Dict[str, str]): hit = not hit str_cat = "" for val in obj.values(): str_cat = str_cat + val assert "111222" == str_cat assert hit x = [torch.rand(3, 3), torch.rand(4, 3)] self.checkScript(type_refinement, (x,)) x = {"1": "111", "2": "222"} self.checkScript(type_refinement, (x,)) def test_list_no_contained_type(self): def list_no_contained_type(x: Any): assert torch.jit.isinstance(x, List) x = ["1", "2", "3"] err_msg = ( "Attempted to use List without a contained type. " r"Please add a contained type, e.g. List\[int\]" ) with self.assertRaisesRegex( RuntimeError, err_msg, ): torch.jit.script(list_no_contained_type) with self.assertRaisesRegex( RuntimeError, err_msg, ): list_no_contained_type(x) def test_tuple_no_contained_type(self): def tuple_no_contained_type(x: Any): assert torch.jit.isinstance(x, Tuple) x = ("1", "2", "3") err_msg = ( "Attempted to use Tuple without a contained type. " r"Please add a contained type, e.g. Tuple\[int\]" ) with self.assertRaisesRegex( RuntimeError, err_msg, ): torch.jit.script(tuple_no_contained_type) with self.assertRaisesRegex( RuntimeError, err_msg, ): tuple_no_contained_type(x) def test_optional_no_contained_type(self): def optional_no_contained_type(x: Any): assert torch.jit.isinstance(x, Optional) x = ("1", "2", "3") err_msg = ( "Attempted to use Optional without a contained type. " r"Please add a contained type, e.g. Optional\[int\]" ) with self.assertRaisesRegex( RuntimeError, err_msg, ): torch.jit.script(optional_no_contained_type) with self.assertRaisesRegex( RuntimeError, err_msg, ): optional_no_contained_type(x) def test_dict_no_contained_type(self): def dict_no_contained_type(x: Any): assert torch.jit.isinstance(x, Dict) x = {"a": "aa"} err_msg = ( "Attempted to use Dict without contained types. " r"Please add contained type, e.g. Dict\[int, int\]" ) with self.assertRaisesRegex( RuntimeError, err_msg, ): torch.jit.script(dict_no_contained_type) with self.assertRaisesRegex( RuntimeError, err_msg, ): dict_no_contained_type(x) def test_tuple_rhs(self): def fn(x: Any): assert torch.jit.isinstance(x, (int, List[str])) assert not torch.jit.isinstance(x, (List[float], Tuple[int, str])) assert not torch.jit.isinstance(x, (List[float], str)) self.checkScript(fn, (2,)) self.checkScript(fn, (["foo", "bar", "baz"],)) def test_nontuple_container_rhs_throws_in_eager(self): def fn1(x: Any): assert torch.jit.isinstance(x, [int, List[str]]) def fn2(x: Any): assert not torch.jit.isinstance(x, {List[str], Tuple[int, str]}) err_highlight = "must be a type or a tuple of types" with self.assertRaisesRegex(RuntimeError, err_highlight): fn1(2) with self.assertRaisesRegex(RuntimeError, err_highlight): fn2(2) def test_empty_container_throws_warning_in_eager(self): def fn(x: Any): torch.jit.isinstance(x, List[int]) with warnings.catch_warnings(record=True) as w: x: List[int] = [] fn(x) self.assertEqual(len(w), 1) with warnings.catch_warnings(record=True) as w: x: int = 2 fn(x) self.assertEqual(len(w), 0) def test_empty_container_special_cases(self): # Should not throw "Boolean value of Tensor with no values is # ambiguous" error torch._jit_internal.check_empty_containers(torch.Tensor([])) # Should not throw "Boolean value of Tensor with more than # one value is ambiguous" error torch._jit_internal.check_empty_containers(torch.rand(2, 3))