xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/dynamic_shape_if_guard.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4class DynamicShapeIfGuard(torch.nn.Module):
5    """
6    `if` statement with backed dynamic shape predicate will be specialized into
7    one particular branch and generate a guard. However, export will fail if the
8    the dimension is marked as dynamic shape from higher level API.
9    """
10
11    def forward(self, x):
12        if x.shape[0] == 3:
13            return x.cos()
14
15        return x.sin()
16
17example_args = (torch.randn(3, 2, 2),)
18tags = {"torch.dynamic-shape", "python.control-flow"}
19model = DynamicShapeIfGuard()
20