xref: /aosp_15_r20/external/pytorch/test/mobile/model_test/android_api_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Dict, List, Optional, Tuple
2
3import torch
4from torch import Tensor
5
6
7class AndroidAPIModule(torch.jit.ScriptModule):
8    @torch.jit.script_method
9    def forward(self, input):
10        return None
11
12    @torch.jit.script_method
13    def eqBool(self, input: bool) -> bool:
14        return input
15
16    @torch.jit.script_method
17    def eqInt(self, input: int) -> int:
18        return input
19
20    @torch.jit.script_method
21    def eqFloat(self, input: float) -> float:
22        return input
23
24    @torch.jit.script_method
25    def eqStr(self, input: str) -> str:
26        return input
27
28    @torch.jit.script_method
29    def eqTensor(self, input: Tensor) -> Tensor:
30        return input
31
32    @torch.jit.script_method
33    def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]:
34        return input
35
36    @torch.jit.script_method
37    def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]:
38        return input
39
40    @torch.jit.script_method
41    def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]:
42        return input
43
44    @torch.jit.script_method
45    def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]:
46        sum = 0
47        for x in input:
48            sum += x
49        return (input, sum)
50
51    @torch.jit.script_method
52    def listBoolConjunction(self, input: List[bool]) -> bool:
53        res = True
54        for x in input:
55            res = res and x
56        return res
57
58    @torch.jit.script_method
59    def listBoolDisjunction(self, input: List[bool]) -> bool:
60        res = False
61        for x in input:
62            res = res or x
63        return res
64
65    @torch.jit.script_method
66    def tupleIntSumReturnTuple(
67        self, input: Tuple[int, int, int]
68    ) -> Tuple[Tuple[int, int, int], int]:
69        sum = 0
70        for x in input:
71            sum += x
72        return (input, sum)
73
74    @torch.jit.script_method
75    def optionalIntIsNone(self, input: Optional[int]) -> bool:
76        return input is None
77
78    @torch.jit.script_method
79    def intEq0None(self, input: int) -> Optional[int]:
80        if input == 0:
81            return None
82        return input
83
84    @torch.jit.script_method
85    def str3Concat(self, input: str) -> str:
86        return input + input + input
87
88    @torch.jit.script_method
89    def newEmptyShapeWithItem(self, input):
90        return torch.tensor([int(input.item())])[0]
91
92    @torch.jit.script_method
93    def testAliasWithOffset(self) -> List[Tensor]:
94        x = torch.tensor([100, 200])
95        a = [x[0], x[1]]
96        return a
97
98    @torch.jit.script_method
99    def testNonContiguous(self):
100        x = torch.tensor([100, 200, 300])[::2]
101        assert not x.is_contiguous()
102        assert x[0] == 100
103        assert x[1] == 300
104        return x
105
106    @torch.jit.script_method
107    def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
108        r = torch.nn.functional.conv2d(x, w)
109        if toChannelsLast:
110            r = r.contiguous(memory_format=torch.channels_last)
111        else:
112            r = r.contiguous()
113        return r
114
115    @torch.jit.script_method
116    def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
117        r = torch.nn.functional.conv3d(x, w)
118        if toChannelsLast:
119            r = r.contiguous(memory_format=torch.channels_last_3d)
120        else:
121            r = r.contiguous()
122        return r
123
124    @torch.jit.script_method
125    def contiguous(self, x: Tensor) -> Tensor:
126        return x.contiguous()
127
128    @torch.jit.script_method
129    def contiguousChannelsLast(self, x: Tensor) -> Tensor:
130        return x.contiguous(memory_format=torch.channels_last)
131
132    @torch.jit.script_method
133    def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
134        return x.contiguous(memory_format=torch.channels_last_3d)
135