1# Owner(s): ["module: distributions"] 2 3import pytest 4 5import torch 6from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix 7from torch.testing._internal.common_utils import run_tests 8 9 10@pytest.mark.parametrize( 11 "shape", 12 [ 13 (2, 2), 14 (3, 3), 15 (2, 4, 4), 16 (2, 2, 4, 4), 17 ], 18) 19def test_tril_matrix_to_vec(shape): 20 mat = torch.randn(shape) 21 n = mat.shape[-1] 22 for diag in range(-n, n): 23 actual = mat.tril(diag) 24 vec = tril_matrix_to_vec(actual, diag) 25 tril_mat = vec_to_tril_matrix(vec, diag) 26 assert torch.allclose(tril_mat, actual) 27 28 29if __name__ == "__main__": 30 run_tests() 31