xref: /aosp_15_r20/external/pytorch/torch/distributions/negative_binomial.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.nn.functional as F
4from torch.distributions import constraints
5from torch.distributions.distribution import Distribution
6from torch.distributions.utils import (
7    broadcast_all,
8    lazy_property,
9    logits_to_probs,
10    probs_to_logits,
11)
12
13
14__all__ = ["NegativeBinomial"]
15
16
17class NegativeBinomial(Distribution):
18    r"""
19    Creates a Negative Binomial distribution, i.e. distribution
20    of the number of successful independent and identical Bernoulli trials
21    before :attr:`total_count` failures are achieved. The probability
22    of success of each Bernoulli trial is :attr:`probs`.
23
24    Args:
25        total_count (float or Tensor): non-negative number of negative Bernoulli
26            trials to stop, although the distribution is still valid for real
27            valued count
28        probs (Tensor): Event probabilities of success in the half open interval [0, 1)
29        logits (Tensor): Event log-odds for probabilities of success
30    """
31    arg_constraints = {
32        "total_count": constraints.greater_than_eq(0),
33        "probs": constraints.half_open_interval(0.0, 1.0),
34        "logits": constraints.real,
35    }
36    support = constraints.nonnegative_integer
37
38    def __init__(self, total_count, probs=None, logits=None, validate_args=None):
39        if (probs is None) == (logits is None):
40            raise ValueError(
41                "Either `probs` or `logits` must be specified, but not both."
42            )
43        if probs is not None:
44            (
45                self.total_count,
46                self.probs,
47            ) = broadcast_all(total_count, probs)
48            self.total_count = self.total_count.type_as(self.probs)
49        else:
50            (
51                self.total_count,
52                self.logits,
53            ) = broadcast_all(total_count, logits)
54            self.total_count = self.total_count.type_as(self.logits)
55
56        self._param = self.probs if probs is not None else self.logits
57        batch_shape = self._param.size()
58        super().__init__(batch_shape, validate_args=validate_args)
59
60    def expand(self, batch_shape, _instance=None):
61        new = self._get_checked_instance(NegativeBinomial, _instance)
62        batch_shape = torch.Size(batch_shape)
63        new.total_count = self.total_count.expand(batch_shape)
64        if "probs" in self.__dict__:
65            new.probs = self.probs.expand(batch_shape)
66            new._param = new.probs
67        if "logits" in self.__dict__:
68            new.logits = self.logits.expand(batch_shape)
69            new._param = new.logits
70        super(NegativeBinomial, new).__init__(batch_shape, validate_args=False)
71        new._validate_args = self._validate_args
72        return new
73
74    def _new(self, *args, **kwargs):
75        return self._param.new(*args, **kwargs)
76
77    @property
78    def mean(self):
79        return self.total_count * torch.exp(self.logits)
80
81    @property
82    def mode(self):
83        return ((self.total_count - 1) * self.logits.exp()).floor().clamp(min=0.0)
84
85    @property
86    def variance(self):
87        return self.mean / torch.sigmoid(-self.logits)
88
89    @lazy_property
90    def logits(self):
91        return probs_to_logits(self.probs, is_binary=True)
92
93    @lazy_property
94    def probs(self):
95        return logits_to_probs(self.logits, is_binary=True)
96
97    @property
98    def param_shape(self):
99        return self._param.size()
100
101    @lazy_property
102    def _gamma(self):
103        # Note we avoid validating because self.total_count can be zero.
104        return torch.distributions.Gamma(
105            concentration=self.total_count,
106            rate=torch.exp(-self.logits),
107            validate_args=False,
108        )
109
110    def sample(self, sample_shape=torch.Size()):
111        with torch.no_grad():
112            rate = self._gamma.sample(sample_shape=sample_shape)
113            return torch.poisson(rate)
114
115    def log_prob(self, value):
116        if self._validate_args:
117            self._validate_sample(value)
118
119        log_unnormalized_prob = self.total_count * F.logsigmoid(
120            -self.logits
121        ) + value * F.logsigmoid(self.logits)
122
123        log_normalization = (
124            -torch.lgamma(self.total_count + value)
125            + torch.lgamma(1.0 + value)
126            + torch.lgamma(self.total_count)
127        )
128        # The case self.total_count == 0 and value == 0 has probability 1 but
129        # lgamma(0) is infinite. Handle this case separately using a function
130        # that does not modify tensors in place to allow Jit compilation.
131        log_normalization = log_normalization.masked_fill(
132            self.total_count + value == 0.0, 0.0
133        )
134
135        return log_unnormalized_prob - log_normalization
136