1# Owner(s): ["oncall: jit"] 2 3from test_jit import JitTestCase 4from torch.testing._internal.common_utils import run_tests 5 6from typing import List, Tuple 7 8class TestScript(JitTestCase): 9 def test_str_ops(self): 10 def test_str_is(s: str) -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]: 11 return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \ 12 s.isalnum(), s.isalpha(), s.isdecimal(), s.isnumeric(), \ 13 s.isidentifier(), s.istitle(), s.isprintable() 14 15 def test_str_to(s: str) -> Tuple[str, str, str, str, str]: 16 return s.upper(), s.lower(), s.capitalize(), s.title(), s.swapcase() 17 18 def test_str_strip(s: str) -> Tuple[str, str, str]: 19 return ( 20 s.lstrip(), 21 s.rstrip(), 22 s.strip(), 23 ) 24 25 def test_str_strip_char_set(s: str, char_set: str) -> Tuple[str, str, str]: 26 return ( 27 s.lstrip(char_set), 28 s.rstrip(char_set), 29 s.strip(char_set), 30 ) 31 32 inputs = ["", "12a", "!B", "12", "a", "B", "aB", "$12", "B12", "AB ", 33 " \t", " \n", "\na", "abc", "123.3", "s a", "b12a ", 34 "more strings with spaces", "Titular Strings", "\x0acan'tprintthis", 35 "spaces at the end ", " begin"] 36 37 def test_str_center(i: int, s: str) -> str: 38 return s.center(i) 39 40 def test_str_center_fc(i: int, s: str) -> str: 41 return s.center(i, '*') 42 43 def test_str_center_error(s: str) -> str: 44 return s.center(10, '**') 45 46 def test_ljust(s: str, i: int) -> str: 47 return s.ljust(i) 48 49 def test_ljust_fc(s: str, i: int, fc: str) -> str: 50 return s.ljust(i, fc) 51 52 def test_ljust_fc_err(s: str) -> str: 53 return s.ljust(10, '**') 54 55 def test_rjust(s: str, i: int) -> str: 56 return s.rjust(i) 57 58 def test_rjust_fc(s: str, i: int, fc: str) -> str: 59 return s.rjust(i, fc) 60 61 def test_rjust_fc_err(s: str) -> str: 62 return s.rjust(10, '**') 63 64 def test_zfill(s: str, i: int) -> str: 65 return s.zfill(i) 66 67 for input in inputs: 68 self.checkScript(test_str_is, (input,)) 69 self.checkScript(test_str_to, (input,)) 70 self.checkScript(test_str_strip, (input,)) 71 for char_set in ["abc", "123", " ", "\t"]: 72 self.checkScript(test_str_strip_char_set, (input, char_set)) 73 for i in range(7): 74 self.checkScript(test_str_center, (i, input,)) 75 self.checkScript(test_str_center_fc, (i, input,)) 76 self.checkScript(test_ljust, (input, i)) 77 self.checkScript(test_ljust_fc, (input, i, '*')) 78 self.checkScript(test_rjust, (input, i)) 79 self.checkScript(test_rjust_fc, (input, i, '*')) 80 self.checkScript(test_zfill, (input, i)) 81 82 with self.assertRaises(Exception): 83 test_str_center_error("error") 84 test_ljust("error") 85 86 def test_count() -> Tuple[int, int, int, int, int, int, int, int, int, int, int, int]: 87 return ( 88 "hello".count("h"), 89 "hello".count("h", 0, 1), 90 "hello".count("h", -3), 91 "hello".count("h", -10, 1), 92 "hello".count("h", 0, -10), 93 "hello".count("h", 0, 10), 94 "hello".count("ell"), 95 "hello".count("ell", 0, 1), 96 "hello".count("ell", -3), 97 "hello".count("ell", -10, 1), 98 "hello".count("ell", 0, -10), 99 "hello".count("ell", 0, 10) 100 ) 101 self.checkScript(test_count, ()) 102 103 def test_endswith() -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]: 104 return ( 105 "hello".endswith("lo"), 106 "hello".endswith("lo", 0), 107 "hello".endswith("lo", -2), 108 "hello".endswith("lo", -8), 109 "hello".endswith("lo", 0, -5), 110 "hello".endswith("lo", -2, 3), 111 "hello".endswith("lo", -8, 4), 112 "hello".endswith("l"), 113 "hello".endswith("l", 0), 114 "hello".endswith("l", -2), 115 "hello".endswith("l", -8), 116 "hello".endswith("l", 0, -5), 117 "hello".endswith("l", -2, 3), 118 "hello".endswith("l", -8, 4) 119 ) 120 self.checkScript(test_endswith, ()) 121 122 def test_startswith() -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]: 123 return ( 124 "hello".startswith("lo"), 125 "hello".startswith("lo", 0), 126 "hello".startswith("lo", -2), 127 "hello".startswith("lo", -8), 128 "hello".startswith("lo", 0, -5), 129 "hello".startswith("lo", -2, 3), 130 "hello".startswith("lo", -8, 4), 131 "hello".startswith("l"), 132 "hello".startswith("l", 0), 133 "hello".startswith("l", -2), 134 "hello".startswith("l", -8), 135 "hello".startswith("l", 0, -5), 136 "hello".startswith("l", -2, 3), 137 "hello".startswith("l", -8, 4) 138 ) 139 self.checkScript(test_startswith, ()) 140 141 def test_expandtabs() -> Tuple[str, str, str, str, str, str]: 142 return ( 143 'xyz\t82345\tabc'.expandtabs(), 144 'xyz\t32345\tabc'.expandtabs(3), 145 'xyz\t52345\tabc'.expandtabs(5), 146 'xyz\t62345\tabc'.expandtabs(6), 147 'xyz\t72345\tabc'.expandtabs(7), 148 'xyz\t62345\tabc'.expandtabs(-5), 149 ) 150 self.checkScript(test_expandtabs, ()) 151 152 def test_rfind() -> Tuple[int, int, int, int, int, int, int, int, int]: 153 return ( 154 "hello123abc".rfind("llo"), 155 "hello123abc".rfind("12"), 156 "hello123abc".rfind("ab"), 157 "hello123abc".rfind("ll", -1), 158 "hello123abc".rfind("12", 4), 159 "hello123abc".rfind("ab", -7), 160 "hello123abc".rfind("ll", -1, 8), 161 "hello123abc".rfind("12", 4, -4), 162 "hello123abc".rfind("ab", -7, -20), 163 ) 164 self.checkScript(test_rfind, ()) 165 166 def test_find() -> Tuple[int, int, int, int, int, int, int, int, int]: 167 return ( 168 "hello123abc".find("llo"), 169 "hello123abc".find("12"), 170 "hello123abc".find("ab"), 171 "hello123abc".find("ll", -1), 172 "hello123abc".find("12", 4), 173 "hello123abc".find("ab", -7), 174 "hello123abc".find("ll", -1, 8), 175 "hello123abc".find("12", 4, -4), 176 "hello123abc".find("ab", -7, -20), 177 ) 178 self.checkScript(test_find, ()) 179 180 def test_index() -> Tuple[int, int, int, int, int, int]: 181 return ( 182 "hello123abc".index("llo"), 183 "hello123abc".index("12"), 184 "hello123abc".index("ab"), 185 "hello123abc".index("12", 4), 186 "hello123abc".index("ab", -7), 187 "hello123abc".index("12", 4, -4), 188 ) 189 self.checkScript(test_index, ()) 190 191 def test_rindex() -> Tuple[int, int, int, int, int, int]: 192 return ( 193 "hello123abc".rindex("llo"), 194 "hello123abc".rindex("12"), 195 "hello123abc".rindex("ab"), 196 "hello123abc".rindex("12", 4), 197 "hello123abc".rindex("ab", -7), 198 "hello123abc".rindex("12", 4, -4), 199 ) 200 self.checkScript(test_rindex, ()) 201 202 def test_replace() -> Tuple[str, str, str, str, str, str, str]: 203 return ( 204 "hello123abc".replace("llo", "sdf"), 205 "ff".replace("f", "ff"), 206 "abc123".replace("a", "testing"), 207 "aaaaaa".replace("a", "testing", 3), 208 "bbb".replace("a", "testing", 3), 209 "ccc".replace("c", "ccc", 3), 210 "cc".replace("c", "ccc", -3), 211 ) 212 self.checkScript(test_replace, ()) 213 214 def test_partition() -> Tuple[Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str], 215 Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str], 216 Tuple[str, str, str]]: 217 return ( 218 "hello123abc".partition("llo"), 219 "ff".partition("f"), 220 "abc123".partition("a"), 221 "aaaaaa".partition("testing"), 222 "bbb".partition("a"), 223 "ccc".partition("ccc"), 224 "cc".partition("ccc"), 225 ) 226 self.checkScript(test_partition, ()) 227 228 def test_rpartition() -> Tuple[Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str], 229 Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str], 230 Tuple[str, str, str]]: 231 return ( 232 "hello123abc".rpartition("llo"), 233 "ff".rpartition("f"), 234 "abc123".rpartition("a"), 235 "aaaaaa".rpartition("testing"), 236 "bbb".rpartition("a"), 237 "ccc".rpartition("ccc"), 238 "cc".rpartition("ccc"), 239 ) 240 self.checkScript(test_rpartition, ()) 241 242 def test_split() -> Tuple[List[str], List[str], List[str], List[str], List[str], 243 List[str], List[str], List[str], List[str], List[str], List[str]]: 244 return ( 245 "a a a a a".split(), 246 "a a a a a".split(), 247 " a a\ta \v a \v\f\n a \t ".split(), 248 " a a a a a ".split(" "), 249 "a a a a a ".split(" ", 10), 250 "a a a a a ".split(" ", -1), 251 "a a a a a ".split(" ", 3), 252 " a a a a a ".split("*"), 253 " a*a a*a a".split("*"), 254 " a*a a*a a ".split("*", -1), 255 " a*a a*a a ".split("a*", 10), 256 ) 257 self.checkScript(test_split, ()) 258 259 # test raising error for empty separator 260 def test_split_empty_separator(): 261 s = "test" 262 return s.split("") 263 264 self.checkScriptRaisesRegex(test_split_empty_separator, (), Exception, 265 "empty separator") 266 267 def test_rsplit() -> Tuple[List[str], List[str], List[str], List[str], List[str], 268 List[str], List[str], List[str], List[str]]: 269 return ( 270 "a a a a a".rsplit(), 271 " a a a a a ".rsplit(" "), 272 "a a a a a ".rsplit(" ", 10), 273 "a a a a a ".rsplit(" ", -1), 274 "a a a a a ".rsplit(" ", 3), 275 " a a a a a ".rsplit("*"), 276 " a*a a*a a ".rsplit("*"), 277 " a*a a*a a ".rsplit("*", -1), 278 " a*a a*a a".rsplit("a*", 10), 279 ) 280 self.checkScript(test_rsplit, ()) 281 282 def test_splitlines() -> Tuple[List[str], List[str], List[str], List[str], 283 List[str], List[str]]: 284 return ( 285 "hello\ntest".splitlines(), 286 "hello\n\ntest\n".splitlines(), 287 "hello\ntest\n\n".splitlines(), 288 "hello\vtest".splitlines(), 289 "hello\v\f\ntest".splitlines(), 290 "hello\ftest".splitlines(), 291 ) 292 self.checkScript(test_splitlines, ()) 293 294 def test_str_cmp(a: str, b: str) -> Tuple[bool, bool, bool, bool, bool, bool]: 295 return a != b, a == b, a < b, a > b, a <= b, a >= b 296 297 for i in range(len(inputs) - 1): 298 self.checkScript(test_str_cmp, (inputs[i], inputs[i + 1])) 299 300 def test_str_join(): 301 return ( 302 ",".join(["a"]), 303 ",".join(["a", "b", "c"]), 304 ",".join(["aa", "bb", "cc"]), 305 ",".join(["a,a", "bb", "c,c"]), 306 "**a**".join(["b", "c", "d", "e"]), 307 "".join(["a", "b", "c"]), 308 ) 309 self.checkScript(test_str_join, ()) 310 311 def test_bool_conversion(a: str): 312 if a: 313 return a 314 else: 315 return "default" 316 317 self.checkScript(test_bool_conversion, ("nonempty",)) 318 self.checkScript(test_bool_conversion, ("",)) 319 320 def test_string_slice(self): 321 def test_slice(a: str) -> Tuple[str, str, str, str, str]: 322 return ( 323 a[0:1:2], 324 a[0:6:1], 325 a[4:1:2], 326 a[0:3:2], 327 a[-1:1:3], 328 ) 329 330 self.checkScript(test_slice, ("hellotest",)) 331 332if __name__ == '__main__': 333 run_tests() 334