xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/pytree_flatten.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4from torch.utils import _pytree as pytree
5
6class PytreeFlatten(torch.nn.Module):
7    """
8    Pytree from PyTorch can be captured by TorchDynamo.
9    """
10
11    def forward(self, x):
12        y, spec = pytree.tree_flatten(x)
13        return y[0] + 1
14
15example_args = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},),
16model = PytreeFlatten()
17