1# mypy: allow-untyped-defs 2import math 3from numbers import Number 4 5import torch 6from torch.distributions import constraints 7from torch.distributions.exp_family import ExponentialFamily 8from torch.distributions.utils import ( 9 broadcast_all, 10 clamp_probs, 11 lazy_property, 12 logits_to_probs, 13 probs_to_logits, 14) 15from torch.nn.functional import binary_cross_entropy_with_logits 16from torch.types import _size 17 18 19__all__ = ["ContinuousBernoulli"] 20 21 22class ContinuousBernoulli(ExponentialFamily): 23 r""" 24 Creates a continuous Bernoulli distribution parameterized by :attr:`probs` 25 or :attr:`logits` (but not both). 26 27 The distribution is supported in [0, 1] and parameterized by 'probs' (in 28 (0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs' 29 does not correspond to a probability and 'logits' does not correspond to 30 log-odds, but the same names are used due to the similarity with the 31 Bernoulli. See [1] for more details. 32 33 Example:: 34 35 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 36 >>> m = ContinuousBernoulli(torch.tensor([0.3])) 37 >>> m.sample() 38 tensor([ 0.2538]) 39 40 Args: 41 probs (Number, Tensor): (0,1) valued parameters 42 logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs' 43 44 [1] The continuous Bernoulli: fixing a pervasive error in variational 45 autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019. 46 https://arxiv.org/abs/1907.06845 47 """ 48 arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} 49 support = constraints.unit_interval 50 _mean_carrier_measure = 0 51 has_rsample = True 52 53 def __init__( 54 self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None 55 ): 56 if (probs is None) == (logits is None): 57 raise ValueError( 58 "Either `probs` or `logits` must be specified, but not both." 59 ) 60 if probs is not None: 61 is_scalar = isinstance(probs, Number) 62 (self.probs,) = broadcast_all(probs) 63 # validate 'probs' here if necessary as it is later clamped for numerical stability 64 # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass 65 if validate_args is not None: 66 if not self.arg_constraints["probs"].check(self.probs).all(): 67 raise ValueError("The parameter probs has invalid values") 68 self.probs = clamp_probs(self.probs) 69 else: 70 is_scalar = isinstance(logits, Number) 71 (self.logits,) = broadcast_all(logits) 72 self._param = self.probs if probs is not None else self.logits 73 if is_scalar: 74 batch_shape = torch.Size() 75 else: 76 batch_shape = self._param.size() 77 self._lims = lims 78 super().__init__(batch_shape, validate_args=validate_args) 79 80 def expand(self, batch_shape, _instance=None): 81 new = self._get_checked_instance(ContinuousBernoulli, _instance) 82 new._lims = self._lims 83 batch_shape = torch.Size(batch_shape) 84 if "probs" in self.__dict__: 85 new.probs = self.probs.expand(batch_shape) 86 new._param = new.probs 87 if "logits" in self.__dict__: 88 new.logits = self.logits.expand(batch_shape) 89 new._param = new.logits 90 super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False) 91 new._validate_args = self._validate_args 92 return new 93 94 def _new(self, *args, **kwargs): 95 return self._param.new(*args, **kwargs) 96 97 def _outside_unstable_region(self): 98 return torch.max( 99 torch.le(self.probs, self._lims[0]), torch.gt(self.probs, self._lims[1]) 100 ) 101 102 def _cut_probs(self): 103 return torch.where( 104 self._outside_unstable_region(), 105 self.probs, 106 self._lims[0] * torch.ones_like(self.probs), 107 ) 108 109 def _cont_bern_log_norm(self): 110 """computes the log normalizing constant as a function of the 'probs' parameter""" 111 cut_probs = self._cut_probs() 112 cut_probs_below_half = torch.where( 113 torch.le(cut_probs, 0.5), cut_probs, torch.zeros_like(cut_probs) 114 ) 115 cut_probs_above_half = torch.where( 116 torch.ge(cut_probs, 0.5), cut_probs, torch.ones_like(cut_probs) 117 ) 118 log_norm = torch.log( 119 torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs)) 120 ) - torch.where( 121 torch.le(cut_probs, 0.5), 122 torch.log1p(-2.0 * cut_probs_below_half), 123 torch.log(2.0 * cut_probs_above_half - 1.0), 124 ) 125 x = torch.pow(self.probs - 0.5, 2) 126 taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x 127 return torch.where(self._outside_unstable_region(), log_norm, taylor) 128 129 @property 130 def mean(self): 131 cut_probs = self._cut_probs() 132 mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / ( 133 torch.log1p(-cut_probs) - torch.log(cut_probs) 134 ) 135 x = self.probs - 0.5 136 taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x 137 return torch.where(self._outside_unstable_region(), mus, taylor) 138 139 @property 140 def stddev(self): 141 return torch.sqrt(self.variance) 142 143 @property 144 def variance(self): 145 cut_probs = self._cut_probs() 146 vars = cut_probs * (cut_probs - 1.0) / torch.pow( 147 1.0 - 2.0 * cut_probs, 2 148 ) + 1.0 / torch.pow(torch.log1p(-cut_probs) - torch.log(cut_probs), 2) 149 x = torch.pow(self.probs - 0.5, 2) 150 taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x 151 return torch.where(self._outside_unstable_region(), vars, taylor) 152 153 @lazy_property 154 def logits(self): 155 return probs_to_logits(self.probs, is_binary=True) 156 157 @lazy_property 158 def probs(self): 159 return clamp_probs(logits_to_probs(self.logits, is_binary=True)) 160 161 @property 162 def param_shape(self): 163 return self._param.size() 164 165 def sample(self, sample_shape=torch.Size()): 166 shape = self._extended_shape(sample_shape) 167 u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device) 168 with torch.no_grad(): 169 return self.icdf(u) 170 171 def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 172 shape = self._extended_shape(sample_shape) 173 u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device) 174 return self.icdf(u) 175 176 def log_prob(self, value): 177 if self._validate_args: 178 self._validate_sample(value) 179 logits, value = broadcast_all(self.logits, value) 180 return ( 181 -binary_cross_entropy_with_logits(logits, value, reduction="none") 182 + self._cont_bern_log_norm() 183 ) 184 185 def cdf(self, value): 186 if self._validate_args: 187 self._validate_sample(value) 188 cut_probs = self._cut_probs() 189 cdfs = ( 190 torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value) 191 + cut_probs 192 - 1.0 193 ) / (2.0 * cut_probs - 1.0) 194 unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value) 195 return torch.where( 196 torch.le(value, 0.0), 197 torch.zeros_like(value), 198 torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs), 199 ) 200 201 def icdf(self, value): 202 cut_probs = self._cut_probs() 203 return torch.where( 204 self._outside_unstable_region(), 205 ( 206 torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0)) 207 - torch.log1p(-cut_probs) 208 ) 209 / (torch.log(cut_probs) - torch.log1p(-cut_probs)), 210 value, 211 ) 212 213 def entropy(self): 214 log_probs0 = torch.log1p(-self.probs) 215 log_probs1 = torch.log(self.probs) 216 return ( 217 self.mean * (log_probs0 - log_probs1) 218 - self._cont_bern_log_norm() 219 - log_probs0 220 ) 221 222 @property 223 def _natural_params(self): 224 return (self.logits,) 225 226 def _log_normalizer(self, x): 227 """computes the log normalizing constant as a function of the natural parameter""" 228 out_unst_reg = torch.max( 229 torch.le(x, self._lims[0] - 0.5), torch.gt(x, self._lims[1] - 0.5) 230 ) 231 cut_nat_params = torch.where( 232 out_unst_reg, x, (self._lims[0] - 0.5) * torch.ones_like(x) 233 ) 234 log_norm = torch.log(torch.abs(torch.exp(cut_nat_params) - 1.0)) - torch.log( 235 torch.abs(cut_nat_params) 236 ) 237 taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0 238 return torch.where(out_unst_reg, log_norm, taylor) 239