1# mypy: allow-untyped-defs 2import torch 3import torch.nn.functional as F 4from torch.distributions import constraints 5from torch.distributions.distribution import Distribution 6from torch.distributions.utils import ( 7 broadcast_all, 8 lazy_property, 9 logits_to_probs, 10 probs_to_logits, 11) 12 13 14__all__ = ["NegativeBinomial"] 15 16 17class NegativeBinomial(Distribution): 18 r""" 19 Creates a Negative Binomial distribution, i.e. distribution 20 of the number of successful independent and identical Bernoulli trials 21 before :attr:`total_count` failures are achieved. The probability 22 of success of each Bernoulli trial is :attr:`probs`. 23 24 Args: 25 total_count (float or Tensor): non-negative number of negative Bernoulli 26 trials to stop, although the distribution is still valid for real 27 valued count 28 probs (Tensor): Event probabilities of success in the half open interval [0, 1) 29 logits (Tensor): Event log-odds for probabilities of success 30 """ 31 arg_constraints = { 32 "total_count": constraints.greater_than_eq(0), 33 "probs": constraints.half_open_interval(0.0, 1.0), 34 "logits": constraints.real, 35 } 36 support = constraints.nonnegative_integer 37 38 def __init__(self, total_count, probs=None, logits=None, validate_args=None): 39 if (probs is None) == (logits is None): 40 raise ValueError( 41 "Either `probs` or `logits` must be specified, but not both." 42 ) 43 if probs is not None: 44 ( 45 self.total_count, 46 self.probs, 47 ) = broadcast_all(total_count, probs) 48 self.total_count = self.total_count.type_as(self.probs) 49 else: 50 ( 51 self.total_count, 52 self.logits, 53 ) = broadcast_all(total_count, logits) 54 self.total_count = self.total_count.type_as(self.logits) 55 56 self._param = self.probs if probs is not None else self.logits 57 batch_shape = self._param.size() 58 super().__init__(batch_shape, validate_args=validate_args) 59 60 def expand(self, batch_shape, _instance=None): 61 new = self._get_checked_instance(NegativeBinomial, _instance) 62 batch_shape = torch.Size(batch_shape) 63 new.total_count = self.total_count.expand(batch_shape) 64 if "probs" in self.__dict__: 65 new.probs = self.probs.expand(batch_shape) 66 new._param = new.probs 67 if "logits" in self.__dict__: 68 new.logits = self.logits.expand(batch_shape) 69 new._param = new.logits 70 super(NegativeBinomial, new).__init__(batch_shape, validate_args=False) 71 new._validate_args = self._validate_args 72 return new 73 74 def _new(self, *args, **kwargs): 75 return self._param.new(*args, **kwargs) 76 77 @property 78 def mean(self): 79 return self.total_count * torch.exp(self.logits) 80 81 @property 82 def mode(self): 83 return ((self.total_count - 1) * self.logits.exp()).floor().clamp(min=0.0) 84 85 @property 86 def variance(self): 87 return self.mean / torch.sigmoid(-self.logits) 88 89 @lazy_property 90 def logits(self): 91 return probs_to_logits(self.probs, is_binary=True) 92 93 @lazy_property 94 def probs(self): 95 return logits_to_probs(self.logits, is_binary=True) 96 97 @property 98 def param_shape(self): 99 return self._param.size() 100 101 @lazy_property 102 def _gamma(self): 103 # Note we avoid validating because self.total_count can be zero. 104 return torch.distributions.Gamma( 105 concentration=self.total_count, 106 rate=torch.exp(-self.logits), 107 validate_args=False, 108 ) 109 110 def sample(self, sample_shape=torch.Size()): 111 with torch.no_grad(): 112 rate = self._gamma.sample(sample_shape=sample_shape) 113 return torch.poisson(rate) 114 115 def log_prob(self, value): 116 if self._validate_args: 117 self._validate_sample(value) 118 119 log_unnormalized_prob = self.total_count * F.logsigmoid( 120 -self.logits 121 ) + value * F.logsigmoid(self.logits) 122 123 log_normalization = ( 124 -torch.lgamma(self.total_count + value) 125 + torch.lgamma(1.0 + value) 126 + torch.lgamma(self.total_count) 127 ) 128 # The case self.total_count == 0 and value == 0 has probability 1 but 129 # lgamma(0) is infinite. Handle this case separately using a function 130 # that does not modify tensors in place to allow Jit compilation. 131 log_normalization = log_normalization.masked_fill( 132 self.total_count + value == 0.0, 0.0 133 ) 134 135 return log_unnormalized_prob - log_normalization 136