xref: /aosp_15_r20/frameworks/base/ravenwood/scripts/convert-androidtest.py (revision d57664e9bc4670b3ecf6748a746a57c557b6bc9e)
1#!/usr/bin/python3
2# Copyright (C) 2024 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# This script converts a legacy test class (using AndroidTestCase, TestCase or
17# InstrumentationTestCase to a modern style test class, in a best-effort manner.
18#
19# Usage:
20#  convert-androidtest.py TARGET-FILE [TARGET-FILE ...]
21#
22# Caveats:
23#   - It adds all the extra imports, even if they're not needed.
24#   - It won't sort imports.
25#   - It also always adds getContext() and getTestContext().
26#
27
28import sys
29import fileinput
30import re
31import subprocess
32
33# Print message on console
34def log(msg):
35    print(msg, file=sys.stderr)
36
37
38# Matches `extends AndroidTestCase` (or another similar base class)
39re_extends = re.compile(
40    r''' \b extends \s+ (AndroidTestCase|TestCase|InstrumentationTestCase) \s* ''',
41    re.S + re.X)
42
43
44# Look into given files and return the files that have `re_extends`.
45def find_target_files(files):
46    ret = []
47
48    for file in files:
49        try:
50            with open(file, 'r') as f:
51                data = f.read()
52
53                if re_extends.search(data):
54                    ret.append(file)
55
56        except FileNotFoundError as e:
57            log(f'Failed to open file {file}: {e}')
58
59    return ret
60
61
62def main(args):
63    files = args
64
65    # Find the files that should be processed.
66    files = find_target_files(files)
67
68    if len(files) == 0:
69        log("No target files found.")
70        return 0
71
72    # Process the files.
73    with fileinput.input(files=(files), inplace = True, backup = '.bak') as f:
74        import_seen = False
75        carry_over = ''
76        class_body_started = False
77        class_seen = False
78
79        def on_file_start():
80            nonlocal import_seen, carry_over, class_body_started, class_seen
81            import_seen = False
82            carry_over = ''
83            class_body_started = False
84            class_seen = False
85
86        for line in f:
87            if (fileinput.filelineno() == 1):
88                log(f"Processing: {fileinput.filename()}")
89                on_file_start()
90
91            line = line.rstrip('\n')
92
93            # Carry over a certain line to the next line.
94            if re.search(r'''@Override\b''', line):
95                carry_over = carry_over + line + '\n'
96                continue
97
98            if carry_over:
99                line = carry_over + line
100                carry_over = ''
101
102
103            # Remove the base class from the class definition.
104            line = re_extends.sub('', line)
105
106            # Add a @RunWith.
107            if not class_seen and re.search(r'''\b class \b''', line, re.X):
108                class_seen = True
109                print("@RunWith(AndroidJUnit4.class)")
110
111
112            # Inject extra imports.
113            if not import_seen and re.search(r'''^import\b''', line):
114                import_seen = True
115                print("""\
116import android.content.Context;
117import androidx.test.platform.app.InstrumentationRegistry;
118
119import static junit.framework.TestCase.assertEquals;
120import static junit.framework.TestCase.assertSame;
121import static junit.framework.TestCase.assertNotSame;
122import static junit.framework.TestCase.assertTrue;
123import static junit.framework.TestCase.assertFalse;
124import static junit.framework.TestCase.assertNull;
125import static junit.framework.TestCase.assertNotNull;
126import static junit.framework.TestCase.fail;
127
128import org.junit.After;
129import org.junit.Before;
130import org.junit.runner.RunWith;
131import org.junit.Test;
132
133import androidx.test.ext.junit.runners.AndroidJUnit4;
134""")
135
136            # Add @Test to the test methods.
137            if re.search(r'''^ \s* public \s* void \s* test''', line, re.X):
138                print("    @Test")
139
140            # Convert setUp/tearDown to @Before/@After.
141            if re.search(r''' ^\s+ ( \@Override \s+ ) ? (public|protected) \s+ void \s+ (setUp|tearDown) ''',
142                        line, re.X):
143                if re.search('setUp', line):
144                    print('    @Before')
145                else:
146                    print('    @After')
147
148                line = re.sub(r''' \s* \@Override \s* \n ''', '', line, 0, re.X)
149                line = re.sub(r'''protected''', 'public', line, 0, re.X)
150
151            # Remove the super setUp / tearDown call.
152            if re.search(r''' \b super \. (setUp|tearDown) \b ''', line, re.X):
153                continue
154
155            # Convert mContext to getContext().
156            line = re.sub(r'''\b mContext \b ''', 'getContext()', line, 0, re.X)
157
158            # Print the processed line.
159            print(line)
160
161            # Add getContext() / getTestContext() at the beginning of the class.
162            if not class_body_started and re.search(r'''\{''', line):
163                class_body_started = True
164                print("""\
165    private Context getContext() {
166        return InstrumentationRegistry.getInstrumentation().getTargetContext();
167    }
168
169    private Context getTestContext() {
170        return InstrumentationRegistry.getInstrumentation().getContext();
171    }
172""")
173
174
175    # Run diff
176    for file in files:
177        subprocess.call(["diff", "-u", "--color=auto", f"{file}.bak", file])
178
179    log(f'{len(files)} file(s) converted.')
180
181    return 0
182
183if __name__ == '__main__':
184    sys.exit(main(sys.argv[1:]))
185