# Owner(s): ["oncall: jit"] import os import sys from typing import List, Tuple import torch # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.testing._internal.jit_utils import JitTestCase if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestHash(JitTestCase): def test_hash_tuple(self): def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool: return hash(t1) == hash(t2) self.checkScript(fn, ((1, 2), (1, 2))) self.checkScript(fn, ((1, 2), (3, 4))) self.checkScript(fn, ((1, 2), (2, 1))) def test_hash_tuple_nested_unhashable_type(self): # Tuples may contain unhashable types like `list`, check that we error # properly in that case. @torch.jit.script def fn_unhashable(t1: Tuple[int, List[int]]): return hash(t1) with self.assertRaisesRegexWithHighlight(RuntimeError, "unhashable", "hash"): fn_unhashable((1, [1])) def test_hash_tensor(self): """Tensors should hash by identity""" def fn(t1, t2): return hash(t1) == hash(t2) tensor1 = torch.tensor(1) tensor1_clone = torch.tensor(1) tensor2 = torch.tensor(2) self.checkScript(fn, (tensor1, tensor1)) self.checkScript(fn, (tensor1, tensor1_clone)) self.checkScript(fn, (tensor1, tensor2)) def test_hash_none(self): def fn(): n1 = None n2 = None return hash(n1) == hash(n2) self.checkScript(fn, ()) def test_hash_bool(self): def fn(b1: bool, b2: bool): return hash(b1) == hash(b2) self.checkScript(fn, (True, False)) self.checkScript(fn, (True, True)) self.checkScript(fn, (False, True)) self.checkScript(fn, (False, False)) def test_hash_float(self): def fn(f1: float, f2: float): return hash(f1) == hash(f2) self.checkScript(fn, (1.2345, 1.2345)) self.checkScript(fn, (1.2345, 6.789)) self.checkScript(fn, (1.2345, float("inf"))) self.checkScript(fn, (float("inf"), float("inf"))) self.checkScript(fn, (1.2345, float("nan"))) if sys.version_info < (3, 10): # Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html : # Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity. self.checkScript(fn, (float("nan"), float("nan"))) self.checkScript(fn, (float("nan"), float("inf"))) def test_hash_int(self): def fn(i1: int, i2: int): return hash(i1) == hash(i2) self.checkScript(fn, (123, 456)) self.checkScript(fn, (123, 123)) self.checkScript(fn, (123, -123)) self.checkScript(fn, (-123, -123)) self.checkScript(fn, (123, 0)) def test_hash_string(self): def fn(s1: str, s2: str): return hash(s1) == hash(s2) self.checkScript(fn, ("foo", "foo")) self.checkScript(fn, ("foo", "bar")) self.checkScript(fn, ("foo", "")) def test_hash_device(self): def fn(d1: torch.device, d2: torch.device): return hash(d1) == hash(d2) gpu0 = torch.device("cuda:0") gpu1 = torch.device("cuda:1") cpu = torch.device("cpu") self.checkScript(fn, (gpu0, gpu0)) self.checkScript(fn, (gpu0, gpu1)) self.checkScript(fn, (gpu0, cpu)) self.checkScript(fn, (cpu, cpu))