1# mypy: allow-untyped-defs 2from torch.distributions import constraints 3from torch.distributions.normal import Normal 4from torch.distributions.transformed_distribution import TransformedDistribution 5from torch.distributions.transforms import ExpTransform 6 7 8__all__ = ["LogNormal"] 9 10 11class LogNormal(TransformedDistribution): 12 r""" 13 Creates a log-normal distribution parameterized by 14 :attr:`loc` and :attr:`scale` where:: 15 16 X ~ Normal(loc, scale) 17 Y = exp(X) ~ LogNormal(loc, scale) 18 19 Example:: 20 21 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 22 >>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0])) 23 >>> m.sample() # log-normal distributed with mean=0 and stddev=1 24 tensor([ 0.1046]) 25 26 Args: 27 loc (float or Tensor): mean of log of distribution 28 scale (float or Tensor): standard deviation of log of the distribution 29 """ 30 arg_constraints = {"loc": constraints.real, "scale": constraints.positive} 31 support = constraints.positive 32 has_rsample = True 33 34 def __init__(self, loc, scale, validate_args=None): 35 base_dist = Normal(loc, scale, validate_args=validate_args) 36 super().__init__(base_dist, ExpTransform(), validate_args=validate_args) 37 38 def expand(self, batch_shape, _instance=None): 39 new = self._get_checked_instance(LogNormal, _instance) 40 return super().expand(batch_shape, _instance=new) 41 42 @property 43 def loc(self): 44 return self.base_dist.loc 45 46 @property 47 def scale(self): 48 return self.base_dist.scale 49 50 @property 51 def mean(self): 52 return (self.loc + self.scale.pow(2) / 2).exp() 53 54 @property 55 def mode(self): 56 return (self.loc - self.scale.square()).exp() 57 58 @property 59 def variance(self): 60 scale_sq = self.scale.pow(2) 61 return scale_sq.expm1() * (2 * self.loc + scale_sq).exp() 62 63 def entropy(self): 64 return self.base_dist.entropy() + self.loc 65