1# mypy: allow-untyped-defs 2import torch 3import torch._dynamo as torchdynamo 4 5 6class AssumeConstantResult(torch.nn.Module): 7 """ 8 Applying `assume_constant_result` decorator to burn make non-tracable code as constant. 9 """ 10 11 @torchdynamo.assume_constant_result 12 def get_item(self, y): 13 return y.int().item() 14 15 def forward(self, x, y): 16 return x[: self.get_item(y)] 17 18example_args = (torch.randn(3, 2), torch.tensor(4)) 19tags = {"torch.escape-hatch"} 20model = AssumeConstantResult() 21