1# mypy: allow-untyped-defs 2import torch 3 4class StaticIf(torch.nn.Module): 5 """ 6 `if` statement with static predicate value should be traced through with the 7 taken branch. 8 """ 9 10 def forward(self, x): 11 if len(x.shape) == 3: 12 return x + torch.ones(1, 1, 1) 13 14 return x 15 16example_args = (torch.randn(3, 2, 2),) 17tags = {"python.control-flow"} 18model = StaticIf() 19