1from typing import Dict, List, Optional, Tuple 2 3import torch 4from torch import Tensor 5 6 7class AndroidAPIModule(torch.jit.ScriptModule): 8 @torch.jit.script_method 9 def forward(self, input): 10 return None 11 12 @torch.jit.script_method 13 def eqBool(self, input: bool) -> bool: 14 return input 15 16 @torch.jit.script_method 17 def eqInt(self, input: int) -> int: 18 return input 19 20 @torch.jit.script_method 21 def eqFloat(self, input: float) -> float: 22 return input 23 24 @torch.jit.script_method 25 def eqStr(self, input: str) -> str: 26 return input 27 28 @torch.jit.script_method 29 def eqTensor(self, input: Tensor) -> Tensor: 30 return input 31 32 @torch.jit.script_method 33 def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]: 34 return input 35 36 @torch.jit.script_method 37 def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]: 38 return input 39 40 @torch.jit.script_method 41 def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]: 42 return input 43 44 @torch.jit.script_method 45 def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]: 46 sum = 0 47 for x in input: 48 sum += x 49 return (input, sum) 50 51 @torch.jit.script_method 52 def listBoolConjunction(self, input: List[bool]) -> bool: 53 res = True 54 for x in input: 55 res = res and x 56 return res 57 58 @torch.jit.script_method 59 def listBoolDisjunction(self, input: List[bool]) -> bool: 60 res = False 61 for x in input: 62 res = res or x 63 return res 64 65 @torch.jit.script_method 66 def tupleIntSumReturnTuple( 67 self, input: Tuple[int, int, int] 68 ) -> Tuple[Tuple[int, int, int], int]: 69 sum = 0 70 for x in input: 71 sum += x 72 return (input, sum) 73 74 @torch.jit.script_method 75 def optionalIntIsNone(self, input: Optional[int]) -> bool: 76 return input is None 77 78 @torch.jit.script_method 79 def intEq0None(self, input: int) -> Optional[int]: 80 if input == 0: 81 return None 82 return input 83 84 @torch.jit.script_method 85 def str3Concat(self, input: str) -> str: 86 return input + input + input 87 88 @torch.jit.script_method 89 def newEmptyShapeWithItem(self, input): 90 return torch.tensor([int(input.item())])[0] 91 92 @torch.jit.script_method 93 def testAliasWithOffset(self) -> List[Tensor]: 94 x = torch.tensor([100, 200]) 95 a = [x[0], x[1]] 96 return a 97 98 @torch.jit.script_method 99 def testNonContiguous(self): 100 x = torch.tensor([100, 200, 300])[::2] 101 assert not x.is_contiguous() 102 assert x[0] == 100 103 assert x[1] == 300 104 return x 105 106 @torch.jit.script_method 107 def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: 108 r = torch.nn.functional.conv2d(x, w) 109 if toChannelsLast: 110 r = r.contiguous(memory_format=torch.channels_last) 111 else: 112 r = r.contiguous() 113 return r 114 115 @torch.jit.script_method 116 def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: 117 r = torch.nn.functional.conv3d(x, w) 118 if toChannelsLast: 119 r = r.contiguous(memory_format=torch.channels_last_3d) 120 else: 121 r = r.contiguous() 122 return r 123 124 @torch.jit.script_method 125 def contiguous(self, x: Tensor) -> Tensor: 126 return x.contiguous() 127 128 @torch.jit.script_method 129 def contiguousChannelsLast(self, x: Tensor) -> Tensor: 130 return x.contiguous(memory_format=torch.channels_last) 131 132 @torch.jit.script_method 133 def contiguousChannelsLast3d(self, x: Tensor) -> Tensor: 134 return x.contiguous(memory_format=torch.channels_last_3d) 135