xref: /aosp_15_r20/external/pytorch/test/test_jit_string.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3from test_jit import JitTestCase
4from torch.testing._internal.common_utils import run_tests
5
6from typing import List, Tuple
7
8class TestScript(JitTestCase):
9    def test_str_ops(self):
10        def test_str_is(s: str) -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]:
11            return s.isupper(), s.islower(), s.isdigit(), s.isspace(), \
12                s.isalnum(), s.isalpha(), s.isdecimal(), s.isnumeric(), \
13                s.isidentifier(), s.istitle(), s.isprintable()
14
15        def test_str_to(s: str) -> Tuple[str, str, str, str, str]:
16            return s.upper(), s.lower(), s.capitalize(), s.title(), s.swapcase()
17
18        def test_str_strip(s: str) -> Tuple[str, str, str]:
19            return (
20                s.lstrip(),
21                s.rstrip(),
22                s.strip(),
23            )
24
25        def test_str_strip_char_set(s: str, char_set: str) -> Tuple[str, str, str]:
26            return (
27                s.lstrip(char_set),
28                s.rstrip(char_set),
29                s.strip(char_set),
30            )
31
32        inputs = ["", "12a", "!B", "12", "a", "B", "aB", "$12", "B12", "AB ",
33                  "  \t", "  \n", "\na", "abc", "123.3", "s a", "b12a ",
34                  "more strings with spaces", "Titular Strings", "\x0acan'tprintthis",
35                  "spaces at the end ", " begin"]
36
37        def test_str_center(i: int, s: str) -> str:
38            return s.center(i)
39
40        def test_str_center_fc(i: int, s: str) -> str:
41            return s.center(i, '*')
42
43        def test_str_center_error(s: str) -> str:
44            return s.center(10, '**')
45
46        def test_ljust(s: str, i: int) -> str:
47            return s.ljust(i)
48
49        def test_ljust_fc(s: str, i: int, fc: str) -> str:
50            return s.ljust(i, fc)
51
52        def test_ljust_fc_err(s: str) -> str:
53            return s.ljust(10, '**')
54
55        def test_rjust(s: str, i: int) -> str:
56            return s.rjust(i)
57
58        def test_rjust_fc(s: str, i: int, fc: str) -> str:
59            return s.rjust(i, fc)
60
61        def test_rjust_fc_err(s: str) -> str:
62            return s.rjust(10, '**')
63
64        def test_zfill(s: str, i: int) -> str:
65            return s.zfill(i)
66
67        for input in inputs:
68            self.checkScript(test_str_is, (input,))
69            self.checkScript(test_str_to, (input,))
70            self.checkScript(test_str_strip, (input,))
71            for char_set in ["abc", "123", " ", "\t"]:
72                self.checkScript(test_str_strip_char_set, (input, char_set))
73            for i in range(7):
74                self.checkScript(test_str_center, (i, input,))
75                self.checkScript(test_str_center_fc, (i, input,))
76                self.checkScript(test_ljust, (input, i))
77                self.checkScript(test_ljust_fc, (input, i, '*'))
78                self.checkScript(test_rjust, (input, i))
79                self.checkScript(test_rjust_fc, (input, i, '*'))
80                self.checkScript(test_zfill, (input, i))
81
82        with self.assertRaises(Exception):
83            test_str_center_error("error")
84            test_ljust("error")
85
86        def test_count() -> Tuple[int, int, int, int, int, int, int, int, int, int, int, int]:
87            return (
88                "hello".count("h"),
89                "hello".count("h", 0, 1),
90                "hello".count("h", -3),
91                "hello".count("h", -10, 1),
92                "hello".count("h", 0, -10),
93                "hello".count("h", 0, 10),
94                "hello".count("ell"),
95                "hello".count("ell", 0, 1),
96                "hello".count("ell", -3),
97                "hello".count("ell", -10, 1),
98                "hello".count("ell", 0, -10),
99                "hello".count("ell", 0, 10)
100            )
101        self.checkScript(test_count, ())
102
103        def test_endswith() -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]:
104            return (
105                "hello".endswith("lo"),
106                "hello".endswith("lo", 0),
107                "hello".endswith("lo", -2),
108                "hello".endswith("lo", -8),
109                "hello".endswith("lo", 0, -5),
110                "hello".endswith("lo", -2, 3),
111                "hello".endswith("lo", -8, 4),
112                "hello".endswith("l"),
113                "hello".endswith("l", 0),
114                "hello".endswith("l", -2),
115                "hello".endswith("l", -8),
116                "hello".endswith("l", 0, -5),
117                "hello".endswith("l", -2, 3),
118                "hello".endswith("l", -8, 4)
119            )
120        self.checkScript(test_endswith, ())
121
122        def test_startswith() -> Tuple[bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool]:
123            return (
124                "hello".startswith("lo"),
125                "hello".startswith("lo", 0),
126                "hello".startswith("lo", -2),
127                "hello".startswith("lo", -8),
128                "hello".startswith("lo", 0, -5),
129                "hello".startswith("lo", -2, 3),
130                "hello".startswith("lo", -8, 4),
131                "hello".startswith("l"),
132                "hello".startswith("l", 0),
133                "hello".startswith("l", -2),
134                "hello".startswith("l", -8),
135                "hello".startswith("l", 0, -5),
136                "hello".startswith("l", -2, 3),
137                "hello".startswith("l", -8, 4)
138            )
139        self.checkScript(test_startswith, ())
140
141        def test_expandtabs() -> Tuple[str, str, str, str, str, str]:
142            return (
143                'xyz\t82345\tabc'.expandtabs(),
144                'xyz\t32345\tabc'.expandtabs(3),
145                'xyz\t52345\tabc'.expandtabs(5),
146                'xyz\t62345\tabc'.expandtabs(6),
147                'xyz\t72345\tabc'.expandtabs(7),
148                'xyz\t62345\tabc'.expandtabs(-5),
149            )
150        self.checkScript(test_expandtabs, ())
151
152        def test_rfind() -> Tuple[int, int, int, int, int, int, int, int, int]:
153            return (
154                "hello123abc".rfind("llo"),
155                "hello123abc".rfind("12"),
156                "hello123abc".rfind("ab"),
157                "hello123abc".rfind("ll", -1),
158                "hello123abc".rfind("12", 4),
159                "hello123abc".rfind("ab", -7),
160                "hello123abc".rfind("ll", -1, 8),
161                "hello123abc".rfind("12", 4, -4),
162                "hello123abc".rfind("ab", -7, -20),
163            )
164        self.checkScript(test_rfind, ())
165
166        def test_find() -> Tuple[int, int, int, int, int, int, int, int, int]:
167            return (
168                "hello123abc".find("llo"),
169                "hello123abc".find("12"),
170                "hello123abc".find("ab"),
171                "hello123abc".find("ll", -1),
172                "hello123abc".find("12", 4),
173                "hello123abc".find("ab", -7),
174                "hello123abc".find("ll", -1, 8),
175                "hello123abc".find("12", 4, -4),
176                "hello123abc".find("ab", -7, -20),
177            )
178        self.checkScript(test_find, ())
179
180        def test_index() -> Tuple[int, int, int, int, int, int]:
181            return (
182                "hello123abc".index("llo"),
183                "hello123abc".index("12"),
184                "hello123abc".index("ab"),
185                "hello123abc".index("12", 4),
186                "hello123abc".index("ab", -7),
187                "hello123abc".index("12", 4, -4),
188            )
189        self.checkScript(test_index, ())
190
191        def test_rindex() -> Tuple[int, int, int, int, int, int]:
192            return (
193                "hello123abc".rindex("llo"),
194                "hello123abc".rindex("12"),
195                "hello123abc".rindex("ab"),
196                "hello123abc".rindex("12", 4),
197                "hello123abc".rindex("ab", -7),
198                "hello123abc".rindex("12", 4, -4),
199            )
200        self.checkScript(test_rindex, ())
201
202        def test_replace() -> Tuple[str, str, str, str, str, str, str]:
203            return (
204                "hello123abc".replace("llo", "sdf"),
205                "ff".replace("f", "ff"),
206                "abc123".replace("a", "testing"),
207                "aaaaaa".replace("a", "testing", 3),
208                "bbb".replace("a", "testing", 3),
209                "ccc".replace("c", "ccc", 3),
210                "cc".replace("c", "ccc", -3),
211            )
212        self.checkScript(test_replace, ())
213
214        def test_partition() -> Tuple[Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
215                                      Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
216                                      Tuple[str, str, str]]:
217            return (
218                "hello123abc".partition("llo"),
219                "ff".partition("f"),
220                "abc123".partition("a"),
221                "aaaaaa".partition("testing"),
222                "bbb".partition("a"),
223                "ccc".partition("ccc"),
224                "cc".partition("ccc"),
225            )
226        self.checkScript(test_partition, ())
227
228        def test_rpartition() -> Tuple[Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
229                                       Tuple[str, str, str], Tuple[str, str, str], Tuple[str, str, str],
230                                       Tuple[str, str, str]]:
231            return (
232                "hello123abc".rpartition("llo"),
233                "ff".rpartition("f"),
234                "abc123".rpartition("a"),
235                "aaaaaa".rpartition("testing"),
236                "bbb".rpartition("a"),
237                "ccc".rpartition("ccc"),
238                "cc".rpartition("ccc"),
239            )
240        self.checkScript(test_rpartition, ())
241
242        def test_split() -> Tuple[List[str], List[str], List[str], List[str], List[str],
243                                  List[str], List[str], List[str], List[str], List[str], List[str]]:
244            return (
245                "a a a a a".split(),
246                "a  a a   a a".split(),
247                "   a a\ta \v a \v\f\n a \t   ".split(),
248                " a a a a a ".split(" "),
249                "a a a a a ".split(" ", 10),
250                "a a a a a ".split(" ", -1),
251                "a a a a a ".split(" ", 3),
252                " a a a a a ".split("*"),
253                " a*a a*a a".split("*"),
254                " a*a a*a a ".split("*", -1),
255                " a*a a*a a ".split("a*", 10),
256            )
257        self.checkScript(test_split, ())
258
259        # test raising error for empty separator
260        def test_split_empty_separator():
261            s = "test"
262            return s.split("")
263
264        self.checkScriptRaisesRegex(test_split_empty_separator, (), Exception,
265                                    "empty separator")
266
267        def test_rsplit() -> Tuple[List[str], List[str], List[str], List[str], List[str],
268                                   List[str], List[str], List[str], List[str]]:
269            return (
270                "a a a a a".rsplit(),
271                " a a a a a ".rsplit(" "),
272                "a a a a a ".rsplit(" ", 10),
273                "a a a a a ".rsplit(" ", -1),
274                "a a a a a ".rsplit(" ", 3),
275                " a a a a a ".rsplit("*"),
276                " a*a a*a a ".rsplit("*"),
277                " a*a a*a a ".rsplit("*", -1),
278                " a*a a*a a".rsplit("a*", 10),
279            )
280        self.checkScript(test_rsplit, ())
281
282        def test_splitlines() -> Tuple[List[str], List[str], List[str], List[str],
283                                       List[str], List[str]]:
284            return (
285                "hello\ntest".splitlines(),
286                "hello\n\ntest\n".splitlines(),
287                "hello\ntest\n\n".splitlines(),
288                "hello\vtest".splitlines(),
289                "hello\v\f\ntest".splitlines(),
290                "hello\ftest".splitlines(),
291            )
292        self.checkScript(test_splitlines, ())
293
294        def test_str_cmp(a: str, b: str) -> Tuple[bool, bool, bool, bool, bool, bool]:
295            return a != b, a == b, a < b, a > b, a <= b, a >= b
296
297        for i in range(len(inputs) - 1):
298            self.checkScript(test_str_cmp, (inputs[i], inputs[i + 1]))
299
300        def test_str_join():
301            return (
302                ",".join(["a"]),
303                ",".join(["a", "b", "c"]),
304                ",".join(["aa", "bb", "cc"]),
305                ",".join(["a,a", "bb", "c,c"]),
306                "**a**".join(["b", "c", "d", "e"]),
307                "".join(["a", "b", "c"]),
308            )
309        self.checkScript(test_str_join, ())
310
311        def test_bool_conversion(a: str):
312            if a:
313                return a
314            else:
315                return "default"
316
317        self.checkScript(test_bool_conversion, ("nonempty",))
318        self.checkScript(test_bool_conversion, ("",))
319
320    def test_string_slice(self):
321        def test_slice(a: str) -> Tuple[str, str, str, str, str]:
322            return (
323                a[0:1:2],
324                a[0:6:1],
325                a[4:1:2],
326                a[0:3:2],
327                a[-1:1:3],
328            )
329
330        self.checkScript(test_slice, ("hellotest",))
331
332if __name__ == '__main__':
333    run_tests()
334