xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/nested_function.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4class NestedFunction(torch.nn.Module):
5    """
6    Nested functions are traced through. Side effects on global captures
7    are not supported though.
8    """
9
10    def forward(self, a, b):
11        x = a + b
12        z = a - b
13
14        def closure(y):
15            nonlocal x
16            x += 1
17            return x * y + z
18
19        return closure(x)
20
21example_args = (torch.randn(3, 2), torch.randn(2))
22tags = {"python.closure"}
23model = NestedFunction()
24