xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/dynamic_shape_slicing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4class DynamicShapeSlicing(torch.nn.Module):
5    """
6    Slices with dynamic shape arguments should be captured into the graph
7    rather than being baked in.
8    """
9
10    def forward(self, x):
11        return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
12
13example_args = (torch.randn(3, 2),)
14tags = {"torch.dynamic-shape"}
15model = DynamicShapeSlicing()
16