xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/cond_closed_over_variable.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4from functorch.experimental.control_flow import cond
5
6class CondClosedOverVariable(torch.nn.Module):
7    """
8    torch.cond() supports branches closed over arbitrary variables.
9    """
10
11    def forward(self, pred, x):
12        def true_fn(val):
13            return x * 2
14
15        def false_fn(val):
16            return x - 2
17
18        return cond(pred, true_fn, false_fn, [x + 1])
19
20example_args = (torch.tensor(True), torch.randn(3, 2))
21tags = {"torch.cond", "python.closure"}
22model = CondClosedOverVariable()
23