1# mypy: allow-untyped-defs 2import torch 3 4class DynamicShapeView(torch.nn.Module): 5 """ 6 Dynamic shapes should be propagated to view arguments instead of being 7 baked into the exported graph. 8 """ 9 10 def forward(self, x): 11 new_x_shape = x.size()[:-1] + (2, 5) 12 x = x.view(*new_x_shape) 13 return x.permute(0, 2, 1) 14 15example_args = (torch.randn(10, 10),) 16tags = {"torch.dynamic-shape"} 17model = DynamicShapeView() 18