1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerr""" 3*da0073e9SAndroid Build Coastguard WorkerThe following constraints are implemented: 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker- ``constraints.boolean`` 6*da0073e9SAndroid Build Coastguard Worker- ``constraints.cat`` 7*da0073e9SAndroid Build Coastguard Worker- ``constraints.corr_cholesky`` 8*da0073e9SAndroid Build Coastguard Worker- ``constraints.dependent`` 9*da0073e9SAndroid Build Coastguard Worker- ``constraints.greater_than(lower_bound)`` 10*da0073e9SAndroid Build Coastguard Worker- ``constraints.greater_than_eq(lower_bound)`` 11*da0073e9SAndroid Build Coastguard Worker- ``constraints.independent(constraint, reinterpreted_batch_ndims)`` 12*da0073e9SAndroid Build Coastguard Worker- ``constraints.integer_interval(lower_bound, upper_bound)`` 13*da0073e9SAndroid Build Coastguard Worker- ``constraints.interval(lower_bound, upper_bound)`` 14*da0073e9SAndroid Build Coastguard Worker- ``constraints.less_than(upper_bound)`` 15*da0073e9SAndroid Build Coastguard Worker- ``constraints.lower_cholesky`` 16*da0073e9SAndroid Build Coastguard Worker- ``constraints.lower_triangular`` 17*da0073e9SAndroid Build Coastguard Worker- ``constraints.multinomial`` 18*da0073e9SAndroid Build Coastguard Worker- ``constraints.nonnegative`` 19*da0073e9SAndroid Build Coastguard Worker- ``constraints.nonnegative_integer`` 20*da0073e9SAndroid Build Coastguard Worker- ``constraints.one_hot`` 21*da0073e9SAndroid Build Coastguard Worker- ``constraints.positive_integer`` 22*da0073e9SAndroid Build Coastguard Worker- ``constraints.positive`` 23*da0073e9SAndroid Build Coastguard Worker- ``constraints.positive_semidefinite`` 24*da0073e9SAndroid Build Coastguard Worker- ``constraints.positive_definite`` 25*da0073e9SAndroid Build Coastguard Worker- ``constraints.real_vector`` 26*da0073e9SAndroid Build Coastguard Worker- ``constraints.real`` 27*da0073e9SAndroid Build Coastguard Worker- ``constraints.simplex`` 28*da0073e9SAndroid Build Coastguard Worker- ``constraints.symmetric`` 29*da0073e9SAndroid Build Coastguard Worker- ``constraints.stack`` 30*da0073e9SAndroid Build Coastguard Worker- ``constraints.square`` 31*da0073e9SAndroid Build Coastguard Worker- ``constraints.symmetric`` 32*da0073e9SAndroid Build Coastguard Worker- ``constraints.unit_interval`` 33*da0073e9SAndroid Build Coastguard Worker""" 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workerimport torch 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker__all__ = [ 39*da0073e9SAndroid Build Coastguard Worker "Constraint", 40*da0073e9SAndroid Build Coastguard Worker "boolean", 41*da0073e9SAndroid Build Coastguard Worker "cat", 42*da0073e9SAndroid Build Coastguard Worker "corr_cholesky", 43*da0073e9SAndroid Build Coastguard Worker "dependent", 44*da0073e9SAndroid Build Coastguard Worker "dependent_property", 45*da0073e9SAndroid Build Coastguard Worker "greater_than", 46*da0073e9SAndroid Build Coastguard Worker "greater_than_eq", 47*da0073e9SAndroid Build Coastguard Worker "independent", 48*da0073e9SAndroid Build Coastguard Worker "integer_interval", 49*da0073e9SAndroid Build Coastguard Worker "interval", 50*da0073e9SAndroid Build Coastguard Worker "half_open_interval", 51*da0073e9SAndroid Build Coastguard Worker "is_dependent", 52*da0073e9SAndroid Build Coastguard Worker "less_than", 53*da0073e9SAndroid Build Coastguard Worker "lower_cholesky", 54*da0073e9SAndroid Build Coastguard Worker "lower_triangular", 55*da0073e9SAndroid Build Coastguard Worker "multinomial", 56*da0073e9SAndroid Build Coastguard Worker "nonnegative", 57*da0073e9SAndroid Build Coastguard Worker "nonnegative_integer", 58*da0073e9SAndroid Build Coastguard Worker "one_hot", 59*da0073e9SAndroid Build Coastguard Worker "positive", 60*da0073e9SAndroid Build Coastguard Worker "positive_semidefinite", 61*da0073e9SAndroid Build Coastguard Worker "positive_definite", 62*da0073e9SAndroid Build Coastguard Worker "positive_integer", 63*da0073e9SAndroid Build Coastguard Worker "real", 64*da0073e9SAndroid Build Coastguard Worker "real_vector", 65*da0073e9SAndroid Build Coastguard Worker "simplex", 66*da0073e9SAndroid Build Coastguard Worker "square", 67*da0073e9SAndroid Build Coastguard Worker "stack", 68*da0073e9SAndroid Build Coastguard Worker "symmetric", 69*da0073e9SAndroid Build Coastguard Worker "unit_interval", 70*da0073e9SAndroid Build Coastguard Worker] 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Workerclass Constraint: 74*da0073e9SAndroid Build Coastguard Worker """ 75*da0073e9SAndroid Build Coastguard Worker Abstract base class for constraints. 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker A constraint object represents a region over which a variable is valid, 78*da0073e9SAndroid Build Coastguard Worker e.g. within which a variable can be optimized. 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker Attributes: 81*da0073e9SAndroid Build Coastguard Worker is_discrete (bool): Whether constrained space is discrete. 82*da0073e9SAndroid Build Coastguard Worker Defaults to False. 83*da0073e9SAndroid Build Coastguard Worker event_dim (int): Number of rightmost dimensions that together define 84*da0073e9SAndroid Build Coastguard Worker an event. The :meth:`check` method will remove this many dimensions 85*da0073e9SAndroid Build Coastguard Worker when computing validity. 86*da0073e9SAndroid Build Coastguard Worker """ 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker is_discrete = False # Default to continuous. 89*da0073e9SAndroid Build Coastguard Worker event_dim = 0 # Default to univariate. 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker def check(self, value): 92*da0073e9SAndroid Build Coastguard Worker """ 93*da0073e9SAndroid Build Coastguard Worker Returns a byte tensor of ``sample_shape + batch_shape`` indicating 94*da0073e9SAndroid Build Coastguard Worker whether each event in value satisfies this constraint. 95*da0073e9SAndroid Build Coastguard Worker """ 96*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 99*da0073e9SAndroid Build Coastguard Worker return self.__class__.__name__[1:] + "()" 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Workerclass _Dependent(Constraint): 103*da0073e9SAndroid Build Coastguard Worker """ 104*da0073e9SAndroid Build Coastguard Worker Placeholder for variables whose support depends on other variables. 105*da0073e9SAndroid Build Coastguard Worker These variables obey no simple coordinate-wise constraints. 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker Args: 108*da0073e9SAndroid Build Coastguard Worker is_discrete (bool): Optional value of ``.is_discrete`` in case this 109*da0073e9SAndroid Build Coastguard Worker can be computed statically. If not provided, access to the 110*da0073e9SAndroid Build Coastguard Worker ``.is_discrete`` attribute will raise a NotImplementedError. 111*da0073e9SAndroid Build Coastguard Worker event_dim (int): Optional value of ``.event_dim`` in case this 112*da0073e9SAndroid Build Coastguard Worker can be computed statically. If not provided, access to the 113*da0073e9SAndroid Build Coastguard Worker ``.event_dim`` attribute will raise a NotImplementedError. 114*da0073e9SAndroid Build Coastguard Worker """ 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): 117*da0073e9SAndroid Build Coastguard Worker self._is_discrete = is_discrete 118*da0073e9SAndroid Build Coastguard Worker self._event_dim = event_dim 119*da0073e9SAndroid Build Coastguard Worker super().__init__() 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker @property 122*da0073e9SAndroid Build Coastguard Worker def is_discrete(self): 123*da0073e9SAndroid Build Coastguard Worker if self._is_discrete is NotImplemented: 124*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(".is_discrete cannot be determined statically") 125*da0073e9SAndroid Build Coastguard Worker return self._is_discrete 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker @property 128*da0073e9SAndroid Build Coastguard Worker def event_dim(self): 129*da0073e9SAndroid Build Coastguard Worker if self._event_dim is NotImplemented: 130*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(".event_dim cannot be determined statically") 131*da0073e9SAndroid Build Coastguard Worker return self._event_dim 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): 134*da0073e9SAndroid Build Coastguard Worker """ 135*da0073e9SAndroid Build Coastguard Worker Support for syntax to customize static attributes:: 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker constraints.dependent(is_discrete=True, event_dim=1) 138*da0073e9SAndroid Build Coastguard Worker """ 139*da0073e9SAndroid Build Coastguard Worker if is_discrete is NotImplemented: 140*da0073e9SAndroid Build Coastguard Worker is_discrete = self._is_discrete 141*da0073e9SAndroid Build Coastguard Worker if event_dim is NotImplemented: 142*da0073e9SAndroid Build Coastguard Worker event_dim = self._event_dim 143*da0073e9SAndroid Build Coastguard Worker return _Dependent(is_discrete=is_discrete, event_dim=event_dim) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker def check(self, x): 146*da0073e9SAndroid Build Coastguard Worker raise ValueError("Cannot determine validity of dependent constraint") 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Workerdef is_dependent(constraint): 150*da0073e9SAndroid Build Coastguard Worker """ 151*da0073e9SAndroid Build Coastguard Worker Checks if ``constraint`` is a ``_Dependent`` object. 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker Args: 154*da0073e9SAndroid Build Coastguard Worker constraint : A ``Constraint`` object. 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker Returns: 157*da0073e9SAndroid Build Coastguard Worker ``bool``: True if ``constraint`` can be refined to the type ``_Dependent``, False otherwise. 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker Examples: 160*da0073e9SAndroid Build Coastguard Worker >>> import torch 161*da0073e9SAndroid Build Coastguard Worker >>> from torch.distributions import Bernoulli 162*da0073e9SAndroid Build Coastguard Worker >>> from torch.distributions.constraints import is_dependent 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker >>> dist = Bernoulli(probs = torch.tensor([0.6], requires_grad=True)) 165*da0073e9SAndroid Build Coastguard Worker >>> constraint1 = dist.arg_constraints["probs"] 166*da0073e9SAndroid Build Coastguard Worker >>> constraint2 = dist.arg_constraints["logits"] 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker >>> for constraint in [constraint1, constraint2]: 169*da0073e9SAndroid Build Coastguard Worker >>> if is_dependent(constraint): 170*da0073e9SAndroid Build Coastguard Worker >>> continue 171*da0073e9SAndroid Build Coastguard Worker """ 172*da0073e9SAndroid Build Coastguard Worker return isinstance(constraint, _Dependent) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Workerclass _DependentProperty(property, _Dependent): 176*da0073e9SAndroid Build Coastguard Worker """ 177*da0073e9SAndroid Build Coastguard Worker Decorator that extends @property to act like a `Dependent` constraint when 178*da0073e9SAndroid Build Coastguard Worker called on a class and act like a property when called on an object. 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker Example:: 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker class Uniform(Distribution): 183*da0073e9SAndroid Build Coastguard Worker def __init__(self, low, high): 184*da0073e9SAndroid Build Coastguard Worker self.low = low 185*da0073e9SAndroid Build Coastguard Worker self.high = high 186*da0073e9SAndroid Build Coastguard Worker @constraints.dependent_property(is_discrete=False, event_dim=0) 187*da0073e9SAndroid Build Coastguard Worker def support(self): 188*da0073e9SAndroid Build Coastguard Worker return constraints.interval(self.low, self.high) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker Args: 191*da0073e9SAndroid Build Coastguard Worker fn (Callable): The function to be decorated. 192*da0073e9SAndroid Build Coastguard Worker is_discrete (bool): Optional value of ``.is_discrete`` in case this 193*da0073e9SAndroid Build Coastguard Worker can be computed statically. If not provided, access to the 194*da0073e9SAndroid Build Coastguard Worker ``.is_discrete`` attribute will raise a NotImplementedError. 195*da0073e9SAndroid Build Coastguard Worker event_dim (int): Optional value of ``.event_dim`` in case this 196*da0073e9SAndroid Build Coastguard Worker can be computed statically. If not provided, access to the 197*da0073e9SAndroid Build Coastguard Worker ``.event_dim`` attribute will raise a NotImplementedError. 198*da0073e9SAndroid Build Coastguard Worker """ 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker def __init__( 201*da0073e9SAndroid Build Coastguard Worker self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented 202*da0073e9SAndroid Build Coastguard Worker ): 203*da0073e9SAndroid Build Coastguard Worker super().__init__(fn) 204*da0073e9SAndroid Build Coastguard Worker self._is_discrete = is_discrete 205*da0073e9SAndroid Build Coastguard Worker self._event_dim = event_dim 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker def __call__(self, fn): 208*da0073e9SAndroid Build Coastguard Worker """ 209*da0073e9SAndroid Build Coastguard Worker Support for syntax to customize static attributes:: 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker @constraints.dependent_property(is_discrete=True, event_dim=1) 212*da0073e9SAndroid Build Coastguard Worker def support(self): 213*da0073e9SAndroid Build Coastguard Worker ... 214*da0073e9SAndroid Build Coastguard Worker """ 215*da0073e9SAndroid Build Coastguard Worker return _DependentProperty( 216*da0073e9SAndroid Build Coastguard Worker fn, is_discrete=self._is_discrete, event_dim=self._event_dim 217*da0073e9SAndroid Build Coastguard Worker ) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Workerclass _IndependentConstraint(Constraint): 221*da0073e9SAndroid Build Coastguard Worker """ 222*da0073e9SAndroid Build Coastguard Worker Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many 223*da0073e9SAndroid Build Coastguard Worker dims in :meth:`check`, so that an event is valid only if all its 224*da0073e9SAndroid Build Coastguard Worker independent entries are valid. 225*da0073e9SAndroid Build Coastguard Worker """ 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker def __init__(self, base_constraint, reinterpreted_batch_ndims): 228*da0073e9SAndroid Build Coastguard Worker assert isinstance(base_constraint, Constraint) 229*da0073e9SAndroid Build Coastguard Worker assert isinstance(reinterpreted_batch_ndims, int) 230*da0073e9SAndroid Build Coastguard Worker assert reinterpreted_batch_ndims >= 0 231*da0073e9SAndroid Build Coastguard Worker self.base_constraint = base_constraint 232*da0073e9SAndroid Build Coastguard Worker self.reinterpreted_batch_ndims = reinterpreted_batch_ndims 233*da0073e9SAndroid Build Coastguard Worker super().__init__() 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker @property 236*da0073e9SAndroid Build Coastguard Worker def is_discrete(self): 237*da0073e9SAndroid Build Coastguard Worker return self.base_constraint.is_discrete 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker @property 240*da0073e9SAndroid Build Coastguard Worker def event_dim(self): 241*da0073e9SAndroid Build Coastguard Worker return self.base_constraint.event_dim + self.reinterpreted_batch_ndims 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker def check(self, value): 244*da0073e9SAndroid Build Coastguard Worker result = self.base_constraint.check(value) 245*da0073e9SAndroid Build Coastguard Worker if result.dim() < self.reinterpreted_batch_ndims: 246*da0073e9SAndroid Build Coastguard Worker expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims 247*da0073e9SAndroid Build Coastguard Worker raise ValueError( 248*da0073e9SAndroid Build Coastguard Worker f"Expected value.dim() >= {expected} but got {value.dim()}" 249*da0073e9SAndroid Build Coastguard Worker ) 250*da0073e9SAndroid Build Coastguard Worker result = result.reshape( 251*da0073e9SAndroid Build Coastguard Worker result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,) 252*da0073e9SAndroid Build Coastguard Worker ) 253*da0073e9SAndroid Build Coastguard Worker result = result.all(-1) 254*da0073e9SAndroid Build Coastguard Worker return result 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 257*da0073e9SAndroid Build Coastguard Worker return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})" 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Workerclass _Boolean(Constraint): 261*da0073e9SAndroid Build Coastguard Worker """ 262*da0073e9SAndroid Build Coastguard Worker Constrain to the two values `{0, 1}`. 263*da0073e9SAndroid Build Coastguard Worker """ 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker is_discrete = True 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker def check(self, value): 268*da0073e9SAndroid Build Coastguard Worker return (value == 0) | (value == 1) 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Workerclass _OneHot(Constraint): 272*da0073e9SAndroid Build Coastguard Worker """ 273*da0073e9SAndroid Build Coastguard Worker Constrain to one-hot vectors. 274*da0073e9SAndroid Build Coastguard Worker """ 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker is_discrete = True 277*da0073e9SAndroid Build Coastguard Worker event_dim = 1 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker def check(self, value): 280*da0073e9SAndroid Build Coastguard Worker is_boolean = (value == 0) | (value == 1) 281*da0073e9SAndroid Build Coastguard Worker is_normalized = value.sum(-1).eq(1) 282*da0073e9SAndroid Build Coastguard Worker return is_boolean.all(-1) & is_normalized 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Workerclass _IntegerInterval(Constraint): 286*da0073e9SAndroid Build Coastguard Worker """ 287*da0073e9SAndroid Build Coastguard Worker Constrain to an integer interval `[lower_bound, upper_bound]`. 288*da0073e9SAndroid Build Coastguard Worker """ 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker is_discrete = True 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker def __init__(self, lower_bound, upper_bound): 293*da0073e9SAndroid Build Coastguard Worker self.lower_bound = lower_bound 294*da0073e9SAndroid Build Coastguard Worker self.upper_bound = upper_bound 295*da0073e9SAndroid Build Coastguard Worker super().__init__() 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker def check(self, value): 298*da0073e9SAndroid Build Coastguard Worker return ( 299*da0073e9SAndroid Build Coastguard Worker (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) 300*da0073e9SAndroid Build Coastguard Worker ) 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 303*da0073e9SAndroid Build Coastguard Worker fmt_string = self.__class__.__name__[1:] 304*da0073e9SAndroid Build Coastguard Worker fmt_string += ( 305*da0073e9SAndroid Build Coastguard Worker f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" 306*da0073e9SAndroid Build Coastguard Worker ) 307*da0073e9SAndroid Build Coastguard Worker return fmt_string 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Workerclass _IntegerLessThan(Constraint): 311*da0073e9SAndroid Build Coastguard Worker """ 312*da0073e9SAndroid Build Coastguard Worker Constrain to an integer interval `(-inf, upper_bound]`. 313*da0073e9SAndroid Build Coastguard Worker """ 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker is_discrete = True 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker def __init__(self, upper_bound): 318*da0073e9SAndroid Build Coastguard Worker self.upper_bound = upper_bound 319*da0073e9SAndroid Build Coastguard Worker super().__init__() 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker def check(self, value): 322*da0073e9SAndroid Build Coastguard Worker return (value % 1 == 0) & (value <= self.upper_bound) 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 325*da0073e9SAndroid Build Coastguard Worker fmt_string = self.__class__.__name__[1:] 326*da0073e9SAndroid Build Coastguard Worker fmt_string += f"(upper_bound={self.upper_bound})" 327*da0073e9SAndroid Build Coastguard Worker return fmt_string 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Workerclass _IntegerGreaterThan(Constraint): 331*da0073e9SAndroid Build Coastguard Worker """ 332*da0073e9SAndroid Build Coastguard Worker Constrain to an integer interval `[lower_bound, inf)`. 333*da0073e9SAndroid Build Coastguard Worker """ 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker is_discrete = True 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker def __init__(self, lower_bound): 338*da0073e9SAndroid Build Coastguard Worker self.lower_bound = lower_bound 339*da0073e9SAndroid Build Coastguard Worker super().__init__() 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker def check(self, value): 342*da0073e9SAndroid Build Coastguard Worker return (value % 1 == 0) & (value >= self.lower_bound) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 345*da0073e9SAndroid Build Coastguard Worker fmt_string = self.__class__.__name__[1:] 346*da0073e9SAndroid Build Coastguard Worker fmt_string += f"(lower_bound={self.lower_bound})" 347*da0073e9SAndroid Build Coastguard Worker return fmt_string 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Workerclass _Real(Constraint): 351*da0073e9SAndroid Build Coastguard Worker """ 352*da0073e9SAndroid Build Coastguard Worker Trivially constrain to the extended real line `[-inf, inf]`. 353*da0073e9SAndroid Build Coastguard Worker """ 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker def check(self, value): 356*da0073e9SAndroid Build Coastguard Worker return value == value # False for NANs. 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Workerclass _GreaterThan(Constraint): 360*da0073e9SAndroid Build Coastguard Worker """ 361*da0073e9SAndroid Build Coastguard Worker Constrain to a real half line `(lower_bound, inf]`. 362*da0073e9SAndroid Build Coastguard Worker """ 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker def __init__(self, lower_bound): 365*da0073e9SAndroid Build Coastguard Worker self.lower_bound = lower_bound 366*da0073e9SAndroid Build Coastguard Worker super().__init__() 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker def check(self, value): 369*da0073e9SAndroid Build Coastguard Worker return self.lower_bound < value 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 372*da0073e9SAndroid Build Coastguard Worker fmt_string = self.__class__.__name__[1:] 373*da0073e9SAndroid Build Coastguard Worker fmt_string += f"(lower_bound={self.lower_bound})" 374*da0073e9SAndroid Build Coastguard Worker return fmt_string 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Workerclass _GreaterThanEq(Constraint): 378*da0073e9SAndroid Build Coastguard Worker """ 379*da0073e9SAndroid Build Coastguard Worker Constrain to a real half line `[lower_bound, inf)`. 380*da0073e9SAndroid Build Coastguard Worker """ 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker def __init__(self, lower_bound): 383*da0073e9SAndroid Build Coastguard Worker self.lower_bound = lower_bound 384*da0073e9SAndroid Build Coastguard Worker super().__init__() 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker def check(self, value): 387*da0073e9SAndroid Build Coastguard Worker return self.lower_bound <= value 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 390*da0073e9SAndroid Build Coastguard Worker fmt_string = self.__class__.__name__[1:] 391*da0073e9SAndroid Build Coastguard Worker fmt_string += f"(lower_bound={self.lower_bound})" 392*da0073e9SAndroid Build Coastguard Worker return fmt_string 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Workerclass _LessThan(Constraint): 396*da0073e9SAndroid Build Coastguard Worker """ 397*da0073e9SAndroid Build Coastguard Worker Constrain to a real half line `[-inf, upper_bound)`. 398*da0073e9SAndroid Build Coastguard Worker """ 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker def __init__(self, upper_bound): 401*da0073e9SAndroid Build Coastguard Worker self.upper_bound = upper_bound 402*da0073e9SAndroid Build Coastguard Worker super().__init__() 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker def check(self, value): 405*da0073e9SAndroid Build Coastguard Worker return value < self.upper_bound 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 408*da0073e9SAndroid Build Coastguard Worker fmt_string = self.__class__.__name__[1:] 409*da0073e9SAndroid Build Coastguard Worker fmt_string += f"(upper_bound={self.upper_bound})" 410*da0073e9SAndroid Build Coastguard Worker return fmt_string 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Workerclass _Interval(Constraint): 414*da0073e9SAndroid Build Coastguard Worker """ 415*da0073e9SAndroid Build Coastguard Worker Constrain to a real interval `[lower_bound, upper_bound]`. 416*da0073e9SAndroid Build Coastguard Worker """ 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker def __init__(self, lower_bound, upper_bound): 419*da0073e9SAndroid Build Coastguard Worker self.lower_bound = lower_bound 420*da0073e9SAndroid Build Coastguard Worker self.upper_bound = upper_bound 421*da0073e9SAndroid Build Coastguard Worker super().__init__() 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker def check(self, value): 424*da0073e9SAndroid Build Coastguard Worker return (self.lower_bound <= value) & (value <= self.upper_bound) 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 427*da0073e9SAndroid Build Coastguard Worker fmt_string = self.__class__.__name__[1:] 428*da0073e9SAndroid Build Coastguard Worker fmt_string += ( 429*da0073e9SAndroid Build Coastguard Worker f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" 430*da0073e9SAndroid Build Coastguard Worker ) 431*da0073e9SAndroid Build Coastguard Worker return fmt_string 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Workerclass _HalfOpenInterval(Constraint): 435*da0073e9SAndroid Build Coastguard Worker """ 436*da0073e9SAndroid Build Coastguard Worker Constrain to a real interval `[lower_bound, upper_bound)`. 437*da0073e9SAndroid Build Coastguard Worker """ 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker def __init__(self, lower_bound, upper_bound): 440*da0073e9SAndroid Build Coastguard Worker self.lower_bound = lower_bound 441*da0073e9SAndroid Build Coastguard Worker self.upper_bound = upper_bound 442*da0073e9SAndroid Build Coastguard Worker super().__init__() 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker def check(self, value): 445*da0073e9SAndroid Build Coastguard Worker return (self.lower_bound <= value) & (value < self.upper_bound) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 448*da0073e9SAndroid Build Coastguard Worker fmt_string = self.__class__.__name__[1:] 449*da0073e9SAndroid Build Coastguard Worker fmt_string += ( 450*da0073e9SAndroid Build Coastguard Worker f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" 451*da0073e9SAndroid Build Coastguard Worker ) 452*da0073e9SAndroid Build Coastguard Worker return fmt_string 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Workerclass _Simplex(Constraint): 456*da0073e9SAndroid Build Coastguard Worker """ 457*da0073e9SAndroid Build Coastguard Worker Constrain to the unit simplex in the innermost (rightmost) dimension. 458*da0073e9SAndroid Build Coastguard Worker Specifically: `x >= 0` and `x.sum(-1) == 1`. 459*da0073e9SAndroid Build Coastguard Worker """ 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker event_dim = 1 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker def check(self, value): 464*da0073e9SAndroid Build Coastguard Worker return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6) 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Workerclass _Multinomial(Constraint): 468*da0073e9SAndroid Build Coastguard Worker """ 469*da0073e9SAndroid Build Coastguard Worker Constrain to nonnegative integer values summing to at most an upper bound. 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker Note due to limitations of the Multinomial distribution, this currently 472*da0073e9SAndroid Build Coastguard Worker checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future 473*da0073e9SAndroid Build Coastguard Worker this may be strengthened to ``value.sum(-1) == upper_bound``. 474*da0073e9SAndroid Build Coastguard Worker """ 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker is_discrete = True 477*da0073e9SAndroid Build Coastguard Worker event_dim = 1 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker def __init__(self, upper_bound): 480*da0073e9SAndroid Build Coastguard Worker self.upper_bound = upper_bound 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Worker def check(self, x): 483*da0073e9SAndroid Build Coastguard Worker return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound) 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Workerclass _LowerTriangular(Constraint): 487*da0073e9SAndroid Build Coastguard Worker """ 488*da0073e9SAndroid Build Coastguard Worker Constrain to lower-triangular square matrices. 489*da0073e9SAndroid Build Coastguard Worker """ 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker event_dim = 2 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker def check(self, value): 494*da0073e9SAndroid Build Coastguard Worker value_tril = value.tril() 495*da0073e9SAndroid Build Coastguard Worker return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Workerclass _LowerCholesky(Constraint): 499*da0073e9SAndroid Build Coastguard Worker """ 500*da0073e9SAndroid Build Coastguard Worker Constrain to lower-triangular square matrices with positive diagonals. 501*da0073e9SAndroid Build Coastguard Worker """ 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker event_dim = 2 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker def check(self, value): 506*da0073e9SAndroid Build Coastguard Worker value_tril = value.tril() 507*da0073e9SAndroid Build Coastguard Worker lower_triangular = ( 508*da0073e9SAndroid Build Coastguard Worker (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] 509*da0073e9SAndroid Build Coastguard Worker ) 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0] 512*da0073e9SAndroid Build Coastguard Worker return lower_triangular & positive_diagonal 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Workerclass _CorrCholesky(Constraint): 516*da0073e9SAndroid Build Coastguard Worker """ 517*da0073e9SAndroid Build Coastguard Worker Constrain to lower-triangular square matrices with positive diagonals and each 518*da0073e9SAndroid Build Coastguard Worker row vector being of unit length. 519*da0073e9SAndroid Build Coastguard Worker """ 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker event_dim = 2 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker def check(self, value): 524*da0073e9SAndroid Build Coastguard Worker tol = ( 525*da0073e9SAndroid Build Coastguard Worker torch.finfo(value.dtype).eps * value.size(-1) * 10 526*da0073e9SAndroid Build Coastguard Worker ) # 10 is an adjustable fudge factor 527*da0073e9SAndroid Build Coastguard Worker row_norm = torch.linalg.norm(value.detach(), dim=-1) 528*da0073e9SAndroid Build Coastguard Worker unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1) 529*da0073e9SAndroid Build Coastguard Worker return _LowerCholesky().check(value) & unit_row_norm 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Workerclass _Square(Constraint): 533*da0073e9SAndroid Build Coastguard Worker """ 534*da0073e9SAndroid Build Coastguard Worker Constrain to square matrices. 535*da0073e9SAndroid Build Coastguard Worker """ 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker event_dim = 2 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker def check(self, value): 540*da0073e9SAndroid Build Coastguard Worker return torch.full( 541*da0073e9SAndroid Build Coastguard Worker size=value.shape[:-2], 542*da0073e9SAndroid Build Coastguard Worker fill_value=(value.shape[-2] == value.shape[-1]), 543*da0073e9SAndroid Build Coastguard Worker dtype=torch.bool, 544*da0073e9SAndroid Build Coastguard Worker device=value.device, 545*da0073e9SAndroid Build Coastguard Worker ) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Workerclass _Symmetric(_Square): 549*da0073e9SAndroid Build Coastguard Worker """ 550*da0073e9SAndroid Build Coastguard Worker Constrain to Symmetric square matrices. 551*da0073e9SAndroid Build Coastguard Worker """ 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker def check(self, value): 554*da0073e9SAndroid Build Coastguard Worker square_check = super().check(value) 555*da0073e9SAndroid Build Coastguard Worker if not square_check.all(): 556*da0073e9SAndroid Build Coastguard Worker return square_check 557*da0073e9SAndroid Build Coastguard Worker return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Workerclass _PositiveSemidefinite(_Symmetric): 561*da0073e9SAndroid Build Coastguard Worker """ 562*da0073e9SAndroid Build Coastguard Worker Constrain to positive-semidefinite matrices. 563*da0073e9SAndroid Build Coastguard Worker """ 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker def check(self, value): 566*da0073e9SAndroid Build Coastguard Worker sym_check = super().check(value) 567*da0073e9SAndroid Build Coastguard Worker if not sym_check.all(): 568*da0073e9SAndroid Build Coastguard Worker return sym_check 569*da0073e9SAndroid Build Coastguard Worker return torch.linalg.eigvalsh(value).ge(0).all(-1) 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Workerclass _PositiveDefinite(_Symmetric): 573*da0073e9SAndroid Build Coastguard Worker """ 574*da0073e9SAndroid Build Coastguard Worker Constrain to positive-definite matrices. 575*da0073e9SAndroid Build Coastguard Worker """ 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker def check(self, value): 578*da0073e9SAndroid Build Coastguard Worker sym_check = super().check(value) 579*da0073e9SAndroid Build Coastguard Worker if not sym_check.all(): 580*da0073e9SAndroid Build Coastguard Worker return sym_check 581*da0073e9SAndroid Build Coastguard Worker return torch.linalg.cholesky_ex(value).info.eq(0) 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Workerclass _Cat(Constraint): 585*da0073e9SAndroid Build Coastguard Worker """ 586*da0073e9SAndroid Build Coastguard Worker Constraint functor that applies a sequence of constraints 587*da0073e9SAndroid Build Coastguard Worker `cseq` at the submatrices at dimension `dim`, 588*da0073e9SAndroid Build Coastguard Worker each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`. 589*da0073e9SAndroid Build Coastguard Worker """ 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker def __init__(self, cseq, dim=0, lengths=None): 592*da0073e9SAndroid Build Coastguard Worker assert all(isinstance(c, Constraint) for c in cseq) 593*da0073e9SAndroid Build Coastguard Worker self.cseq = list(cseq) 594*da0073e9SAndroid Build Coastguard Worker if lengths is None: 595*da0073e9SAndroid Build Coastguard Worker lengths = [1] * len(self.cseq) 596*da0073e9SAndroid Build Coastguard Worker self.lengths = list(lengths) 597*da0073e9SAndroid Build Coastguard Worker assert len(self.lengths) == len(self.cseq) 598*da0073e9SAndroid Build Coastguard Worker self.dim = dim 599*da0073e9SAndroid Build Coastguard Worker super().__init__() 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker @property 602*da0073e9SAndroid Build Coastguard Worker def is_discrete(self): 603*da0073e9SAndroid Build Coastguard Worker return any(c.is_discrete for c in self.cseq) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker @property 606*da0073e9SAndroid Build Coastguard Worker def event_dim(self): 607*da0073e9SAndroid Build Coastguard Worker return max(c.event_dim for c in self.cseq) 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker def check(self, value): 610*da0073e9SAndroid Build Coastguard Worker assert -value.dim() <= self.dim < value.dim() 611*da0073e9SAndroid Build Coastguard Worker checks = [] 612*da0073e9SAndroid Build Coastguard Worker start = 0 613*da0073e9SAndroid Build Coastguard Worker for constr, length in zip(self.cseq, self.lengths): 614*da0073e9SAndroid Build Coastguard Worker v = value.narrow(self.dim, start, length) 615*da0073e9SAndroid Build Coastguard Worker checks.append(constr.check(v)) 616*da0073e9SAndroid Build Coastguard Worker start = start + length # avoid += for jit compat 617*da0073e9SAndroid Build Coastguard Worker return torch.cat(checks, self.dim) 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Workerclass _Stack(Constraint): 621*da0073e9SAndroid Build Coastguard Worker """ 622*da0073e9SAndroid Build Coastguard Worker Constraint functor that applies a sequence of constraints 623*da0073e9SAndroid Build Coastguard Worker `cseq` at the submatrices at dimension `dim`, 624*da0073e9SAndroid Build Coastguard Worker in a way compatible with :func:`torch.stack`. 625*da0073e9SAndroid Build Coastguard Worker """ 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker def __init__(self, cseq, dim=0): 628*da0073e9SAndroid Build Coastguard Worker assert all(isinstance(c, Constraint) for c in cseq) 629*da0073e9SAndroid Build Coastguard Worker self.cseq = list(cseq) 630*da0073e9SAndroid Build Coastguard Worker self.dim = dim 631*da0073e9SAndroid Build Coastguard Worker super().__init__() 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker @property 634*da0073e9SAndroid Build Coastguard Worker def is_discrete(self): 635*da0073e9SAndroid Build Coastguard Worker return any(c.is_discrete for c in self.cseq) 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker @property 638*da0073e9SAndroid Build Coastguard Worker def event_dim(self): 639*da0073e9SAndroid Build Coastguard Worker dim = max(c.event_dim for c in self.cseq) 640*da0073e9SAndroid Build Coastguard Worker if self.dim + dim < 0: 641*da0073e9SAndroid Build Coastguard Worker dim += 1 642*da0073e9SAndroid Build Coastguard Worker return dim 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Worker def check(self, value): 645*da0073e9SAndroid Build Coastguard Worker assert -value.dim() <= self.dim < value.dim() 646*da0073e9SAndroid Build Coastguard Worker vs = [value.select(self.dim, i) for i in range(value.size(self.dim))] 647*da0073e9SAndroid Build Coastguard Worker return torch.stack( 648*da0073e9SAndroid Build Coastguard Worker [constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim 649*da0073e9SAndroid Build Coastguard Worker ) 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker# Public interface. 653*da0073e9SAndroid Build Coastguard Workerdependent = _Dependent() 654*da0073e9SAndroid Build Coastguard Workerdependent_property = _DependentProperty 655*da0073e9SAndroid Build Coastguard Workerindependent = _IndependentConstraint 656*da0073e9SAndroid Build Coastguard Workerboolean = _Boolean() 657*da0073e9SAndroid Build Coastguard Workerone_hot = _OneHot() 658*da0073e9SAndroid Build Coastguard Workernonnegative_integer = _IntegerGreaterThan(0) 659*da0073e9SAndroid Build Coastguard Workerpositive_integer = _IntegerGreaterThan(1) 660*da0073e9SAndroid Build Coastguard Workerinteger_interval = _IntegerInterval 661*da0073e9SAndroid Build Coastguard Workerreal = _Real() 662*da0073e9SAndroid Build Coastguard Workerreal_vector = independent(real, 1) 663*da0073e9SAndroid Build Coastguard Workerpositive = _GreaterThan(0.0) 664*da0073e9SAndroid Build Coastguard Workernonnegative = _GreaterThanEq(0.0) 665*da0073e9SAndroid Build Coastguard Workergreater_than = _GreaterThan 666*da0073e9SAndroid Build Coastguard Workergreater_than_eq = _GreaterThanEq 667*da0073e9SAndroid Build Coastguard Workerless_than = _LessThan 668*da0073e9SAndroid Build Coastguard Workermultinomial = _Multinomial 669*da0073e9SAndroid Build Coastguard Workerunit_interval = _Interval(0.0, 1.0) 670*da0073e9SAndroid Build Coastguard Workerinterval = _Interval 671*da0073e9SAndroid Build Coastguard Workerhalf_open_interval = _HalfOpenInterval 672*da0073e9SAndroid Build Coastguard Workersimplex = _Simplex() 673*da0073e9SAndroid Build Coastguard Workerlower_triangular = _LowerTriangular() 674*da0073e9SAndroid Build Coastguard Workerlower_cholesky = _LowerCholesky() 675*da0073e9SAndroid Build Coastguard Workercorr_cholesky = _CorrCholesky() 676*da0073e9SAndroid Build Coastguard Workersquare = _Square() 677*da0073e9SAndroid Build Coastguard Workersymmetric = _Symmetric() 678*da0073e9SAndroid Build Coastguard Workerpositive_semidefinite = _PositiveSemidefinite() 679*da0073e9SAndroid Build Coastguard Workerpositive_definite = _PositiveDefinite() 680*da0073e9SAndroid Build Coastguard Workercat = _Cat 681*da0073e9SAndroid Build Coastguard Workerstack = _Stack 682