xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/scalar_output.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4from torch.export import Dim
5
6x = torch.randn(3, 2)
7dim1_x = Dim("dim1_x")
8
9class ScalarOutput(torch.nn.Module):
10    """
11    Returning scalar values from the graph is supported, in addition to Tensor
12    outputs. Symbolic shapes are captured and rank is specialized.
13    """
14    def __init__(self) -> None:
15        super().__init__()
16
17    def forward(self, x):
18        return x.shape[1] + 1
19
20example_args = (x,)
21tags = {"torch.dynamic-shape"}
22dynamic_shapes = {"x": {1: dim1_x}}
23model = ScalarOutput()
24