# Owner(s): ["module: distributions"] import pytest import torch from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix from torch.testing._internal.common_utils import run_tests @pytest.mark.parametrize( "shape", [ (2, 2), (3, 3), (2, 4, 4), (2, 2, 4, 4), ], ) def test_tril_matrix_to_vec(shape): mat = torch.randn(shape) n = mat.shape[-1] for diag in range(-n, n): actual = mat.tril(diag) vec = tril_matrix_to_vec(actual, diag) tril_mat = vec_to_tril_matrix(vec, diag) assert torch.allclose(tril_mat, actual) if __name__ == "__main__": run_tests()