xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/dynamic_shape_round.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4from torch._export.db.case import SupportLevel
5from torch.export import Dim
6
7class DynamicShapeRound(torch.nn.Module):
8    """
9    Calling round on dynamic shapes is not supported.
10    """
11
12    def forward(self, x):
13        return x[: round(x.shape[0] / 2)]
14
15x = torch.randn(3, 2)
16dim0_x = Dim("dim0_x")
17example_args = (x,)
18tags = {"torch.dynamic-shape", "python.builtin"}
19support_level = SupportLevel.NOT_SUPPORTED_YET
20dynamic_shapes = {"x": {0: dim0_x}}
21model = DynamicShapeRound()
22