1# mypy: allow-untyped-defs 2import math 3import warnings 4from numbers import Number 5from typing import Optional, Union 6 7import torch 8from torch import nan 9from torch.distributions import constraints 10from torch.distributions.exp_family import ExponentialFamily 11from torch.distributions.multivariate_normal import _precision_to_scale_tril 12from torch.distributions.utils import lazy_property 13from torch.types import _size 14 15 16__all__ = ["Wishart"] 17 18_log_2 = math.log(2) 19 20 21def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor: 22 assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function." 23 return torch.digamma( 24 x.unsqueeze(-1) 25 - torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,)) 26 ).sum(-1) 27 28 29def _clamp_above_eps(x: torch.Tensor) -> torch.Tensor: 30 # We assume positive input for this function 31 return x.clamp(min=torch.finfo(x.dtype).eps) 32 33 34class Wishart(ExponentialFamily): 35 r""" 36 Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`, 37 or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top` 38 39 Example: 40 >>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional") 41 >>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2)) 42 >>> m.sample() # Wishart distributed with mean=`df * I` and 43 >>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j 44 45 Args: 46 df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1 47 covariance_matrix (Tensor): positive-definite covariance matrix 48 precision_matrix (Tensor): positive-definite precision matrix 49 scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal 50 Note: 51 Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or 52 :attr:`scale_tril` can be specified. 53 Using :attr:`scale_tril` will be more efficient: all computations internally 54 are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or 55 :attr:`precision_matrix` is passed instead, it is only used to compute 56 the corresponding lower triangular matrices using a Cholesky decomposition. 57 'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1] 58 59 **References** 60 61 [1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`. 62 [2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`. 63 [3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`. 64 [4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203. 65 [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`. 66 """ 67 arg_constraints = { 68 "covariance_matrix": constraints.positive_definite, 69 "precision_matrix": constraints.positive_definite, 70 "scale_tril": constraints.lower_cholesky, 71 "df": constraints.greater_than(0), 72 } 73 support = constraints.positive_definite 74 has_rsample = True 75 _mean_carrier_measure = 0 76 77 def __init__( 78 self, 79 df: Union[torch.Tensor, Number], 80 covariance_matrix: Optional[torch.Tensor] = None, 81 precision_matrix: Optional[torch.Tensor] = None, 82 scale_tril: Optional[torch.Tensor] = None, 83 validate_args=None, 84 ): 85 assert (covariance_matrix is not None) + (scale_tril is not None) + ( 86 precision_matrix is not None 87 ) == 1, "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." 88 89 param = next( 90 p 91 for p in (covariance_matrix, precision_matrix, scale_tril) 92 if p is not None 93 ) 94 95 if param.dim() < 2: 96 raise ValueError( 97 "scale_tril must be at least two-dimensional, with optional leading batch dimensions" 98 ) 99 100 if isinstance(df, Number): 101 batch_shape = torch.Size(param.shape[:-2]) 102 self.df = torch.tensor(df, dtype=param.dtype, device=param.device) 103 else: 104 batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape) 105 self.df = df.expand(batch_shape) 106 event_shape = param.shape[-2:] 107 108 if self.df.le(event_shape[-1] - 1).any(): 109 raise ValueError( 110 f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}." 111 ) 112 113 if scale_tril is not None: 114 self.scale_tril = param.expand(batch_shape + (-1, -1)) 115 elif covariance_matrix is not None: 116 self.covariance_matrix = param.expand(batch_shape + (-1, -1)) 117 elif precision_matrix is not None: 118 self.precision_matrix = param.expand(batch_shape + (-1, -1)) 119 120 self.arg_constraints["df"] = constraints.greater_than(event_shape[-1] - 1) 121 if self.df.lt(event_shape[-1]).any(): 122 warnings.warn( 123 "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim." 124 ) 125 126 super().__init__(batch_shape, event_shape, validate_args=validate_args) 127 self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))] 128 129 if scale_tril is not None: 130 self._unbroadcasted_scale_tril = scale_tril 131 elif covariance_matrix is not None: 132 self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix) 133 else: # precision_matrix is not None 134 self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix) 135 136 # Chi2 distribution is needed for Bartlett decomposition sampling 137 self._dist_chi2 = torch.distributions.chi2.Chi2( 138 df=( 139 self.df.unsqueeze(-1) 140 - torch.arange( 141 self._event_shape[-1], 142 dtype=self._unbroadcasted_scale_tril.dtype, 143 device=self._unbroadcasted_scale_tril.device, 144 ).expand(batch_shape + (-1,)) 145 ) 146 ) 147 148 def expand(self, batch_shape, _instance=None): 149 new = self._get_checked_instance(Wishart, _instance) 150 batch_shape = torch.Size(batch_shape) 151 cov_shape = batch_shape + self.event_shape 152 new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape) 153 new.df = self.df.expand(batch_shape) 154 155 new._batch_dims = [-(x + 1) for x in range(len(batch_shape))] 156 157 if "covariance_matrix" in self.__dict__: 158 new.covariance_matrix = self.covariance_matrix.expand(cov_shape) 159 if "scale_tril" in self.__dict__: 160 new.scale_tril = self.scale_tril.expand(cov_shape) 161 if "precision_matrix" in self.__dict__: 162 new.precision_matrix = self.precision_matrix.expand(cov_shape) 163 164 # Chi2 distribution is needed for Bartlett decomposition sampling 165 new._dist_chi2 = torch.distributions.chi2.Chi2( 166 df=( 167 new.df.unsqueeze(-1) 168 - torch.arange( 169 self.event_shape[-1], 170 dtype=new._unbroadcasted_scale_tril.dtype, 171 device=new._unbroadcasted_scale_tril.device, 172 ).expand(batch_shape + (-1,)) 173 ) 174 ) 175 176 super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False) 177 new._validate_args = self._validate_args 178 return new 179 180 @lazy_property 181 def scale_tril(self): 182 return self._unbroadcasted_scale_tril.expand( 183 self._batch_shape + self._event_shape 184 ) 185 186 @lazy_property 187 def covariance_matrix(self): 188 return ( 189 self._unbroadcasted_scale_tril 190 @ self._unbroadcasted_scale_tril.transpose(-2, -1) 191 ).expand(self._batch_shape + self._event_shape) 192 193 @lazy_property 194 def precision_matrix(self): 195 identity = torch.eye( 196 self._event_shape[-1], 197 device=self._unbroadcasted_scale_tril.device, 198 dtype=self._unbroadcasted_scale_tril.dtype, 199 ) 200 return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand( 201 self._batch_shape + self._event_shape 202 ) 203 204 @property 205 def mean(self): 206 return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix 207 208 @property 209 def mode(self): 210 factor = self.df - self.covariance_matrix.shape[-1] - 1 211 factor[factor <= 0] = nan 212 return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix 213 214 @property 215 def variance(self): 216 V = self.covariance_matrix # has shape (batch_shape x event_shape) 217 diag_V = V.diagonal(dim1=-2, dim2=-1) 218 return self.df.view(self._batch_shape + (1, 1)) * ( 219 V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V) 220 ) 221 222 def _bartlett_sampling(self, sample_shape=torch.Size()): 223 p = self._event_shape[-1] # has singleton shape 224 225 # Implemented Sampling using Bartlett decomposition 226 noise = _clamp_above_eps( 227 self._dist_chi2.rsample(sample_shape).sqrt() 228 ).diag_embed(dim1=-2, dim2=-1) 229 230 i, j = torch.tril_indices(p, p, offset=-1) 231 noise[..., i, j] = torch.randn( 232 torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),), 233 dtype=noise.dtype, 234 device=noise.device, 235 ) 236 chol = self._unbroadcasted_scale_tril @ noise 237 return chol @ chol.transpose(-2, -1) 238 239 def rsample( 240 self, sample_shape: _size = torch.Size(), max_try_correction=None 241 ) -> torch.Tensor: 242 r""" 243 .. warning:: 244 In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples. 245 Several tries to correct singular samples are performed by default, but it may end up returning 246 singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`. 247 In those cases, the user should validate the samples and either fix the value of `df` 248 or adjust `max_try_correction` value for argument in `.rsample` accordingly. 249 """ 250 251 if max_try_correction is None: 252 max_try_correction = 3 if torch._C._get_tracing_state() else 10 253 254 sample_shape = torch.Size(sample_shape) 255 sample = self._bartlett_sampling(sample_shape) 256 257 # Below part is to improve numerical stability temporally and should be removed in the future 258 is_singular = self.support.check(sample) 259 if self._batch_shape: 260 is_singular = is_singular.amax(self._batch_dims) 261 262 if torch._C._get_tracing_state(): 263 # Less optimized version for JIT 264 for _ in range(max_try_correction): 265 sample_new = self._bartlett_sampling(sample_shape) 266 sample = torch.where(is_singular, sample_new, sample) 267 268 is_singular = ~self.support.check(sample) 269 if self._batch_shape: 270 is_singular = is_singular.amax(self._batch_dims) 271 272 else: 273 # More optimized version with data-dependent control flow. 274 if is_singular.any(): 275 warnings.warn("Singular sample detected.") 276 277 for _ in range(max_try_correction): 278 sample_new = self._bartlett_sampling(is_singular[is_singular].shape) 279 sample[is_singular] = sample_new 280 281 is_singular_new = ~self.support.check(sample_new) 282 if self._batch_shape: 283 is_singular_new = is_singular_new.amax(self._batch_dims) 284 is_singular[is_singular.clone()] = is_singular_new 285 286 if not is_singular.any(): 287 break 288 289 return sample 290 291 def log_prob(self, value): 292 if self._validate_args: 293 self._validate_sample(value) 294 nu = self.df # has shape (batch_shape) 295 p = self._event_shape[-1] # has singleton shape 296 return ( 297 -nu 298 * ( 299 p * _log_2 / 2 300 + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1) 301 .log() 302 .sum(-1) 303 ) 304 - torch.mvlgamma(nu / 2, p=p) 305 + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet 306 - torch.cholesky_solve(value, self._unbroadcasted_scale_tril) 307 .diagonal(dim1=-2, dim2=-1) 308 .sum(dim=-1) 309 / 2 310 ) 311 312 def entropy(self): 313 nu = self.df # has shape (batch_shape) 314 p = self._event_shape[-1] # has singleton shape 315 V = self.covariance_matrix # has shape (batch_shape x event_shape) 316 return ( 317 (p + 1) 318 * ( 319 p * _log_2 / 2 320 + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1) 321 .log() 322 .sum(-1) 323 ) 324 + torch.mvlgamma(nu / 2, p=p) 325 - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p) 326 + nu * p / 2 327 ) 328 329 @property 330 def _natural_params(self): 331 nu = self.df # has shape (batch_shape) 332 p = self._event_shape[-1] # has singleton shape 333 return -self.precision_matrix / 2, (nu - p - 1) / 2 334 335 def _log_normalizer(self, x, y): 336 p = self._event_shape[-1] 337 return (y + (p + 1) / 2) * ( 338 -torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p 339 ) + torch.mvlgamma(y + (p + 1) / 2, p=p) 340