1# mypy: allow-untyped-defs 2import torch 3 4from functorch.experimental.control_flow import map 5 6class DynamicShapeMap(torch.nn.Module): 7 """ 8 functorch map() maps a function over the first tensor dimension. 9 """ 10 11 def forward(self, xs, y): 12 def body(x, y): 13 return x + y 14 15 return map(body, xs, y) 16 17example_args = (torch.randn(3, 2), torch.randn(2)) 18tags = {"torch.dynamic-shape", "torch.map"} 19model = DynamicShapeMap() 20