1# mypy: allow-untyped-defs 2from numbers import Number 3 4import torch 5from torch import nan 6from torch.distributions import constraints 7from torch.distributions.distribution import Distribution 8from torch.distributions.utils import broadcast_all 9from torch.types import _size 10 11 12__all__ = ["Uniform"] 13 14 15class Uniform(Distribution): 16 r""" 17 Generates uniformly distributed random samples from the half-open interval 18 ``[low, high)``. 19 20 Example:: 21 22 >>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0])) 23 >>> m.sample() # uniformly distributed in the range [0.0, 5.0) 24 >>> # xdoctest: +SKIP 25 tensor([ 2.3418]) 26 27 Args: 28 low (float or Tensor): lower range (inclusive). 29 high (float or Tensor): upper range (exclusive). 30 """ 31 # TODO allow (loc,scale) parameterization to allow independent constraints. 32 arg_constraints = { 33 "low": constraints.dependent(is_discrete=False, event_dim=0), 34 "high": constraints.dependent(is_discrete=False, event_dim=0), 35 } 36 has_rsample = True 37 38 @property 39 def mean(self): 40 return (self.high + self.low) / 2 41 42 @property 43 def mode(self): 44 return nan * self.high 45 46 @property 47 def stddev(self): 48 return (self.high - self.low) / 12**0.5 49 50 @property 51 def variance(self): 52 return (self.high - self.low).pow(2) / 12 53 54 def __init__(self, low, high, validate_args=None): 55 self.low, self.high = broadcast_all(low, high) 56 57 if isinstance(low, Number) and isinstance(high, Number): 58 batch_shape = torch.Size() 59 else: 60 batch_shape = self.low.size() 61 super().__init__(batch_shape, validate_args=validate_args) 62 63 if self._validate_args and not torch.lt(self.low, self.high).all(): 64 raise ValueError("Uniform is not defined when low>= high") 65 66 def expand(self, batch_shape, _instance=None): 67 new = self._get_checked_instance(Uniform, _instance) 68 batch_shape = torch.Size(batch_shape) 69 new.low = self.low.expand(batch_shape) 70 new.high = self.high.expand(batch_shape) 71 super(Uniform, new).__init__(batch_shape, validate_args=False) 72 new._validate_args = self._validate_args 73 return new 74 75 @constraints.dependent_property(is_discrete=False, event_dim=0) 76 def support(self): 77 return constraints.interval(self.low, self.high) 78 79 def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 80 shape = self._extended_shape(sample_shape) 81 rand = torch.rand(shape, dtype=self.low.dtype, device=self.low.device) 82 return self.low + rand * (self.high - self.low) 83 84 def log_prob(self, value): 85 if self._validate_args: 86 self._validate_sample(value) 87 lb = self.low.le(value).type_as(self.low) 88 ub = self.high.gt(value).type_as(self.low) 89 return torch.log(lb.mul(ub)) - torch.log(self.high - self.low) 90 91 def cdf(self, value): 92 if self._validate_args: 93 self._validate_sample(value) 94 result = (value - self.low) / (self.high - self.low) 95 return result.clamp(min=0, max=1) 96 97 def icdf(self, value): 98 result = value * (self.high - self.low) + self.low 99 return result 100 101 def entropy(self): 102 return torch.log(self.high - self.low) 103