1# mypy: allow-untyped-defs 2import torch 3 4class DynamicShapeIfGuard(torch.nn.Module): 5 """ 6 `if` statement with backed dynamic shape predicate will be specialized into 7 one particular branch and generate a guard. However, export will fail if the 8 the dimension is marked as dynamic shape from higher level API. 9 """ 10 11 def forward(self, x): 12 if x.shape[0] == 3: 13 return x.cos() 14 15 return x.sin() 16 17example_args = (torch.randn(3, 2, 2),) 18tags = {"torch.dynamic-shape", "python.control-flow"} 19model = DynamicShapeIfGuard() 20