1# mypy: allow-untyped-defs 2import torch 3 4from torch.utils import _pytree as pytree 5 6class PytreeFlatten(torch.nn.Module): 7 """ 8 Pytree from PyTorch can be captured by TorchDynamo. 9 """ 10 11 def forward(self, x): 12 y, spec = pytree.tree_flatten(x) 13 return y[0] + 1 14 15example_args = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},), 16model = PytreeFlatten() 17