xref: /aosp_15_r20/external/pytorch/torch/distributions/constraints.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerr"""
3*da0073e9SAndroid Build Coastguard WorkerThe following constraints are implemented:
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker- ``constraints.boolean``
6*da0073e9SAndroid Build Coastguard Worker- ``constraints.cat``
7*da0073e9SAndroid Build Coastguard Worker- ``constraints.corr_cholesky``
8*da0073e9SAndroid Build Coastguard Worker- ``constraints.dependent``
9*da0073e9SAndroid Build Coastguard Worker- ``constraints.greater_than(lower_bound)``
10*da0073e9SAndroid Build Coastguard Worker- ``constraints.greater_than_eq(lower_bound)``
11*da0073e9SAndroid Build Coastguard Worker- ``constraints.independent(constraint, reinterpreted_batch_ndims)``
12*da0073e9SAndroid Build Coastguard Worker- ``constraints.integer_interval(lower_bound, upper_bound)``
13*da0073e9SAndroid Build Coastguard Worker- ``constraints.interval(lower_bound, upper_bound)``
14*da0073e9SAndroid Build Coastguard Worker- ``constraints.less_than(upper_bound)``
15*da0073e9SAndroid Build Coastguard Worker- ``constraints.lower_cholesky``
16*da0073e9SAndroid Build Coastguard Worker- ``constraints.lower_triangular``
17*da0073e9SAndroid Build Coastguard Worker- ``constraints.multinomial``
18*da0073e9SAndroid Build Coastguard Worker- ``constraints.nonnegative``
19*da0073e9SAndroid Build Coastguard Worker- ``constraints.nonnegative_integer``
20*da0073e9SAndroid Build Coastguard Worker- ``constraints.one_hot``
21*da0073e9SAndroid Build Coastguard Worker- ``constraints.positive_integer``
22*da0073e9SAndroid Build Coastguard Worker- ``constraints.positive``
23*da0073e9SAndroid Build Coastguard Worker- ``constraints.positive_semidefinite``
24*da0073e9SAndroid Build Coastguard Worker- ``constraints.positive_definite``
25*da0073e9SAndroid Build Coastguard Worker- ``constraints.real_vector``
26*da0073e9SAndroid Build Coastguard Worker- ``constraints.real``
27*da0073e9SAndroid Build Coastguard Worker- ``constraints.simplex``
28*da0073e9SAndroid Build Coastguard Worker- ``constraints.symmetric``
29*da0073e9SAndroid Build Coastguard Worker- ``constraints.stack``
30*da0073e9SAndroid Build Coastguard Worker- ``constraints.square``
31*da0073e9SAndroid Build Coastguard Worker- ``constraints.symmetric``
32*da0073e9SAndroid Build Coastguard Worker- ``constraints.unit_interval``
33*da0073e9SAndroid Build Coastguard Worker"""
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workerimport torch
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker__all__ = [
39*da0073e9SAndroid Build Coastguard Worker    "Constraint",
40*da0073e9SAndroid Build Coastguard Worker    "boolean",
41*da0073e9SAndroid Build Coastguard Worker    "cat",
42*da0073e9SAndroid Build Coastguard Worker    "corr_cholesky",
43*da0073e9SAndroid Build Coastguard Worker    "dependent",
44*da0073e9SAndroid Build Coastguard Worker    "dependent_property",
45*da0073e9SAndroid Build Coastguard Worker    "greater_than",
46*da0073e9SAndroid Build Coastguard Worker    "greater_than_eq",
47*da0073e9SAndroid Build Coastguard Worker    "independent",
48*da0073e9SAndroid Build Coastguard Worker    "integer_interval",
49*da0073e9SAndroid Build Coastguard Worker    "interval",
50*da0073e9SAndroid Build Coastguard Worker    "half_open_interval",
51*da0073e9SAndroid Build Coastguard Worker    "is_dependent",
52*da0073e9SAndroid Build Coastguard Worker    "less_than",
53*da0073e9SAndroid Build Coastguard Worker    "lower_cholesky",
54*da0073e9SAndroid Build Coastguard Worker    "lower_triangular",
55*da0073e9SAndroid Build Coastguard Worker    "multinomial",
56*da0073e9SAndroid Build Coastguard Worker    "nonnegative",
57*da0073e9SAndroid Build Coastguard Worker    "nonnegative_integer",
58*da0073e9SAndroid Build Coastguard Worker    "one_hot",
59*da0073e9SAndroid Build Coastguard Worker    "positive",
60*da0073e9SAndroid Build Coastguard Worker    "positive_semidefinite",
61*da0073e9SAndroid Build Coastguard Worker    "positive_definite",
62*da0073e9SAndroid Build Coastguard Worker    "positive_integer",
63*da0073e9SAndroid Build Coastguard Worker    "real",
64*da0073e9SAndroid Build Coastguard Worker    "real_vector",
65*da0073e9SAndroid Build Coastguard Worker    "simplex",
66*da0073e9SAndroid Build Coastguard Worker    "square",
67*da0073e9SAndroid Build Coastguard Worker    "stack",
68*da0073e9SAndroid Build Coastguard Worker    "symmetric",
69*da0073e9SAndroid Build Coastguard Worker    "unit_interval",
70*da0073e9SAndroid Build Coastguard Worker]
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Workerclass Constraint:
74*da0073e9SAndroid Build Coastguard Worker    """
75*da0073e9SAndroid Build Coastguard Worker    Abstract base class for constraints.
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    A constraint object represents a region over which a variable is valid,
78*da0073e9SAndroid Build Coastguard Worker    e.g. within which a variable can be optimized.
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    Attributes:
81*da0073e9SAndroid Build Coastguard Worker        is_discrete (bool): Whether constrained space is discrete.
82*da0073e9SAndroid Build Coastguard Worker            Defaults to False.
83*da0073e9SAndroid Build Coastguard Worker        event_dim (int): Number of rightmost dimensions that together define
84*da0073e9SAndroid Build Coastguard Worker            an event. The :meth:`check` method will remove this many dimensions
85*da0073e9SAndroid Build Coastguard Worker            when computing validity.
86*da0073e9SAndroid Build Coastguard Worker    """
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    is_discrete = False  # Default to continuous.
89*da0073e9SAndroid Build Coastguard Worker    event_dim = 0  # Default to univariate.
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
92*da0073e9SAndroid Build Coastguard Worker        """
93*da0073e9SAndroid Build Coastguard Worker        Returns a byte tensor of ``sample_shape + batch_shape`` indicating
94*da0073e9SAndroid Build Coastguard Worker        whether each event in value satisfies this constraint.
95*da0073e9SAndroid Build Coastguard Worker        """
96*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
99*da0073e9SAndroid Build Coastguard Worker        return self.__class__.__name__[1:] + "()"
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Workerclass _Dependent(Constraint):
103*da0073e9SAndroid Build Coastguard Worker    """
104*da0073e9SAndroid Build Coastguard Worker    Placeholder for variables whose support depends on other variables.
105*da0073e9SAndroid Build Coastguard Worker    These variables obey no simple coordinate-wise constraints.
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    Args:
108*da0073e9SAndroid Build Coastguard Worker        is_discrete (bool): Optional value of ``.is_discrete`` in case this
109*da0073e9SAndroid Build Coastguard Worker            can be computed statically. If not provided, access to the
110*da0073e9SAndroid Build Coastguard Worker            ``.is_discrete`` attribute will raise a NotImplementedError.
111*da0073e9SAndroid Build Coastguard Worker        event_dim (int): Optional value of ``.event_dim`` in case this
112*da0073e9SAndroid Build Coastguard Worker            can be computed statically. If not provided, access to the
113*da0073e9SAndroid Build Coastguard Worker            ``.event_dim`` attribute will raise a NotImplementedError.
114*da0073e9SAndroid Build Coastguard Worker    """
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
117*da0073e9SAndroid Build Coastguard Worker        self._is_discrete = is_discrete
118*da0073e9SAndroid Build Coastguard Worker        self._event_dim = event_dim
119*da0073e9SAndroid Build Coastguard Worker        super().__init__()
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    @property
122*da0073e9SAndroid Build Coastguard Worker    def is_discrete(self):
123*da0073e9SAndroid Build Coastguard Worker        if self._is_discrete is NotImplemented:
124*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(".is_discrete cannot be determined statically")
125*da0073e9SAndroid Build Coastguard Worker        return self._is_discrete
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    @property
128*da0073e9SAndroid Build Coastguard Worker    def event_dim(self):
129*da0073e9SAndroid Build Coastguard Worker        if self._event_dim is NotImplemented:
130*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(".event_dim cannot be determined statically")
131*da0073e9SAndroid Build Coastguard Worker        return self._event_dim
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
134*da0073e9SAndroid Build Coastguard Worker        """
135*da0073e9SAndroid Build Coastguard Worker        Support for syntax to customize static attributes::
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker            constraints.dependent(is_discrete=True, event_dim=1)
138*da0073e9SAndroid Build Coastguard Worker        """
139*da0073e9SAndroid Build Coastguard Worker        if is_discrete is NotImplemented:
140*da0073e9SAndroid Build Coastguard Worker            is_discrete = self._is_discrete
141*da0073e9SAndroid Build Coastguard Worker        if event_dim is NotImplemented:
142*da0073e9SAndroid Build Coastguard Worker            event_dim = self._event_dim
143*da0073e9SAndroid Build Coastguard Worker        return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    def check(self, x):
146*da0073e9SAndroid Build Coastguard Worker        raise ValueError("Cannot determine validity of dependent constraint")
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Workerdef is_dependent(constraint):
150*da0073e9SAndroid Build Coastguard Worker    """
151*da0073e9SAndroid Build Coastguard Worker    Checks if ``constraint`` is a ``_Dependent`` object.
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    Args:
154*da0073e9SAndroid Build Coastguard Worker        constraint : A ``Constraint`` object.
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker    Returns:
157*da0073e9SAndroid Build Coastguard Worker        ``bool``: True if ``constraint`` can be refined to the type ``_Dependent``, False otherwise.
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker    Examples:
160*da0073e9SAndroid Build Coastguard Worker        >>> import torch
161*da0073e9SAndroid Build Coastguard Worker        >>> from torch.distributions import Bernoulli
162*da0073e9SAndroid Build Coastguard Worker        >>> from torch.distributions.constraints import is_dependent
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker        >>> dist = Bernoulli(probs = torch.tensor([0.6], requires_grad=True))
165*da0073e9SAndroid Build Coastguard Worker        >>> constraint1 = dist.arg_constraints["probs"]
166*da0073e9SAndroid Build Coastguard Worker        >>> constraint2 = dist.arg_constraints["logits"]
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        >>> for constraint in [constraint1, constraint2]:
169*da0073e9SAndroid Build Coastguard Worker        >>>     if is_dependent(constraint):
170*da0073e9SAndroid Build Coastguard Worker        >>>         continue
171*da0073e9SAndroid Build Coastguard Worker    """
172*da0073e9SAndroid Build Coastguard Worker    return isinstance(constraint, _Dependent)
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Workerclass _DependentProperty(property, _Dependent):
176*da0073e9SAndroid Build Coastguard Worker    """
177*da0073e9SAndroid Build Coastguard Worker    Decorator that extends @property to act like a `Dependent` constraint when
178*da0073e9SAndroid Build Coastguard Worker    called on a class and act like a property when called on an object.
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    Example::
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker        class Uniform(Distribution):
183*da0073e9SAndroid Build Coastguard Worker            def __init__(self, low, high):
184*da0073e9SAndroid Build Coastguard Worker                self.low = low
185*da0073e9SAndroid Build Coastguard Worker                self.high = high
186*da0073e9SAndroid Build Coastguard Worker            @constraints.dependent_property(is_discrete=False, event_dim=0)
187*da0073e9SAndroid Build Coastguard Worker            def support(self):
188*da0073e9SAndroid Build Coastguard Worker                return constraints.interval(self.low, self.high)
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker    Args:
191*da0073e9SAndroid Build Coastguard Worker        fn (Callable): The function to be decorated.
192*da0073e9SAndroid Build Coastguard Worker        is_discrete (bool): Optional value of ``.is_discrete`` in case this
193*da0073e9SAndroid Build Coastguard Worker            can be computed statically. If not provided, access to the
194*da0073e9SAndroid Build Coastguard Worker            ``.is_discrete`` attribute will raise a NotImplementedError.
195*da0073e9SAndroid Build Coastguard Worker        event_dim (int): Optional value of ``.event_dim`` in case this
196*da0073e9SAndroid Build Coastguard Worker            can be computed statically. If not provided, access to the
197*da0073e9SAndroid Build Coastguard Worker            ``.event_dim`` attribute will raise a NotImplementedError.
198*da0073e9SAndroid Build Coastguard Worker    """
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker    def __init__(
201*da0073e9SAndroid Build Coastguard Worker        self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented
202*da0073e9SAndroid Build Coastguard Worker    ):
203*da0073e9SAndroid Build Coastguard Worker        super().__init__(fn)
204*da0073e9SAndroid Build Coastguard Worker        self._is_discrete = is_discrete
205*da0073e9SAndroid Build Coastguard Worker        self._event_dim = event_dim
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker    def __call__(self, fn):
208*da0073e9SAndroid Build Coastguard Worker        """
209*da0073e9SAndroid Build Coastguard Worker        Support for syntax to customize static attributes::
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker            @constraints.dependent_property(is_discrete=True, event_dim=1)
212*da0073e9SAndroid Build Coastguard Worker            def support(self):
213*da0073e9SAndroid Build Coastguard Worker                ...
214*da0073e9SAndroid Build Coastguard Worker        """
215*da0073e9SAndroid Build Coastguard Worker        return _DependentProperty(
216*da0073e9SAndroid Build Coastguard Worker            fn, is_discrete=self._is_discrete, event_dim=self._event_dim
217*da0073e9SAndroid Build Coastguard Worker        )
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Workerclass _IndependentConstraint(Constraint):
221*da0073e9SAndroid Build Coastguard Worker    """
222*da0073e9SAndroid Build Coastguard Worker    Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
223*da0073e9SAndroid Build Coastguard Worker    dims in :meth:`check`, so that an event is valid only if all its
224*da0073e9SAndroid Build Coastguard Worker    independent entries are valid.
225*da0073e9SAndroid Build Coastguard Worker    """
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker    def __init__(self, base_constraint, reinterpreted_batch_ndims):
228*da0073e9SAndroid Build Coastguard Worker        assert isinstance(base_constraint, Constraint)
229*da0073e9SAndroid Build Coastguard Worker        assert isinstance(reinterpreted_batch_ndims, int)
230*da0073e9SAndroid Build Coastguard Worker        assert reinterpreted_batch_ndims >= 0
231*da0073e9SAndroid Build Coastguard Worker        self.base_constraint = base_constraint
232*da0073e9SAndroid Build Coastguard Worker        self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
233*da0073e9SAndroid Build Coastguard Worker        super().__init__()
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker    @property
236*da0073e9SAndroid Build Coastguard Worker    def is_discrete(self):
237*da0073e9SAndroid Build Coastguard Worker        return self.base_constraint.is_discrete
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    @property
240*da0073e9SAndroid Build Coastguard Worker    def event_dim(self):
241*da0073e9SAndroid Build Coastguard Worker        return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
244*da0073e9SAndroid Build Coastguard Worker        result = self.base_constraint.check(value)
245*da0073e9SAndroid Build Coastguard Worker        if result.dim() < self.reinterpreted_batch_ndims:
246*da0073e9SAndroid Build Coastguard Worker            expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
247*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
248*da0073e9SAndroid Build Coastguard Worker                f"Expected value.dim() >= {expected} but got {value.dim()}"
249*da0073e9SAndroid Build Coastguard Worker            )
250*da0073e9SAndroid Build Coastguard Worker        result = result.reshape(
251*da0073e9SAndroid Build Coastguard Worker            result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)
252*da0073e9SAndroid Build Coastguard Worker        )
253*da0073e9SAndroid Build Coastguard Worker        result = result.all(-1)
254*da0073e9SAndroid Build Coastguard Worker        return result
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
257*da0073e9SAndroid Build Coastguard Worker        return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})"
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Workerclass _Boolean(Constraint):
261*da0073e9SAndroid Build Coastguard Worker    """
262*da0073e9SAndroid Build Coastguard Worker    Constrain to the two values `{0, 1}`.
263*da0073e9SAndroid Build Coastguard Worker    """
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    is_discrete = True
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
268*da0073e9SAndroid Build Coastguard Worker        return (value == 0) | (value == 1)
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Workerclass _OneHot(Constraint):
272*da0073e9SAndroid Build Coastguard Worker    """
273*da0073e9SAndroid Build Coastguard Worker    Constrain to one-hot vectors.
274*da0073e9SAndroid Build Coastguard Worker    """
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker    is_discrete = True
277*da0073e9SAndroid Build Coastguard Worker    event_dim = 1
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
280*da0073e9SAndroid Build Coastguard Worker        is_boolean = (value == 0) | (value == 1)
281*da0073e9SAndroid Build Coastguard Worker        is_normalized = value.sum(-1).eq(1)
282*da0073e9SAndroid Build Coastguard Worker        return is_boolean.all(-1) & is_normalized
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Workerclass _IntegerInterval(Constraint):
286*da0073e9SAndroid Build Coastguard Worker    """
287*da0073e9SAndroid Build Coastguard Worker    Constrain to an integer interval `[lower_bound, upper_bound]`.
288*da0073e9SAndroid Build Coastguard Worker    """
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker    is_discrete = True
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker    def __init__(self, lower_bound, upper_bound):
293*da0073e9SAndroid Build Coastguard Worker        self.lower_bound = lower_bound
294*da0073e9SAndroid Build Coastguard Worker        self.upper_bound = upper_bound
295*da0073e9SAndroid Build Coastguard Worker        super().__init__()
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
298*da0073e9SAndroid Build Coastguard Worker        return (
299*da0073e9SAndroid Build Coastguard Worker            (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
300*da0073e9SAndroid Build Coastguard Worker        )
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
303*da0073e9SAndroid Build Coastguard Worker        fmt_string = self.__class__.__name__[1:]
304*da0073e9SAndroid Build Coastguard Worker        fmt_string += (
305*da0073e9SAndroid Build Coastguard Worker            f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
306*da0073e9SAndroid Build Coastguard Worker        )
307*da0073e9SAndroid Build Coastguard Worker        return fmt_string
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Workerclass _IntegerLessThan(Constraint):
311*da0073e9SAndroid Build Coastguard Worker    """
312*da0073e9SAndroid Build Coastguard Worker    Constrain to an integer interval `(-inf, upper_bound]`.
313*da0073e9SAndroid Build Coastguard Worker    """
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker    is_discrete = True
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker    def __init__(self, upper_bound):
318*da0073e9SAndroid Build Coastguard Worker        self.upper_bound = upper_bound
319*da0073e9SAndroid Build Coastguard Worker        super().__init__()
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
322*da0073e9SAndroid Build Coastguard Worker        return (value % 1 == 0) & (value <= self.upper_bound)
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
325*da0073e9SAndroid Build Coastguard Worker        fmt_string = self.__class__.__name__[1:]
326*da0073e9SAndroid Build Coastguard Worker        fmt_string += f"(upper_bound={self.upper_bound})"
327*da0073e9SAndroid Build Coastguard Worker        return fmt_string
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Workerclass _IntegerGreaterThan(Constraint):
331*da0073e9SAndroid Build Coastguard Worker    """
332*da0073e9SAndroid Build Coastguard Worker    Constrain to an integer interval `[lower_bound, inf)`.
333*da0073e9SAndroid Build Coastguard Worker    """
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker    is_discrete = True
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker    def __init__(self, lower_bound):
338*da0073e9SAndroid Build Coastguard Worker        self.lower_bound = lower_bound
339*da0073e9SAndroid Build Coastguard Worker        super().__init__()
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
342*da0073e9SAndroid Build Coastguard Worker        return (value % 1 == 0) & (value >= self.lower_bound)
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
345*da0073e9SAndroid Build Coastguard Worker        fmt_string = self.__class__.__name__[1:]
346*da0073e9SAndroid Build Coastguard Worker        fmt_string += f"(lower_bound={self.lower_bound})"
347*da0073e9SAndroid Build Coastguard Worker        return fmt_string
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Workerclass _Real(Constraint):
351*da0073e9SAndroid Build Coastguard Worker    """
352*da0073e9SAndroid Build Coastguard Worker    Trivially constrain to the extended real line `[-inf, inf]`.
353*da0073e9SAndroid Build Coastguard Worker    """
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
356*da0073e9SAndroid Build Coastguard Worker        return value == value  # False for NANs.
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Workerclass _GreaterThan(Constraint):
360*da0073e9SAndroid Build Coastguard Worker    """
361*da0073e9SAndroid Build Coastguard Worker    Constrain to a real half line `(lower_bound, inf]`.
362*da0073e9SAndroid Build Coastguard Worker    """
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker    def __init__(self, lower_bound):
365*da0073e9SAndroid Build Coastguard Worker        self.lower_bound = lower_bound
366*da0073e9SAndroid Build Coastguard Worker        super().__init__()
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
369*da0073e9SAndroid Build Coastguard Worker        return self.lower_bound < value
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
372*da0073e9SAndroid Build Coastguard Worker        fmt_string = self.__class__.__name__[1:]
373*da0073e9SAndroid Build Coastguard Worker        fmt_string += f"(lower_bound={self.lower_bound})"
374*da0073e9SAndroid Build Coastguard Worker        return fmt_string
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Workerclass _GreaterThanEq(Constraint):
378*da0073e9SAndroid Build Coastguard Worker    """
379*da0073e9SAndroid Build Coastguard Worker    Constrain to a real half line `[lower_bound, inf)`.
380*da0073e9SAndroid Build Coastguard Worker    """
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker    def __init__(self, lower_bound):
383*da0073e9SAndroid Build Coastguard Worker        self.lower_bound = lower_bound
384*da0073e9SAndroid Build Coastguard Worker        super().__init__()
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
387*da0073e9SAndroid Build Coastguard Worker        return self.lower_bound <= value
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
390*da0073e9SAndroid Build Coastguard Worker        fmt_string = self.__class__.__name__[1:]
391*da0073e9SAndroid Build Coastguard Worker        fmt_string += f"(lower_bound={self.lower_bound})"
392*da0073e9SAndroid Build Coastguard Worker        return fmt_string
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Workerclass _LessThan(Constraint):
396*da0073e9SAndroid Build Coastguard Worker    """
397*da0073e9SAndroid Build Coastguard Worker    Constrain to a real half line `[-inf, upper_bound)`.
398*da0073e9SAndroid Build Coastguard Worker    """
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker    def __init__(self, upper_bound):
401*da0073e9SAndroid Build Coastguard Worker        self.upper_bound = upper_bound
402*da0073e9SAndroid Build Coastguard Worker        super().__init__()
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
405*da0073e9SAndroid Build Coastguard Worker        return value < self.upper_bound
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
408*da0073e9SAndroid Build Coastguard Worker        fmt_string = self.__class__.__name__[1:]
409*da0073e9SAndroid Build Coastguard Worker        fmt_string += f"(upper_bound={self.upper_bound})"
410*da0073e9SAndroid Build Coastguard Worker        return fmt_string
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Workerclass _Interval(Constraint):
414*da0073e9SAndroid Build Coastguard Worker    """
415*da0073e9SAndroid Build Coastguard Worker    Constrain to a real interval `[lower_bound, upper_bound]`.
416*da0073e9SAndroid Build Coastguard Worker    """
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker    def __init__(self, lower_bound, upper_bound):
419*da0073e9SAndroid Build Coastguard Worker        self.lower_bound = lower_bound
420*da0073e9SAndroid Build Coastguard Worker        self.upper_bound = upper_bound
421*da0073e9SAndroid Build Coastguard Worker        super().__init__()
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
424*da0073e9SAndroid Build Coastguard Worker        return (self.lower_bound <= value) & (value <= self.upper_bound)
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
427*da0073e9SAndroid Build Coastguard Worker        fmt_string = self.__class__.__name__[1:]
428*da0073e9SAndroid Build Coastguard Worker        fmt_string += (
429*da0073e9SAndroid Build Coastguard Worker            f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
430*da0073e9SAndroid Build Coastguard Worker        )
431*da0073e9SAndroid Build Coastguard Worker        return fmt_string
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Workerclass _HalfOpenInterval(Constraint):
435*da0073e9SAndroid Build Coastguard Worker    """
436*da0073e9SAndroid Build Coastguard Worker    Constrain to a real interval `[lower_bound, upper_bound)`.
437*da0073e9SAndroid Build Coastguard Worker    """
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker    def __init__(self, lower_bound, upper_bound):
440*da0073e9SAndroid Build Coastguard Worker        self.lower_bound = lower_bound
441*da0073e9SAndroid Build Coastguard Worker        self.upper_bound = upper_bound
442*da0073e9SAndroid Build Coastguard Worker        super().__init__()
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
445*da0073e9SAndroid Build Coastguard Worker        return (self.lower_bound <= value) & (value < self.upper_bound)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
448*da0073e9SAndroid Build Coastguard Worker        fmt_string = self.__class__.__name__[1:]
449*da0073e9SAndroid Build Coastguard Worker        fmt_string += (
450*da0073e9SAndroid Build Coastguard Worker            f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
451*da0073e9SAndroid Build Coastguard Worker        )
452*da0073e9SAndroid Build Coastguard Worker        return fmt_string
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Workerclass _Simplex(Constraint):
456*da0073e9SAndroid Build Coastguard Worker    """
457*da0073e9SAndroid Build Coastguard Worker    Constrain to the unit simplex in the innermost (rightmost) dimension.
458*da0073e9SAndroid Build Coastguard Worker    Specifically: `x >= 0` and `x.sum(-1) == 1`.
459*da0073e9SAndroid Build Coastguard Worker    """
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker    event_dim = 1
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
464*da0073e9SAndroid Build Coastguard Worker        return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Workerclass _Multinomial(Constraint):
468*da0073e9SAndroid Build Coastguard Worker    """
469*da0073e9SAndroid Build Coastguard Worker    Constrain to nonnegative integer values summing to at most an upper bound.
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker    Note due to limitations of the Multinomial distribution, this currently
472*da0073e9SAndroid Build Coastguard Worker    checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future
473*da0073e9SAndroid Build Coastguard Worker    this may be strengthened to ``value.sum(-1) == upper_bound``.
474*da0073e9SAndroid Build Coastguard Worker    """
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker    is_discrete = True
477*da0073e9SAndroid Build Coastguard Worker    event_dim = 1
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker    def __init__(self, upper_bound):
480*da0073e9SAndroid Build Coastguard Worker        self.upper_bound = upper_bound
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker    def check(self, x):
483*da0073e9SAndroid Build Coastguard Worker        return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound)
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Workerclass _LowerTriangular(Constraint):
487*da0073e9SAndroid Build Coastguard Worker    """
488*da0073e9SAndroid Build Coastguard Worker    Constrain to lower-triangular square matrices.
489*da0073e9SAndroid Build Coastguard Worker    """
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    event_dim = 2
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
494*da0073e9SAndroid Build Coastguard Worker        value_tril = value.tril()
495*da0073e9SAndroid Build Coastguard Worker        return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Workerclass _LowerCholesky(Constraint):
499*da0073e9SAndroid Build Coastguard Worker    """
500*da0073e9SAndroid Build Coastguard Worker    Constrain to lower-triangular square matrices with positive diagonals.
501*da0073e9SAndroid Build Coastguard Worker    """
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker    event_dim = 2
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
506*da0073e9SAndroid Build Coastguard Worker        value_tril = value.tril()
507*da0073e9SAndroid Build Coastguard Worker        lower_triangular = (
508*da0073e9SAndroid Build Coastguard Worker            (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
509*da0073e9SAndroid Build Coastguard Worker        )
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker        positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
512*da0073e9SAndroid Build Coastguard Worker        return lower_triangular & positive_diagonal
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Workerclass _CorrCholesky(Constraint):
516*da0073e9SAndroid Build Coastguard Worker    """
517*da0073e9SAndroid Build Coastguard Worker    Constrain to lower-triangular square matrices with positive diagonals and each
518*da0073e9SAndroid Build Coastguard Worker    row vector being of unit length.
519*da0073e9SAndroid Build Coastguard Worker    """
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker    event_dim = 2
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
524*da0073e9SAndroid Build Coastguard Worker        tol = (
525*da0073e9SAndroid Build Coastguard Worker            torch.finfo(value.dtype).eps * value.size(-1) * 10
526*da0073e9SAndroid Build Coastguard Worker        )  # 10 is an adjustable fudge factor
527*da0073e9SAndroid Build Coastguard Worker        row_norm = torch.linalg.norm(value.detach(), dim=-1)
528*da0073e9SAndroid Build Coastguard Worker        unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1)
529*da0073e9SAndroid Build Coastguard Worker        return _LowerCholesky().check(value) & unit_row_norm
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Workerclass _Square(Constraint):
533*da0073e9SAndroid Build Coastguard Worker    """
534*da0073e9SAndroid Build Coastguard Worker    Constrain to square matrices.
535*da0073e9SAndroid Build Coastguard Worker    """
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker    event_dim = 2
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
540*da0073e9SAndroid Build Coastguard Worker        return torch.full(
541*da0073e9SAndroid Build Coastguard Worker            size=value.shape[:-2],
542*da0073e9SAndroid Build Coastguard Worker            fill_value=(value.shape[-2] == value.shape[-1]),
543*da0073e9SAndroid Build Coastguard Worker            dtype=torch.bool,
544*da0073e9SAndroid Build Coastguard Worker            device=value.device,
545*da0073e9SAndroid Build Coastguard Worker        )
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Workerclass _Symmetric(_Square):
549*da0073e9SAndroid Build Coastguard Worker    """
550*da0073e9SAndroid Build Coastguard Worker    Constrain to Symmetric square matrices.
551*da0073e9SAndroid Build Coastguard Worker    """
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
554*da0073e9SAndroid Build Coastguard Worker        square_check = super().check(value)
555*da0073e9SAndroid Build Coastguard Worker        if not square_check.all():
556*da0073e9SAndroid Build Coastguard Worker            return square_check
557*da0073e9SAndroid Build Coastguard Worker        return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1)
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Workerclass _PositiveSemidefinite(_Symmetric):
561*da0073e9SAndroid Build Coastguard Worker    """
562*da0073e9SAndroid Build Coastguard Worker    Constrain to positive-semidefinite matrices.
563*da0073e9SAndroid Build Coastguard Worker    """
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
566*da0073e9SAndroid Build Coastguard Worker        sym_check = super().check(value)
567*da0073e9SAndroid Build Coastguard Worker        if not sym_check.all():
568*da0073e9SAndroid Build Coastguard Worker            return sym_check
569*da0073e9SAndroid Build Coastguard Worker        return torch.linalg.eigvalsh(value).ge(0).all(-1)
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Workerclass _PositiveDefinite(_Symmetric):
573*da0073e9SAndroid Build Coastguard Worker    """
574*da0073e9SAndroid Build Coastguard Worker    Constrain to positive-definite matrices.
575*da0073e9SAndroid Build Coastguard Worker    """
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
578*da0073e9SAndroid Build Coastguard Worker        sym_check = super().check(value)
579*da0073e9SAndroid Build Coastguard Worker        if not sym_check.all():
580*da0073e9SAndroid Build Coastguard Worker            return sym_check
581*da0073e9SAndroid Build Coastguard Worker        return torch.linalg.cholesky_ex(value).info.eq(0)
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Workerclass _Cat(Constraint):
585*da0073e9SAndroid Build Coastguard Worker    """
586*da0073e9SAndroid Build Coastguard Worker    Constraint functor that applies a sequence of constraints
587*da0073e9SAndroid Build Coastguard Worker    `cseq` at the submatrices at dimension `dim`,
588*da0073e9SAndroid Build Coastguard Worker    each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`.
589*da0073e9SAndroid Build Coastguard Worker    """
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker    def __init__(self, cseq, dim=0, lengths=None):
592*da0073e9SAndroid Build Coastguard Worker        assert all(isinstance(c, Constraint) for c in cseq)
593*da0073e9SAndroid Build Coastguard Worker        self.cseq = list(cseq)
594*da0073e9SAndroid Build Coastguard Worker        if lengths is None:
595*da0073e9SAndroid Build Coastguard Worker            lengths = [1] * len(self.cseq)
596*da0073e9SAndroid Build Coastguard Worker        self.lengths = list(lengths)
597*da0073e9SAndroid Build Coastguard Worker        assert len(self.lengths) == len(self.cseq)
598*da0073e9SAndroid Build Coastguard Worker        self.dim = dim
599*da0073e9SAndroid Build Coastguard Worker        super().__init__()
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker    @property
602*da0073e9SAndroid Build Coastguard Worker    def is_discrete(self):
603*da0073e9SAndroid Build Coastguard Worker        return any(c.is_discrete for c in self.cseq)
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker    @property
606*da0073e9SAndroid Build Coastguard Worker    def event_dim(self):
607*da0073e9SAndroid Build Coastguard Worker        return max(c.event_dim for c in self.cseq)
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
610*da0073e9SAndroid Build Coastguard Worker        assert -value.dim() <= self.dim < value.dim()
611*da0073e9SAndroid Build Coastguard Worker        checks = []
612*da0073e9SAndroid Build Coastguard Worker        start = 0
613*da0073e9SAndroid Build Coastguard Worker        for constr, length in zip(self.cseq, self.lengths):
614*da0073e9SAndroid Build Coastguard Worker            v = value.narrow(self.dim, start, length)
615*da0073e9SAndroid Build Coastguard Worker            checks.append(constr.check(v))
616*da0073e9SAndroid Build Coastguard Worker            start = start + length  # avoid += for jit compat
617*da0073e9SAndroid Build Coastguard Worker        return torch.cat(checks, self.dim)
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Workerclass _Stack(Constraint):
621*da0073e9SAndroid Build Coastguard Worker    """
622*da0073e9SAndroid Build Coastguard Worker    Constraint functor that applies a sequence of constraints
623*da0073e9SAndroid Build Coastguard Worker    `cseq` at the submatrices at dimension `dim`,
624*da0073e9SAndroid Build Coastguard Worker    in a way compatible with :func:`torch.stack`.
625*da0073e9SAndroid Build Coastguard Worker    """
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker    def __init__(self, cseq, dim=0):
628*da0073e9SAndroid Build Coastguard Worker        assert all(isinstance(c, Constraint) for c in cseq)
629*da0073e9SAndroid Build Coastguard Worker        self.cseq = list(cseq)
630*da0073e9SAndroid Build Coastguard Worker        self.dim = dim
631*da0073e9SAndroid Build Coastguard Worker        super().__init__()
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker    @property
634*da0073e9SAndroid Build Coastguard Worker    def is_discrete(self):
635*da0073e9SAndroid Build Coastguard Worker        return any(c.is_discrete for c in self.cseq)
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker    @property
638*da0073e9SAndroid Build Coastguard Worker    def event_dim(self):
639*da0073e9SAndroid Build Coastguard Worker        dim = max(c.event_dim for c in self.cseq)
640*da0073e9SAndroid Build Coastguard Worker        if self.dim + dim < 0:
641*da0073e9SAndroid Build Coastguard Worker            dim += 1
642*da0073e9SAndroid Build Coastguard Worker        return dim
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Worker    def check(self, value):
645*da0073e9SAndroid Build Coastguard Worker        assert -value.dim() <= self.dim < value.dim()
646*da0073e9SAndroid Build Coastguard Worker        vs = [value.select(self.dim, i) for i in range(value.size(self.dim))]
647*da0073e9SAndroid Build Coastguard Worker        return torch.stack(
648*da0073e9SAndroid Build Coastguard Worker            [constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim
649*da0073e9SAndroid Build Coastguard Worker        )
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker# Public interface.
653*da0073e9SAndroid Build Coastguard Workerdependent = _Dependent()
654*da0073e9SAndroid Build Coastguard Workerdependent_property = _DependentProperty
655*da0073e9SAndroid Build Coastguard Workerindependent = _IndependentConstraint
656*da0073e9SAndroid Build Coastguard Workerboolean = _Boolean()
657*da0073e9SAndroid Build Coastguard Workerone_hot = _OneHot()
658*da0073e9SAndroid Build Coastguard Workernonnegative_integer = _IntegerGreaterThan(0)
659*da0073e9SAndroid Build Coastguard Workerpositive_integer = _IntegerGreaterThan(1)
660*da0073e9SAndroid Build Coastguard Workerinteger_interval = _IntegerInterval
661*da0073e9SAndroid Build Coastguard Workerreal = _Real()
662*da0073e9SAndroid Build Coastguard Workerreal_vector = independent(real, 1)
663*da0073e9SAndroid Build Coastguard Workerpositive = _GreaterThan(0.0)
664*da0073e9SAndroid Build Coastguard Workernonnegative = _GreaterThanEq(0.0)
665*da0073e9SAndroid Build Coastguard Workergreater_than = _GreaterThan
666*da0073e9SAndroid Build Coastguard Workergreater_than_eq = _GreaterThanEq
667*da0073e9SAndroid Build Coastguard Workerless_than = _LessThan
668*da0073e9SAndroid Build Coastguard Workermultinomial = _Multinomial
669*da0073e9SAndroid Build Coastguard Workerunit_interval = _Interval(0.0, 1.0)
670*da0073e9SAndroid Build Coastguard Workerinterval = _Interval
671*da0073e9SAndroid Build Coastguard Workerhalf_open_interval = _HalfOpenInterval
672*da0073e9SAndroid Build Coastguard Workersimplex = _Simplex()
673*da0073e9SAndroid Build Coastguard Workerlower_triangular = _LowerTriangular()
674*da0073e9SAndroid Build Coastguard Workerlower_cholesky = _LowerCholesky()
675*da0073e9SAndroid Build Coastguard Workercorr_cholesky = _CorrCholesky()
676*da0073e9SAndroid Build Coastguard Workersquare = _Square()
677*da0073e9SAndroid Build Coastguard Workersymmetric = _Symmetric()
678*da0073e9SAndroid Build Coastguard Workerpositive_semidefinite = _PositiveSemidefinite()
679*da0073e9SAndroid Build Coastguard Workerpositive_definite = _PositiveDefinite()
680*da0073e9SAndroid Build Coastguard Workercat = _Cat
681*da0073e9SAndroid Build Coastguard Workerstack = _Stack
682