1# mypy: allow-untyped-defs 2import torch 3 4class NestedFunction(torch.nn.Module): 5 """ 6 Nested functions are traced through. Side effects on global captures 7 are not supported though. 8 """ 9 10 def forward(self, a, b): 11 x = a + b 12 z = a - b 13 14 def closure(y): 15 nonlocal x 16 x += 1 17 return x * y + z 18 19 return closure(x) 20 21example_args = (torch.randn(3, 2), torch.randn(2)) 22tags = {"python.closure"} 23model = NestedFunction() 24