1# mypy: allow-untyped-defs 2from numbers import Number 3 4import torch 5from torch import nan 6from torch.distributions import constraints 7from torch.distributions.exp_family import ExponentialFamily 8from torch.distributions.utils import ( 9 broadcast_all, 10 lazy_property, 11 logits_to_probs, 12 probs_to_logits, 13) 14from torch.nn.functional import binary_cross_entropy_with_logits 15 16 17__all__ = ["Bernoulli"] 18 19 20class Bernoulli(ExponentialFamily): 21 r""" 22 Creates a Bernoulli distribution parameterized by :attr:`probs` 23 or :attr:`logits` (but not both). 24 25 Samples are binary (0 or 1). They take the value `1` with probability `p` 26 and `0` with probability `1 - p`. 27 28 Example:: 29 30 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 31 >>> m = Bernoulli(torch.tensor([0.3])) 32 >>> m.sample() # 30% chance 1; 70% chance 0 33 tensor([ 0.]) 34 35 Args: 36 probs (Number, Tensor): the probability of sampling `1` 37 logits (Number, Tensor): the log-odds of sampling `1` 38 """ 39 arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} 40 support = constraints.boolean 41 has_enumerate_support = True 42 _mean_carrier_measure = 0 43 44 def __init__(self, probs=None, logits=None, validate_args=None): 45 if (probs is None) == (logits is None): 46 raise ValueError( 47 "Either `probs` or `logits` must be specified, but not both." 48 ) 49 if probs is not None: 50 is_scalar = isinstance(probs, Number) 51 (self.probs,) = broadcast_all(probs) 52 else: 53 is_scalar = isinstance(logits, Number) 54 (self.logits,) = broadcast_all(logits) 55 self._param = self.probs if probs is not None else self.logits 56 if is_scalar: 57 batch_shape = torch.Size() 58 else: 59 batch_shape = self._param.size() 60 super().__init__(batch_shape, validate_args=validate_args) 61 62 def expand(self, batch_shape, _instance=None): 63 new = self._get_checked_instance(Bernoulli, _instance) 64 batch_shape = torch.Size(batch_shape) 65 if "probs" in self.__dict__: 66 new.probs = self.probs.expand(batch_shape) 67 new._param = new.probs 68 if "logits" in self.__dict__: 69 new.logits = self.logits.expand(batch_shape) 70 new._param = new.logits 71 super(Bernoulli, new).__init__(batch_shape, validate_args=False) 72 new._validate_args = self._validate_args 73 return new 74 75 def _new(self, *args, **kwargs): 76 return self._param.new(*args, **kwargs) 77 78 @property 79 def mean(self): 80 return self.probs 81 82 @property 83 def mode(self): 84 mode = (self.probs >= 0.5).to(self.probs) 85 mode[self.probs == 0.5] = nan 86 return mode 87 88 @property 89 def variance(self): 90 return self.probs * (1 - self.probs) 91 92 @lazy_property 93 def logits(self): 94 return probs_to_logits(self.probs, is_binary=True) 95 96 @lazy_property 97 def probs(self): 98 return logits_to_probs(self.logits, is_binary=True) 99 100 @property 101 def param_shape(self): 102 return self._param.size() 103 104 def sample(self, sample_shape=torch.Size()): 105 shape = self._extended_shape(sample_shape) 106 with torch.no_grad(): 107 return torch.bernoulli(self.probs.expand(shape)) 108 109 def log_prob(self, value): 110 if self._validate_args: 111 self._validate_sample(value) 112 logits, value = broadcast_all(self.logits, value) 113 return -binary_cross_entropy_with_logits(logits, value, reduction="none") 114 115 def entropy(self): 116 return binary_cross_entropy_with_logits( 117 self.logits, self.probs, reduction="none" 118 ) 119 120 def enumerate_support(self, expand=True): 121 values = torch.arange(2, dtype=self._param.dtype, device=self._param.device) 122 values = values.view((-1,) + (1,) * len(self._batch_shape)) 123 if expand: 124 values = values.expand((-1,) + self._batch_shape) 125 return values 126 127 @property 128 def _natural_params(self): 129 return (torch.logit(self.probs),) 130 131 def _log_normalizer(self, x): 132 return torch.log1p(torch.exp(x)) 133