1# mypy: allow-untyped-defs 2import math 3from numbers import Number, Real 4 5import torch 6from torch.distributions import constraints 7from torch.distributions.exp_family import ExponentialFamily 8from torch.distributions.utils import _standard_normal, broadcast_all 9from torch.types import _size 10 11 12__all__ = ["Normal"] 13 14 15class Normal(ExponentialFamily): 16 r""" 17 Creates a normal (also called Gaussian) distribution parameterized by 18 :attr:`loc` and :attr:`scale`. 19 20 Example:: 21 22 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 23 >>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0])) 24 >>> m.sample() # normally distributed with loc=0 and scale=1 25 tensor([ 0.1046]) 26 27 Args: 28 loc (float or Tensor): mean of the distribution (often referred to as mu) 29 scale (float or Tensor): standard deviation of the distribution 30 (often referred to as sigma) 31 """ 32 arg_constraints = {"loc": constraints.real, "scale": constraints.positive} 33 support = constraints.real 34 has_rsample = True 35 _mean_carrier_measure = 0 36 37 @property 38 def mean(self): 39 return self.loc 40 41 @property 42 def mode(self): 43 return self.loc 44 45 @property 46 def stddev(self): 47 return self.scale 48 49 @property 50 def variance(self): 51 return self.stddev.pow(2) 52 53 def __init__(self, loc, scale, validate_args=None): 54 self.loc, self.scale = broadcast_all(loc, scale) 55 if isinstance(loc, Number) and isinstance(scale, Number): 56 batch_shape = torch.Size() 57 else: 58 batch_shape = self.loc.size() 59 super().__init__(batch_shape, validate_args=validate_args) 60 61 def expand(self, batch_shape, _instance=None): 62 new = self._get_checked_instance(Normal, _instance) 63 batch_shape = torch.Size(batch_shape) 64 new.loc = self.loc.expand(batch_shape) 65 new.scale = self.scale.expand(batch_shape) 66 super(Normal, new).__init__(batch_shape, validate_args=False) 67 new._validate_args = self._validate_args 68 return new 69 70 def sample(self, sample_shape=torch.Size()): 71 shape = self._extended_shape(sample_shape) 72 with torch.no_grad(): 73 return torch.normal(self.loc.expand(shape), self.scale.expand(shape)) 74 75 def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 76 shape = self._extended_shape(sample_shape) 77 eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) 78 return self.loc + eps * self.scale 79 80 def log_prob(self, value): 81 if self._validate_args: 82 self._validate_sample(value) 83 # compute the variance 84 var = self.scale**2 85 log_scale = ( 86 math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log() 87 ) 88 return ( 89 -((value - self.loc) ** 2) / (2 * var) 90 - log_scale 91 - math.log(math.sqrt(2 * math.pi)) 92 ) 93 94 def cdf(self, value): 95 if self._validate_args: 96 self._validate_sample(value) 97 return 0.5 * ( 98 1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)) 99 ) 100 101 def icdf(self, value): 102 return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2) 103 104 def entropy(self): 105 return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale) 106 107 @property 108 def _natural_params(self): 109 return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal()) 110 111 def _log_normalizer(self, x, y): 112 return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y) 113