xref: /aosp_15_r20/external/pytorch/test/distributions/test_constraints.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: distributions"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport pytest
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import biject_to, constraints, transform_to
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA
8*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard WorkerEXAMPLES = [
12*da0073e9SAndroid Build Coastguard Worker    (constraints.symmetric, False, [[2.0, 0], [2.0, 2]]),
13*da0073e9SAndroid Build Coastguard Worker    (constraints.positive_semidefinite, False, [[2.0, 0], [2.0, 2]]),
14*da0073e9SAndroid Build Coastguard Worker    (constraints.positive_definite, False, [[2.0, 0], [2.0, 2]]),
15*da0073e9SAndroid Build Coastguard Worker    (constraints.symmetric, True, [[3.0, -5], [-5.0, 3]]),
16*da0073e9SAndroid Build Coastguard Worker    (constraints.positive_semidefinite, False, [[3.0, -5], [-5.0, 3]]),
17*da0073e9SAndroid Build Coastguard Worker    (constraints.positive_definite, False, [[3.0, -5], [-5.0, 3]]),
18*da0073e9SAndroid Build Coastguard Worker    (constraints.symmetric, True, [[1.0, 2], [2.0, 4]]),
19*da0073e9SAndroid Build Coastguard Worker    (constraints.positive_semidefinite, True, [[1.0, 2], [2.0, 4]]),
20*da0073e9SAndroid Build Coastguard Worker    (constraints.positive_definite, False, [[1.0, 2], [2.0, 4]]),
21*da0073e9SAndroid Build Coastguard Worker    (constraints.symmetric, True, [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]]),
22*da0073e9SAndroid Build Coastguard Worker    (
23*da0073e9SAndroid Build Coastguard Worker        constraints.positive_semidefinite,
24*da0073e9SAndroid Build Coastguard Worker        False,
25*da0073e9SAndroid Build Coastguard Worker        [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]],
26*da0073e9SAndroid Build Coastguard Worker    ),
27*da0073e9SAndroid Build Coastguard Worker    (
28*da0073e9SAndroid Build Coastguard Worker        constraints.positive_definite,
29*da0073e9SAndroid Build Coastguard Worker        False,
30*da0073e9SAndroid Build Coastguard Worker        [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]],
31*da0073e9SAndroid Build Coastguard Worker    ),
32*da0073e9SAndroid Build Coastguard Worker    (constraints.symmetric, True, [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]]),
33*da0073e9SAndroid Build Coastguard Worker    (
34*da0073e9SAndroid Build Coastguard Worker        constraints.positive_semidefinite,
35*da0073e9SAndroid Build Coastguard Worker        True,
36*da0073e9SAndroid Build Coastguard Worker        [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]],
37*da0073e9SAndroid Build Coastguard Worker    ),
38*da0073e9SAndroid Build Coastguard Worker    (
39*da0073e9SAndroid Build Coastguard Worker        constraints.positive_definite,
40*da0073e9SAndroid Build Coastguard Worker        False,
41*da0073e9SAndroid Build Coastguard Worker        [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]],
42*da0073e9SAndroid Build Coastguard Worker    ),
43*da0073e9SAndroid Build Coastguard Worker    (constraints.symmetric, True, [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]]),
44*da0073e9SAndroid Build Coastguard Worker    (
45*da0073e9SAndroid Build Coastguard Worker        constraints.positive_semidefinite,
46*da0073e9SAndroid Build Coastguard Worker        True,
47*da0073e9SAndroid Build Coastguard Worker        [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]],
48*da0073e9SAndroid Build Coastguard Worker    ),
49*da0073e9SAndroid Build Coastguard Worker    (
50*da0073e9SAndroid Build Coastguard Worker        constraints.positive_definite,
51*da0073e9SAndroid Build Coastguard Worker        True,
52*da0073e9SAndroid Build Coastguard Worker        [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]],
53*da0073e9SAndroid Build Coastguard Worker    ),
54*da0073e9SAndroid Build Coastguard Worker]
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard WorkerCONSTRAINTS = [
57*da0073e9SAndroid Build Coastguard Worker    (constraints.real,),
58*da0073e9SAndroid Build Coastguard Worker    (constraints.real_vector,),
59*da0073e9SAndroid Build Coastguard Worker    (constraints.positive,),
60*da0073e9SAndroid Build Coastguard Worker    (constraints.greater_than, [-10.0, -2, 0, 2, 10]),
61*da0073e9SAndroid Build Coastguard Worker    (constraints.greater_than, 0),
62*da0073e9SAndroid Build Coastguard Worker    (constraints.greater_than, 2),
63*da0073e9SAndroid Build Coastguard Worker    (constraints.greater_than, -2),
64*da0073e9SAndroid Build Coastguard Worker    (constraints.greater_than_eq, 0),
65*da0073e9SAndroid Build Coastguard Worker    (constraints.greater_than_eq, 2),
66*da0073e9SAndroid Build Coastguard Worker    (constraints.greater_than_eq, -2),
67*da0073e9SAndroid Build Coastguard Worker    (constraints.less_than, [-10.0, -2, 0, 2, 10]),
68*da0073e9SAndroid Build Coastguard Worker    (constraints.less_than, 0),
69*da0073e9SAndroid Build Coastguard Worker    (constraints.less_than, 2),
70*da0073e9SAndroid Build Coastguard Worker    (constraints.less_than, -2),
71*da0073e9SAndroid Build Coastguard Worker    (constraints.unit_interval,),
72*da0073e9SAndroid Build Coastguard Worker    (constraints.interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]),
73*da0073e9SAndroid Build Coastguard Worker    (constraints.interval, -2, -1),
74*da0073e9SAndroid Build Coastguard Worker    (constraints.interval, 1, 2),
75*da0073e9SAndroid Build Coastguard Worker    (constraints.half_open_interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]),
76*da0073e9SAndroid Build Coastguard Worker    (constraints.half_open_interval, -2, -1),
77*da0073e9SAndroid Build Coastguard Worker    (constraints.half_open_interval, 1, 2),
78*da0073e9SAndroid Build Coastguard Worker    (constraints.simplex,),
79*da0073e9SAndroid Build Coastguard Worker    (constraints.corr_cholesky,),
80*da0073e9SAndroid Build Coastguard Worker    (constraints.lower_cholesky,),
81*da0073e9SAndroid Build Coastguard Worker    (constraints.positive_definite,),
82*da0073e9SAndroid Build Coastguard Worker]
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerdef build_constraint(constraint_fn, args, is_cuda=False):
86*da0073e9SAndroid Build Coastguard Worker    if not args:
87*da0073e9SAndroid Build Coastguard Worker        return constraint_fn
88*da0073e9SAndroid Build Coastguard Worker    t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
89*da0073e9SAndroid Build Coastguard Worker    return constraint_fn(*(t(x) if isinstance(x, list) else x for x in args))
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize(("constraint_fn", "result", "value"), EXAMPLES)
93*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize(
94*da0073e9SAndroid Build Coastguard Worker    "is_cuda",
95*da0073e9SAndroid Build Coastguard Worker    [
96*da0073e9SAndroid Build Coastguard Worker        False,
97*da0073e9SAndroid Build Coastguard Worker        pytest.param(
98*da0073e9SAndroid Build Coastguard Worker            True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
99*da0073e9SAndroid Build Coastguard Worker        ),
100*da0073e9SAndroid Build Coastguard Worker    ],
101*da0073e9SAndroid Build Coastguard Worker)
102*da0073e9SAndroid Build Coastguard Workerdef test_constraint(constraint_fn, result, value, is_cuda):
103*da0073e9SAndroid Build Coastguard Worker    t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
104*da0073e9SAndroid Build Coastguard Worker    assert constraint_fn.check(t(value)).all() == result
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize(
108*da0073e9SAndroid Build Coastguard Worker    ("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS]
109*da0073e9SAndroid Build Coastguard Worker)
110*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize(
111*da0073e9SAndroid Build Coastguard Worker    "is_cuda",
112*da0073e9SAndroid Build Coastguard Worker    [
113*da0073e9SAndroid Build Coastguard Worker        False,
114*da0073e9SAndroid Build Coastguard Worker        pytest.param(
115*da0073e9SAndroid Build Coastguard Worker            True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
116*da0073e9SAndroid Build Coastguard Worker        ),
117*da0073e9SAndroid Build Coastguard Worker    ],
118*da0073e9SAndroid Build Coastguard Worker)
119*da0073e9SAndroid Build Coastguard Workerdef test_biject_to(constraint_fn, args, is_cuda):
120*da0073e9SAndroid Build Coastguard Worker    constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
121*da0073e9SAndroid Build Coastguard Worker    try:
122*da0073e9SAndroid Build Coastguard Worker        t = biject_to(constraint)
123*da0073e9SAndroid Build Coastguard Worker    except NotImplementedError:
124*da0073e9SAndroid Build Coastguard Worker        pytest.skip("`biject_to` not implemented.")
125*da0073e9SAndroid Build Coastguard Worker    assert t.bijective, f"biject_to({constraint}) is not bijective"
126*da0073e9SAndroid Build Coastguard Worker    if constraint_fn is constraints.corr_cholesky:
127*da0073e9SAndroid Build Coastguard Worker        # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
128*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(6, 6, dtype=torch.double)
129*da0073e9SAndroid Build Coastguard Worker    else:
130*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, dtype=torch.double)
131*da0073e9SAndroid Build Coastguard Worker    if is_cuda:
132*da0073e9SAndroid Build Coastguard Worker        x = x.cuda()
133*da0073e9SAndroid Build Coastguard Worker    y = t(x)
134*da0073e9SAndroid Build Coastguard Worker    assert constraint.check(y).all(), "\n".join(
135*da0073e9SAndroid Build Coastguard Worker        [
136*da0073e9SAndroid Build Coastguard Worker            f"Failed to biject_to({constraint})",
137*da0073e9SAndroid Build Coastguard Worker            f"x = {x}",
138*da0073e9SAndroid Build Coastguard Worker            f"biject_to(...)(x) = {y}",
139*da0073e9SAndroid Build Coastguard Worker        ]
140*da0073e9SAndroid Build Coastguard Worker    )
141*da0073e9SAndroid Build Coastguard Worker    x2 = t.inv(y)
142*da0073e9SAndroid Build Coastguard Worker    assert torch.allclose(x, x2), f"Error in biject_to({constraint}) inverse"
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    j = t.log_abs_det_jacobian(x, y)
145*da0073e9SAndroid Build Coastguard Worker    assert j.shape == x.shape[: x.dim() - t.domain.event_dim]
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize(
149*da0073e9SAndroid Build Coastguard Worker    ("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS]
150*da0073e9SAndroid Build Coastguard Worker)
151*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize(
152*da0073e9SAndroid Build Coastguard Worker    "is_cuda",
153*da0073e9SAndroid Build Coastguard Worker    [
154*da0073e9SAndroid Build Coastguard Worker        False,
155*da0073e9SAndroid Build Coastguard Worker        pytest.param(
156*da0073e9SAndroid Build Coastguard Worker            True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
157*da0073e9SAndroid Build Coastguard Worker        ),
158*da0073e9SAndroid Build Coastguard Worker    ],
159*da0073e9SAndroid Build Coastguard Worker)
160*da0073e9SAndroid Build Coastguard Workerdef test_transform_to(constraint_fn, args, is_cuda):
161*da0073e9SAndroid Build Coastguard Worker    constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
162*da0073e9SAndroid Build Coastguard Worker    t = transform_to(constraint)
163*da0073e9SAndroid Build Coastguard Worker    if constraint_fn is constraints.corr_cholesky:
164*da0073e9SAndroid Build Coastguard Worker        # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
165*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(6, 6, dtype=torch.double)
166*da0073e9SAndroid Build Coastguard Worker    else:
167*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, dtype=torch.double)
168*da0073e9SAndroid Build Coastguard Worker    if is_cuda:
169*da0073e9SAndroid Build Coastguard Worker        x = x.cuda()
170*da0073e9SAndroid Build Coastguard Worker    y = t(x)
171*da0073e9SAndroid Build Coastguard Worker    assert constraint.check(y).all(), f"Failed to transform_to({constraint})"
172*da0073e9SAndroid Build Coastguard Worker    x2 = t.inv(y)
173*da0073e9SAndroid Build Coastguard Worker    y2 = t(x2)
174*da0073e9SAndroid Build Coastguard Worker    assert torch.allclose(y, y2), f"Error in transform_to({constraint}) pseudoinverse"
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
178*da0073e9SAndroid Build Coastguard Worker    run_tests()
179