xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/dynamic_shape_map.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4from functorch.experimental.control_flow import map
5
6class DynamicShapeMap(torch.nn.Module):
7    """
8    functorch map() maps a function over the first tensor dimension.
9    """
10
11    def forward(self, xs, y):
12        def body(x, y):
13            return x + y
14
15        return map(body, xs, y)
16
17example_args = (torch.randn(3, 2), torch.randn(2))
18tags = {"torch.dynamic-shape", "torch.map"}
19model = DynamicShapeMap()
20