xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/dynamic_shape_view.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4class DynamicShapeView(torch.nn.Module):
5    """
6    Dynamic shapes should be propagated to view arguments instead of being
7    baked into the exported graph.
8    """
9
10    def forward(self, x):
11        new_x_shape = x.size()[:-1] + (2, 5)
12        x = x.view(*new_x_shape)
13        return x.permute(0, 2, 1)
14
15example_args = (torch.randn(10, 10),)
16tags = {"torch.dynamic-shape"}
17model = DynamicShapeView()
18