xref: /aosp_15_r20/external/pytorch/test/jit/test_hash.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Tuple
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
11*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
17*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
18*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
19*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
20*da0073e9SAndroid Build Coastguard Worker        "instead."
21*da0073e9SAndroid Build Coastguard Worker    )
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerclass TestHash(JitTestCase):
25*da0073e9SAndroid Build Coastguard Worker    def test_hash_tuple(self):
26*da0073e9SAndroid Build Coastguard Worker        def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool:
27*da0073e9SAndroid Build Coastguard Worker            return hash(t1) == hash(t2)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ((1, 2), (1, 2)))
30*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ((1, 2), (3, 4)))
31*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ((1, 2), (2, 1)))
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    def test_hash_tuple_nested_unhashable_type(self):
34*da0073e9SAndroid Build Coastguard Worker        # Tuples may contain unhashable types like `list`, check that we error
35*da0073e9SAndroid Build Coastguard Worker        # properly in that case.
36*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
37*da0073e9SAndroid Build Coastguard Worker        def fn_unhashable(t1: Tuple[int, List[int]]):
38*da0073e9SAndroid Build Coastguard Worker            return hash(t1)
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(RuntimeError, "unhashable", "hash"):
41*da0073e9SAndroid Build Coastguard Worker            fn_unhashable((1, [1]))
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    def test_hash_tensor(self):
44*da0073e9SAndroid Build Coastguard Worker        """Tensors should hash by identity"""
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker        def fn(t1, t2):
47*da0073e9SAndroid Build Coastguard Worker            return hash(t1) == hash(t2)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        tensor1 = torch.tensor(1)
50*da0073e9SAndroid Build Coastguard Worker        tensor1_clone = torch.tensor(1)
51*da0073e9SAndroid Build Coastguard Worker        tensor2 = torch.tensor(2)
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (tensor1, tensor1))
54*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (tensor1, tensor1_clone))
55*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (tensor1, tensor2))
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    def test_hash_none(self):
58*da0073e9SAndroid Build Coastguard Worker        def fn():
59*da0073e9SAndroid Build Coastguard Worker            n1 = None
60*da0073e9SAndroid Build Coastguard Worker            n2 = None
61*da0073e9SAndroid Build Coastguard Worker            return hash(n1) == hash(n2)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    def test_hash_bool(self):
66*da0073e9SAndroid Build Coastguard Worker        def fn(b1: bool, b2: bool):
67*da0073e9SAndroid Build Coastguard Worker            return hash(b1) == hash(b2)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (True, False))
70*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (True, True))
71*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (False, True))
72*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (False, False))
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def test_hash_float(self):
75*da0073e9SAndroid Build Coastguard Worker        def fn(f1: float, f2: float):
76*da0073e9SAndroid Build Coastguard Worker            return hash(f1) == hash(f2)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1.2345, 1.2345))
79*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1.2345, 6.789))
80*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1.2345, float("inf")))
81*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (float("inf"), float("inf")))
82*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1.2345, float("nan")))
83*da0073e9SAndroid Build Coastguard Worker        if sys.version_info < (3, 10):
84*da0073e9SAndroid Build Coastguard Worker            # Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html :
85*da0073e9SAndroid Build Coastguard Worker            # Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity.
86*da0073e9SAndroid Build Coastguard Worker            self.checkScript(fn, (float("nan"), float("nan")))
87*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (float("nan"), float("inf")))
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    def test_hash_int(self):
90*da0073e9SAndroid Build Coastguard Worker        def fn(i1: int, i2: int):
91*da0073e9SAndroid Build Coastguard Worker            return hash(i1) == hash(i2)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (123, 456))
94*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (123, 123))
95*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (123, -123))
96*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (-123, -123))
97*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (123, 0))
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker    def test_hash_string(self):
100*da0073e9SAndroid Build Coastguard Worker        def fn(s1: str, s2: str):
101*da0073e9SAndroid Build Coastguard Worker            return hash(s1) == hash(s2)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("foo", "foo"))
104*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("foo", "bar"))
105*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("foo", ""))
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    def test_hash_device(self):
108*da0073e9SAndroid Build Coastguard Worker        def fn(d1: torch.device, d2: torch.device):
109*da0073e9SAndroid Build Coastguard Worker            return hash(d1) == hash(d2)
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker        gpu0 = torch.device("cuda:0")
112*da0073e9SAndroid Build Coastguard Worker        gpu1 = torch.device("cuda:1")
113*da0073e9SAndroid Build Coastguard Worker        cpu = torch.device("cpu")
114*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (gpu0, gpu0))
115*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (gpu0, gpu1))
116*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (gpu0, cpu))
117*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (cpu, cpu))
118