1# mypy: allow-untyped-defs 2from numbers import Number 3 4import torch 5from torch.distributions import constraints 6from torch.distributions.exp_family import ExponentialFamily 7from torch.distributions.utils import broadcast_all 8 9 10__all__ = ["Poisson"] 11 12 13class Poisson(ExponentialFamily): 14 r""" 15 Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter. 16 17 Samples are nonnegative integers, with a pmf given by 18 19 .. math:: 20 \mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!} 21 22 Example:: 23 24 >>> # xdoctest: +SKIP("poisson_cpu not implemented for 'Long'") 25 >>> m = Poisson(torch.tensor([4])) 26 >>> m.sample() 27 tensor([ 3.]) 28 29 Args: 30 rate (Number, Tensor): the rate parameter 31 """ 32 arg_constraints = {"rate": constraints.nonnegative} 33 support = constraints.nonnegative_integer 34 35 @property 36 def mean(self): 37 return self.rate 38 39 @property 40 def mode(self): 41 return self.rate.floor() 42 43 @property 44 def variance(self): 45 return self.rate 46 47 def __init__(self, rate, validate_args=None): 48 (self.rate,) = broadcast_all(rate) 49 if isinstance(rate, Number): 50 batch_shape = torch.Size() 51 else: 52 batch_shape = self.rate.size() 53 super().__init__(batch_shape, validate_args=validate_args) 54 55 def expand(self, batch_shape, _instance=None): 56 new = self._get_checked_instance(Poisson, _instance) 57 batch_shape = torch.Size(batch_shape) 58 new.rate = self.rate.expand(batch_shape) 59 super(Poisson, new).__init__(batch_shape, validate_args=False) 60 new._validate_args = self._validate_args 61 return new 62 63 def sample(self, sample_shape=torch.Size()): 64 shape = self._extended_shape(sample_shape) 65 with torch.no_grad(): 66 return torch.poisson(self.rate.expand(shape)) 67 68 def log_prob(self, value): 69 if self._validate_args: 70 self._validate_sample(value) 71 rate, value = broadcast_all(self.rate, value) 72 return value.xlogy(rate) - rate - (value + 1).lgamma() 73 74 @property 75 def _natural_params(self): 76 return (torch.log(self.rate),) 77 78 def _log_normalizer(self, x): 79 return torch.exp(x) 80