xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/assume_constant_result.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch._dynamo as torchdynamo
4
5
6class AssumeConstantResult(torch.nn.Module):
7    """
8    Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
9    """
10
11    @torchdynamo.assume_constant_result
12    def get_item(self, y):
13        return y.int().item()
14
15    def forward(self, x, y):
16        return x[: self.get_item(y)]
17
18example_args = (torch.randn(3, 2), torch.tensor(4))
19tags = {"torch.escape-hatch"}
20model = AssumeConstantResult()
21