1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: distributions"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport pytest 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import biject_to, constraints, transform_to 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA 8*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard WorkerEXAMPLES = [ 12*da0073e9SAndroid Build Coastguard Worker (constraints.symmetric, False, [[2.0, 0], [2.0, 2]]), 13*da0073e9SAndroid Build Coastguard Worker (constraints.positive_semidefinite, False, [[2.0, 0], [2.0, 2]]), 14*da0073e9SAndroid Build Coastguard Worker (constraints.positive_definite, False, [[2.0, 0], [2.0, 2]]), 15*da0073e9SAndroid Build Coastguard Worker (constraints.symmetric, True, [[3.0, -5], [-5.0, 3]]), 16*da0073e9SAndroid Build Coastguard Worker (constraints.positive_semidefinite, False, [[3.0, -5], [-5.0, 3]]), 17*da0073e9SAndroid Build Coastguard Worker (constraints.positive_definite, False, [[3.0, -5], [-5.0, 3]]), 18*da0073e9SAndroid Build Coastguard Worker (constraints.symmetric, True, [[1.0, 2], [2.0, 4]]), 19*da0073e9SAndroid Build Coastguard Worker (constraints.positive_semidefinite, True, [[1.0, 2], [2.0, 4]]), 20*da0073e9SAndroid Build Coastguard Worker (constraints.positive_definite, False, [[1.0, 2], [2.0, 4]]), 21*da0073e9SAndroid Build Coastguard Worker (constraints.symmetric, True, [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]]), 22*da0073e9SAndroid Build Coastguard Worker ( 23*da0073e9SAndroid Build Coastguard Worker constraints.positive_semidefinite, 24*da0073e9SAndroid Build Coastguard Worker False, 25*da0073e9SAndroid Build Coastguard Worker [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]], 26*da0073e9SAndroid Build Coastguard Worker ), 27*da0073e9SAndroid Build Coastguard Worker ( 28*da0073e9SAndroid Build Coastguard Worker constraints.positive_definite, 29*da0073e9SAndroid Build Coastguard Worker False, 30*da0073e9SAndroid Build Coastguard Worker [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]], 31*da0073e9SAndroid Build Coastguard Worker ), 32*da0073e9SAndroid Build Coastguard Worker (constraints.symmetric, True, [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]]), 33*da0073e9SAndroid Build Coastguard Worker ( 34*da0073e9SAndroid Build Coastguard Worker constraints.positive_semidefinite, 35*da0073e9SAndroid Build Coastguard Worker True, 36*da0073e9SAndroid Build Coastguard Worker [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]], 37*da0073e9SAndroid Build Coastguard Worker ), 38*da0073e9SAndroid Build Coastguard Worker ( 39*da0073e9SAndroid Build Coastguard Worker constraints.positive_definite, 40*da0073e9SAndroid Build Coastguard Worker False, 41*da0073e9SAndroid Build Coastguard Worker [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]], 42*da0073e9SAndroid Build Coastguard Worker ), 43*da0073e9SAndroid Build Coastguard Worker (constraints.symmetric, True, [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]]), 44*da0073e9SAndroid Build Coastguard Worker ( 45*da0073e9SAndroid Build Coastguard Worker constraints.positive_semidefinite, 46*da0073e9SAndroid Build Coastguard Worker True, 47*da0073e9SAndroid Build Coastguard Worker [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]], 48*da0073e9SAndroid Build Coastguard Worker ), 49*da0073e9SAndroid Build Coastguard Worker ( 50*da0073e9SAndroid Build Coastguard Worker constraints.positive_definite, 51*da0073e9SAndroid Build Coastguard Worker True, 52*da0073e9SAndroid Build Coastguard Worker [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]], 53*da0073e9SAndroid Build Coastguard Worker ), 54*da0073e9SAndroid Build Coastguard Worker] 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard WorkerCONSTRAINTS = [ 57*da0073e9SAndroid Build Coastguard Worker (constraints.real,), 58*da0073e9SAndroid Build Coastguard Worker (constraints.real_vector,), 59*da0073e9SAndroid Build Coastguard Worker (constraints.positive,), 60*da0073e9SAndroid Build Coastguard Worker (constraints.greater_than, [-10.0, -2, 0, 2, 10]), 61*da0073e9SAndroid Build Coastguard Worker (constraints.greater_than, 0), 62*da0073e9SAndroid Build Coastguard Worker (constraints.greater_than, 2), 63*da0073e9SAndroid Build Coastguard Worker (constraints.greater_than, -2), 64*da0073e9SAndroid Build Coastguard Worker (constraints.greater_than_eq, 0), 65*da0073e9SAndroid Build Coastguard Worker (constraints.greater_than_eq, 2), 66*da0073e9SAndroid Build Coastguard Worker (constraints.greater_than_eq, -2), 67*da0073e9SAndroid Build Coastguard Worker (constraints.less_than, [-10.0, -2, 0, 2, 10]), 68*da0073e9SAndroid Build Coastguard Worker (constraints.less_than, 0), 69*da0073e9SAndroid Build Coastguard Worker (constraints.less_than, 2), 70*da0073e9SAndroid Build Coastguard Worker (constraints.less_than, -2), 71*da0073e9SAndroid Build Coastguard Worker (constraints.unit_interval,), 72*da0073e9SAndroid Build Coastguard Worker (constraints.interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]), 73*da0073e9SAndroid Build Coastguard Worker (constraints.interval, -2, -1), 74*da0073e9SAndroid Build Coastguard Worker (constraints.interval, 1, 2), 75*da0073e9SAndroid Build Coastguard Worker (constraints.half_open_interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]), 76*da0073e9SAndroid Build Coastguard Worker (constraints.half_open_interval, -2, -1), 77*da0073e9SAndroid Build Coastguard Worker (constraints.half_open_interval, 1, 2), 78*da0073e9SAndroid Build Coastguard Worker (constraints.simplex,), 79*da0073e9SAndroid Build Coastguard Worker (constraints.corr_cholesky,), 80*da0073e9SAndroid Build Coastguard Worker (constraints.lower_cholesky,), 81*da0073e9SAndroid Build Coastguard Worker (constraints.positive_definite,), 82*da0073e9SAndroid Build Coastguard Worker] 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerdef build_constraint(constraint_fn, args, is_cuda=False): 86*da0073e9SAndroid Build Coastguard Worker if not args: 87*da0073e9SAndroid Build Coastguard Worker return constraint_fn 88*da0073e9SAndroid Build Coastguard Worker t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor 89*da0073e9SAndroid Build Coastguard Worker return constraint_fn(*(t(x) if isinstance(x, list) else x for x in args)) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize(("constraint_fn", "result", "value"), EXAMPLES) 93*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize( 94*da0073e9SAndroid Build Coastguard Worker "is_cuda", 95*da0073e9SAndroid Build Coastguard Worker [ 96*da0073e9SAndroid Build Coastguard Worker False, 97*da0073e9SAndroid Build Coastguard Worker pytest.param( 98*da0073e9SAndroid Build Coastguard Worker True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.") 99*da0073e9SAndroid Build Coastguard Worker ), 100*da0073e9SAndroid Build Coastguard Worker ], 101*da0073e9SAndroid Build Coastguard Worker) 102*da0073e9SAndroid Build Coastguard Workerdef test_constraint(constraint_fn, result, value, is_cuda): 103*da0073e9SAndroid Build Coastguard Worker t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor 104*da0073e9SAndroid Build Coastguard Worker assert constraint_fn.check(t(value)).all() == result 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize( 108*da0073e9SAndroid Build Coastguard Worker ("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS] 109*da0073e9SAndroid Build Coastguard Worker) 110*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize( 111*da0073e9SAndroid Build Coastguard Worker "is_cuda", 112*da0073e9SAndroid Build Coastguard Worker [ 113*da0073e9SAndroid Build Coastguard Worker False, 114*da0073e9SAndroid Build Coastguard Worker pytest.param( 115*da0073e9SAndroid Build Coastguard Worker True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.") 116*da0073e9SAndroid Build Coastguard Worker ), 117*da0073e9SAndroid Build Coastguard Worker ], 118*da0073e9SAndroid Build Coastguard Worker) 119*da0073e9SAndroid Build Coastguard Workerdef test_biject_to(constraint_fn, args, is_cuda): 120*da0073e9SAndroid Build Coastguard Worker constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda) 121*da0073e9SAndroid Build Coastguard Worker try: 122*da0073e9SAndroid Build Coastguard Worker t = biject_to(constraint) 123*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 124*da0073e9SAndroid Build Coastguard Worker pytest.skip("`biject_to` not implemented.") 125*da0073e9SAndroid Build Coastguard Worker assert t.bijective, f"biject_to({constraint}) is not bijective" 126*da0073e9SAndroid Build Coastguard Worker if constraint_fn is constraints.corr_cholesky: 127*da0073e9SAndroid Build Coastguard Worker # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim) 128*da0073e9SAndroid Build Coastguard Worker x = torch.randn(6, 6, dtype=torch.double) 129*da0073e9SAndroid Build Coastguard Worker else: 130*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, dtype=torch.double) 131*da0073e9SAndroid Build Coastguard Worker if is_cuda: 132*da0073e9SAndroid Build Coastguard Worker x = x.cuda() 133*da0073e9SAndroid Build Coastguard Worker y = t(x) 134*da0073e9SAndroid Build Coastguard Worker assert constraint.check(y).all(), "\n".join( 135*da0073e9SAndroid Build Coastguard Worker [ 136*da0073e9SAndroid Build Coastguard Worker f"Failed to biject_to({constraint})", 137*da0073e9SAndroid Build Coastguard Worker f"x = {x}", 138*da0073e9SAndroid Build Coastguard Worker f"biject_to(...)(x) = {y}", 139*da0073e9SAndroid Build Coastguard Worker ] 140*da0073e9SAndroid Build Coastguard Worker ) 141*da0073e9SAndroid Build Coastguard Worker x2 = t.inv(y) 142*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(x, x2), f"Error in biject_to({constraint}) inverse" 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker j = t.log_abs_det_jacobian(x, y) 145*da0073e9SAndroid Build Coastguard Worker assert j.shape == x.shape[: x.dim() - t.domain.event_dim] 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize( 149*da0073e9SAndroid Build Coastguard Worker ("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS] 150*da0073e9SAndroid Build Coastguard Worker) 151*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize( 152*da0073e9SAndroid Build Coastguard Worker "is_cuda", 153*da0073e9SAndroid Build Coastguard Worker [ 154*da0073e9SAndroid Build Coastguard Worker False, 155*da0073e9SAndroid Build Coastguard Worker pytest.param( 156*da0073e9SAndroid Build Coastguard Worker True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.") 157*da0073e9SAndroid Build Coastguard Worker ), 158*da0073e9SAndroid Build Coastguard Worker ], 159*da0073e9SAndroid Build Coastguard Worker) 160*da0073e9SAndroid Build Coastguard Workerdef test_transform_to(constraint_fn, args, is_cuda): 161*da0073e9SAndroid Build Coastguard Worker constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda) 162*da0073e9SAndroid Build Coastguard Worker t = transform_to(constraint) 163*da0073e9SAndroid Build Coastguard Worker if constraint_fn is constraints.corr_cholesky: 164*da0073e9SAndroid Build Coastguard Worker # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim) 165*da0073e9SAndroid Build Coastguard Worker x = torch.randn(6, 6, dtype=torch.double) 166*da0073e9SAndroid Build Coastguard Worker else: 167*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, dtype=torch.double) 168*da0073e9SAndroid Build Coastguard Worker if is_cuda: 169*da0073e9SAndroid Build Coastguard Worker x = x.cuda() 170*da0073e9SAndroid Build Coastguard Worker y = t(x) 171*da0073e9SAndroid Build Coastguard Worker assert constraint.check(y).all(), f"Failed to transform_to({constraint})" 172*da0073e9SAndroid Build Coastguard Worker x2 = t.inv(y) 173*da0073e9SAndroid Build Coastguard Worker y2 = t(x2) 174*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(y, y2), f"Error in transform_to({constraint}) pseudoinverse" 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 178*da0073e9SAndroid Build Coastguard Worker run_tests() 179