xref: /aosp_15_r20/external/pytorch/torch/distributions/multinomial.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from torch import inf
4from torch.distributions import Categorical, constraints
5from torch.distributions.binomial import Binomial
6from torch.distributions.distribution import Distribution
7from torch.distributions.utils import broadcast_all
8
9
10__all__ = ["Multinomial"]
11
12
13class Multinomial(Distribution):
14    r"""
15    Creates a Multinomial distribution parameterized by :attr:`total_count` and
16    either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
17    :attr:`probs` indexes over categories. All other dimensions index over batches.
18
19    Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
20    called (see example below)
21
22    .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
23              and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
24              will return this normalized value.
25              The `logits` argument will be interpreted as unnormalized log probabilities
26              and can therefore be any real number. It will likewise be normalized so that
27              the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
28              will return this normalized value.
29
30    -   :meth:`sample` requires a single shared `total_count` for all
31        parameters and samples.
32    -   :meth:`log_prob` allows different `total_count` for each parameter and
33        sample.
34
35    Example::
36
37        >>> # xdoctest: +SKIP("FIXME: found invalid values")
38        >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
39        >>> x = m.sample()  # equal probability of 0, 1, 2, 3
40        tensor([ 21.,  24.,  30.,  25.])
41
42        >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
43        tensor([-4.1338])
44
45    Args:
46        total_count (int): number of trials
47        probs (Tensor): event probabilities
48        logits (Tensor): event log probabilities (unnormalized)
49    """
50    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
51    total_count: int
52
53    @property
54    def mean(self):
55        return self.probs * self.total_count
56
57    @property
58    def variance(self):
59        return self.total_count * self.probs * (1 - self.probs)
60
61    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
62        if not isinstance(total_count, int):
63            raise NotImplementedError("inhomogeneous total_count is not supported")
64        self.total_count = total_count
65        self._categorical = Categorical(probs=probs, logits=logits)
66        self._binomial = Binomial(total_count=total_count, probs=self.probs)
67        batch_shape = self._categorical.batch_shape
68        event_shape = self._categorical.param_shape[-1:]
69        super().__init__(batch_shape, event_shape, validate_args=validate_args)
70
71    def expand(self, batch_shape, _instance=None):
72        new = self._get_checked_instance(Multinomial, _instance)
73        batch_shape = torch.Size(batch_shape)
74        new.total_count = self.total_count
75        new._categorical = self._categorical.expand(batch_shape)
76        super(Multinomial, new).__init__(
77            batch_shape, self.event_shape, validate_args=False
78        )
79        new._validate_args = self._validate_args
80        return new
81
82    def _new(self, *args, **kwargs):
83        return self._categorical._new(*args, **kwargs)
84
85    @constraints.dependent_property(is_discrete=True, event_dim=1)
86    def support(self):
87        return constraints.multinomial(self.total_count)
88
89    @property
90    def logits(self):
91        return self._categorical.logits
92
93    @property
94    def probs(self):
95        return self._categorical.probs
96
97    @property
98    def param_shape(self):
99        return self._categorical.param_shape
100
101    def sample(self, sample_shape=torch.Size()):
102        sample_shape = torch.Size(sample_shape)
103        samples = self._categorical.sample(
104            torch.Size((self.total_count,)) + sample_shape
105        )
106        # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
107        # (sample_shape, batch_shape, total_count)
108        shifted_idx = list(range(samples.dim()))
109        shifted_idx.append(shifted_idx.pop(0))
110        samples = samples.permute(*shifted_idx)
111        counts = samples.new(self._extended_shape(sample_shape)).zero_()
112        counts.scatter_add_(-1, samples, torch.ones_like(samples))
113        return counts.type_as(self.probs)
114
115    def entropy(self):
116        n = torch.tensor(self.total_count)
117
118        cat_entropy = self._categorical.entropy()
119        term1 = n * cat_entropy - torch.lgamma(n + 1)
120
121        support = self._binomial.enumerate_support(expand=False)[1:]
122        binomial_probs = torch.exp(self._binomial.log_prob(support))
123        weights = torch.lgamma(support + 1)
124        term2 = (binomial_probs * weights).sum([0, -1])
125
126        return term1 + term2
127
128    def log_prob(self, value):
129        if self._validate_args:
130            self._validate_sample(value)
131        logits, value = broadcast_all(self.logits, value)
132        logits = logits.clone(memory_format=torch.contiguous_format)
133        log_factorial_n = torch.lgamma(value.sum(-1) + 1)
134        log_factorial_xs = torch.lgamma(value + 1).sum(-1)
135        logits[(value == 0) & (logits == -inf)] = 0
136        log_powers = (logits * value).sum(-1)
137        return log_factorial_n - log_factorial_xs + log_powers
138