xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/cond_predicate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4from functorch.experimental.control_flow import cond
5
6class CondPredicate(torch.nn.Module):
7    """
8    The conditional statement (aka predicate) passed to cond() must be one of the following:
9      - torch.Tensor with a single element
10      - boolean expression
11
12    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
13    """
14
15    def forward(self, x):
16        pred = x.dim() > 2 and x.shape[2] > 10
17
18        return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
19
20example_args = (torch.randn(6, 4, 3),)
21tags = {
22    "torch.cond",
23    "torch.dynamic-shape",
24}
25model = CondPredicate()
26