1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Tuple 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 11*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 17*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 18*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 19*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 20*da0073e9SAndroid Build Coastguard Worker "instead." 21*da0073e9SAndroid Build Coastguard Worker ) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerclass TestHash(JitTestCase): 25*da0073e9SAndroid Build Coastguard Worker def test_hash_tuple(self): 26*da0073e9SAndroid Build Coastguard Worker def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool: 27*da0073e9SAndroid Build Coastguard Worker return hash(t1) == hash(t2) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ((1, 2), (1, 2))) 30*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ((1, 2), (3, 4))) 31*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ((1, 2), (2, 1))) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker def test_hash_tuple_nested_unhashable_type(self): 34*da0073e9SAndroid Build Coastguard Worker # Tuples may contain unhashable types like `list`, check that we error 35*da0073e9SAndroid Build Coastguard Worker # properly in that case. 36*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 37*da0073e9SAndroid Build Coastguard Worker def fn_unhashable(t1: Tuple[int, List[int]]): 38*da0073e9SAndroid Build Coastguard Worker return hash(t1) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight(RuntimeError, "unhashable", "hash"): 41*da0073e9SAndroid Build Coastguard Worker fn_unhashable((1, [1])) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker def test_hash_tensor(self): 44*da0073e9SAndroid Build Coastguard Worker """Tensors should hash by identity""" 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker def fn(t1, t2): 47*da0073e9SAndroid Build Coastguard Worker return hash(t1) == hash(t2) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker tensor1 = torch.tensor(1) 50*da0073e9SAndroid Build Coastguard Worker tensor1_clone = torch.tensor(1) 51*da0073e9SAndroid Build Coastguard Worker tensor2 = torch.tensor(2) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (tensor1, tensor1)) 54*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (tensor1, tensor1_clone)) 55*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (tensor1, tensor2)) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker def test_hash_none(self): 58*da0073e9SAndroid Build Coastguard Worker def fn(): 59*da0073e9SAndroid Build Coastguard Worker n1 = None 60*da0073e9SAndroid Build Coastguard Worker n2 = None 61*da0073e9SAndroid Build Coastguard Worker return hash(n1) == hash(n2) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def test_hash_bool(self): 66*da0073e9SAndroid Build Coastguard Worker def fn(b1: bool, b2: bool): 67*da0073e9SAndroid Build Coastguard Worker return hash(b1) == hash(b2) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (True, False)) 70*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (True, True)) 71*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (False, True)) 72*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (False, False)) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def test_hash_float(self): 75*da0073e9SAndroid Build Coastguard Worker def fn(f1: float, f2: float): 76*da0073e9SAndroid Build Coastguard Worker return hash(f1) == hash(f2) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1.2345, 1.2345)) 79*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1.2345, 6.789)) 80*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1.2345, float("inf"))) 81*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (float("inf"), float("inf"))) 82*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1.2345, float("nan"))) 83*da0073e9SAndroid Build Coastguard Worker if sys.version_info < (3, 10): 84*da0073e9SAndroid Build Coastguard Worker # Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html : 85*da0073e9SAndroid Build Coastguard Worker # Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity. 86*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (float("nan"), float("nan"))) 87*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (float("nan"), float("inf"))) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker def test_hash_int(self): 90*da0073e9SAndroid Build Coastguard Worker def fn(i1: int, i2: int): 91*da0073e9SAndroid Build Coastguard Worker return hash(i1) == hash(i2) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (123, 456)) 94*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (123, 123)) 95*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (123, -123)) 96*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (-123, -123)) 97*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (123, 0)) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker def test_hash_string(self): 100*da0073e9SAndroid Build Coastguard Worker def fn(s1: str, s2: str): 101*da0073e9SAndroid Build Coastguard Worker return hash(s1) == hash(s2) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("foo", "foo")) 104*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("foo", "bar")) 105*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("foo", "")) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker def test_hash_device(self): 108*da0073e9SAndroid Build Coastguard Worker def fn(d1: torch.device, d2: torch.device): 109*da0073e9SAndroid Build Coastguard Worker return hash(d1) == hash(d2) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker gpu0 = torch.device("cuda:0") 112*da0073e9SAndroid Build Coastguard Worker gpu1 = torch.device("cuda:1") 113*da0073e9SAndroid Build Coastguard Worker cpu = torch.device("cpu") 114*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (gpu0, gpu0)) 115*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (gpu0, gpu1)) 116*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (gpu0, cpu)) 117*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (cpu, cpu)) 118