1# mypy: allow-untyped-defs 2import torch 3 4from torch.export import Dim 5 6x = torch.randn(3, 2) 7dim1_x = Dim("dim1_x") 8 9class ScalarOutput(torch.nn.Module): 10 """ 11 Returning scalar values from the graph is supported, in addition to Tensor 12 outputs. Symbolic shapes are captured and rank is specialized. 13 """ 14 def __init__(self) -> None: 15 super().__init__() 16 17 def forward(self, x): 18 return x.shape[1] + 1 19 20example_args = (x,) 21tags = {"torch.dynamic-shape"} 22dynamic_shapes = {"x": {1: dim1_x}} 23model = ScalarOutput() 24