1#!/usr/bin/env python3 2# Copyright (C) 2023 The Android Open Source Project 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16import concurrent.futures 17import datetime 18import difflib 19import os 20import subprocess 21import sys 22import tempfile 23from binascii import unhexlify 24from dataclasses import dataclass 25from typing import List, Tuple, Optional 26 27from google.protobuf import text_format, message_factory, descriptor_pool 28from python.generators.diff_tests.testing import TestCase, TestType, BinaryProto 29from python.generators.diff_tests.utils import ( 30 ColorFormatter, create_message_factory, get_env, get_trace_descriptor_path, 31 read_all_tests, serialize_python_trace, serialize_textproto_trace, 32 modify_trace) 33 34ROOT_DIR = os.path.dirname( 35 os.path.dirname( 36 os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 37 38 39# Performance result of running the test. 40@dataclass 41class PerfResult: 42 test: TestCase 43 ingest_time_ns: int 44 real_time_ns: int 45 46 def __init__(self, test: TestCase, perf_lines: List[str]): 47 self.test = test 48 49 assert len(perf_lines) == 1 50 perf_numbers = perf_lines[0].split(',') 51 52 assert len(perf_numbers) == 2 53 self.ingest_time_ns = int(perf_numbers[0]) 54 self.real_time_ns = int(perf_numbers[1]) 55 56 57# Data gathered from running the test. 58@dataclass 59class TestResult: 60 test: TestCase 61 trace: str 62 cmd: List[str] 63 expected: str 64 actual: str 65 passed: bool 66 stderr: str 67 exit_code: int 68 perf_result: Optional[PerfResult] 69 70 def __init__(self, test: TestCase, gen_trace_path: str, cmd: List[str], 71 expected_text: str, actual_text: str, stderr: str, 72 exit_code: int, perf_lines: List[str]) -> None: 73 self.test = test 74 self.trace = gen_trace_path 75 self.cmd = cmd 76 self.stderr = stderr 77 self.exit_code = exit_code 78 79 # For better string formatting we often add whitespaces, which has to now 80 # be removed. 81 def strip_whitespaces(text: str): 82 no_front_new_line_text = text.lstrip('\n') 83 return '\n'.join(s.strip() for s in no_front_new_line_text.split('\n')) 84 85 self.expected = strip_whitespaces(expected_text) 86 self.actual = strip_whitespaces(actual_text) 87 88 expected_content = self.expected.replace('\r\n', '\n') 89 90 actual_content = self.actual.replace('\r\n', '\n') 91 self.passed = (expected_content == actual_content) 92 93 if self.exit_code == 0: 94 self.perf_result = PerfResult(self.test, perf_lines) 95 else: 96 self.perf_result = None 97 98 def write_diff(self): 99 expected_lines = self.expected.splitlines(True) 100 actual_lines = self.actual.splitlines(True) 101 diff = difflib.unified_diff( 102 expected_lines, actual_lines, fromfile='expected', tofile='actual') 103 return "".join(list(diff)) 104 105 def rebase(self, rebase) -> str: 106 if not rebase or self.passed: 107 return "" 108 if not self.test.blueprint.is_out_file(): 109 return f"Can't rebase expected results passed as strings.\n" 110 if self.exit_code != 0: 111 return f"Rebase failed for {self.test.name} as query failed\n" 112 113 with open(self.test.expected_path, 'w') as f: 114 f.write(self.actual) 115 return f"Rebasing {self.test.name}\n" 116 117 118# Results of running the test suite. Mostly used for printing aggregated 119# results. 120@dataclass 121class TestResults: 122 test_failures: List[str] 123 perf_data: List[PerfResult] 124 rebased: List[str] 125 test_time_ms: int 126 127 def str(self, no_colors: bool, tests_no: int): 128 c = ColorFormatter(no_colors) 129 res = ( 130 f"[==========] {tests_no} tests ran. ({self.test_time_ms} ms total)\n" 131 f"{c.green('[ PASSED ]')} " 132 f"{tests_no - len(self.test_failures)} tests.\n") 133 if len(self.test_failures) > 0: 134 res += (f"{c.red('[ FAILED ]')} " 135 f"{len(self.test_failures)} tests.\n") 136 for failure in self.test_failures: 137 res += f"{c.red('[ FAILED ]')} {failure}\n" 138 return res 139 140 def rebase_str(self): 141 res = f"\n[ REBASED ] {len(self.rebased)} tests.\n" 142 for name in self.rebased: 143 res += f"[ REBASED ] {name}\n" 144 return res 145 146 147# Responsible for executing singular diff test. 148@dataclass 149class TestCaseRunner: 150 test: TestCase 151 trace_processor_path: str 152 trace_descriptor_path: str 153 colors: ColorFormatter 154 override_sql_module_paths: List[str] 155 156 def __output_to_text_proto(self, actual: str, out: BinaryProto) -> str: 157 """Deserializes a binary proto and returns its text representation. 158 159 Args: 160 actual: (string) HEX encoded serialized proto message 161 message_type: (string) Message type 162 163 Returns: 164 Text proto 165 """ 166 try: 167 raw_data = unhexlify(actual.splitlines()[-1][1:-1]) 168 out_path = os.path.dirname(self.trace_processor_path) 169 descriptor_paths = [ 170 f.path 171 for f in os.scandir( 172 os.path.join(ROOT_DIR, out_path, 'gen', 'protos', 'perfetto', 173 'trace_processor')) 174 if f.is_file() and os.path.splitext(f.name)[1] == '.descriptor' 175 ] 176 descriptor_paths.append( 177 os.path.join(ROOT_DIR, out_path, 'gen', 'protos', 'third_party', 178 'pprof', 'profile.descriptor')) 179 proto = create_message_factory(descriptor_paths, out.message_type)() 180 proto.ParseFromString(raw_data) 181 try: 182 return out.post_processing(proto) 183 except: 184 return '<Proto post processing failed>' 185 return text_format.MessageToString(proto) 186 except: 187 return '<Invalid input for proto deserializaiton>' 188 189 def __run_metrics_test(self, trace_path: str, 190 metrics_message_factory) -> TestResult: 191 192 if self.test.blueprint.is_out_file(): 193 with open(self.test.expected_path, 'r') as expected_file: 194 expected = expected_file.read() 195 else: 196 expected = self.test.blueprint.out.contents 197 198 tmp_perf_file = tempfile.NamedTemporaryFile(delete=False) 199 is_json_output_file = self.test.blueprint.is_out_file( 200 ) and os.path.basename(self.test.expected_path).endswith('.json.out') 201 is_json_output = is_json_output_file or self.test.blueprint.is_out_json() 202 cmd = [ 203 self.trace_processor_path, 204 '--analyze-trace-proto-content', 205 '--crop-track-events', 206 '--extra-checks', 207 '--run-metrics', 208 self.test.blueprint.query.name, 209 '--metrics-output=%s' % ('json' if is_json_output else 'binary'), 210 '--perf-file', 211 tmp_perf_file.name, 212 trace_path, 213 ] 214 for sql_module_path in self.override_sql_module_paths: 215 cmd += ['--override-sql-module', sql_module_path] 216 tp = subprocess.Popen( 217 cmd, 218 stdout=subprocess.PIPE, 219 stderr=subprocess.PIPE, 220 env=get_env(ROOT_DIR)) 221 (stdout, stderr) = tp.communicate() 222 223 if is_json_output: 224 expected_text = expected 225 actual_text = stdout.decode('utf8') 226 else: 227 # Expected will be in text proto format and we'll need to parse it to 228 # a real proto. 229 expected_message = metrics_message_factory() 230 text_format.Merge(expected, expected_message) 231 232 # Actual will be the raw bytes of the proto and we'll need to parse it 233 # into a message. 234 actual_message = metrics_message_factory() 235 actual_message.ParseFromString(stdout) 236 237 # Convert both back to text format. 238 expected_text = text_format.MessageToString(expected_message) 239 actual_text = text_format.MessageToString(actual_message) 240 241 perf_lines = [line.decode('utf8') for line in tmp_perf_file.readlines()] 242 tmp_perf_file.close() 243 os.remove(tmp_perf_file.name) 244 return TestResult(self.test, trace_path, cmd, expected_text, actual_text, 245 stderr.decode('utf8'), tp.returncode, perf_lines) 246 247 # Run a query based Diff Test. 248 def __run_query_test(self, trace_path: str, keep_query: bool) -> TestResult: 249 # Fetch expected text. 250 if self.test.expected_path: 251 with open(self.test.expected_path, 'r') as expected_file: 252 expected = expected_file.read() 253 else: 254 expected = self.test.blueprint.out.contents 255 256 # Fetch query. 257 if self.test.blueprint.is_query_file(): 258 query = self.test.query_path 259 else: 260 tmp_query_file = tempfile.NamedTemporaryFile(delete=False) 261 with open(tmp_query_file.name, 'w') as query_file: 262 query_file.write(self.test.blueprint.query) 263 query = tmp_query_file.name 264 265 tmp_perf_file = tempfile.NamedTemporaryFile(delete=False) 266 cmd = [ 267 self.trace_processor_path, 268 '--analyze-trace-proto-content', 269 '--crop-track-events', 270 '--extra-checks', 271 '-q', 272 query, 273 '--perf-file', 274 tmp_perf_file.name, 275 trace_path, 276 ] 277 for sql_module_path in self.override_sql_module_paths: 278 cmd += ['--override-sql-module', sql_module_path] 279 tp = subprocess.Popen( 280 cmd, 281 stdout=subprocess.PIPE, 282 stderr=subprocess.PIPE, 283 env=get_env(ROOT_DIR)) 284 (stdout, stderr) = tp.communicate() 285 286 if not self.test.blueprint.is_query_file() and not keep_query: 287 tmp_query_file.close() 288 os.remove(tmp_query_file.name) 289 perf_lines = [line.decode('utf8') for line in tmp_perf_file.readlines()] 290 tmp_perf_file.close() 291 os.remove(tmp_perf_file.name) 292 293 actual = stdout.decode('utf8') 294 if self.test.blueprint.is_out_binaryproto(): 295 actual = self.__output_to_text_proto(actual, self.test.blueprint.out) 296 297 return TestResult(self.test, trace_path, cmd, expected, actual, 298 stderr.decode('utf8'), tp.returncode, perf_lines) 299 300 def __run(self, metrics_descriptor_paths: List[str], 301 extension_descriptor_paths: List[str], keep_input, 302 rebase) -> Tuple[TestResult, str]: 303 # We can't use delete=True here. When using that on Windows, the 304 # resulting file is opened in exclusive mode (in turn that's a subtle 305 # side-effect of the underlying CreateFile(FILE_ATTRIBUTE_TEMPORARY)) 306 # and TP fails to open the passed path. 307 gen_trace_file = None 308 if self.test.blueprint.is_trace_file(): 309 if self.test.trace_path.endswith('.py'): 310 gen_trace_file = tempfile.NamedTemporaryFile(delete=False) 311 serialize_python_trace(ROOT_DIR, self.trace_descriptor_path, 312 self.test.trace_path, gen_trace_file) 313 314 elif self.test.trace_path.endswith('.textproto'): 315 gen_trace_file = tempfile.NamedTemporaryFile(delete=False) 316 serialize_textproto_trace(self.trace_descriptor_path, 317 extension_descriptor_paths, 318 self.test.trace_path, gen_trace_file) 319 320 elif self.test.blueprint.is_trace_textproto(): 321 gen_trace_file = tempfile.NamedTemporaryFile(delete=False) 322 proto = create_message_factory([self.trace_descriptor_path] + 323 extension_descriptor_paths, 324 'perfetto.protos.Trace')() 325 text_format.Merge(self.test.blueprint.trace.contents, proto) 326 gen_trace_file.write(proto.SerializeToString()) 327 gen_trace_file.flush() 328 329 else: 330 gen_trace_file = tempfile.NamedTemporaryFile(delete=False) 331 with open(gen_trace_file.name, 'w') as trace_file: 332 trace_file.write(self.test.blueprint.trace.contents) 333 334 if self.test.blueprint.trace_modifier is not None: 335 if gen_trace_file: 336 # Overwrite |gen_trace_file|. 337 modify_trace(self.trace_descriptor_path, extension_descriptor_paths, 338 gen_trace_file.name, gen_trace_file.name, 339 self.test.blueprint.trace_modifier) 340 else: 341 # Create |gen_trace_file| to save the modified trace. 342 gen_trace_file = tempfile.NamedTemporaryFile(delete=False) 343 modify_trace(self.trace_descriptor_path, extension_descriptor_paths, 344 self.test.trace_path, gen_trace_file.name, 345 self.test.blueprint.trace_modifier) 346 347 if gen_trace_file: 348 trace_path = os.path.realpath(gen_trace_file.name) 349 else: 350 trace_path = self.test.trace_path 351 352 str = f"{self.colors.yellow('[ RUN ]')} {self.test.name}\n" 353 354 if self.test.type == TestType.QUERY: 355 result = self.__run_query_test(trace_path, keep_input) 356 elif self.test.type == TestType.METRIC: 357 result = self.__run_metrics_test( 358 trace_path, 359 create_message_factory(metrics_descriptor_paths, 360 'perfetto.protos.TraceMetrics')) 361 else: 362 assert False 363 364 if gen_trace_file: 365 if keep_input: 366 str += f"Saving generated input trace: {trace_path}\n" 367 else: 368 gen_trace_file.close() 369 os.remove(trace_path) 370 371 def write_cmdlines(): 372 res = "" 373 if self.test.trace_path and (self.test.trace_path.endswith('.textproto') 374 or self.test.trace_path.endswith('.py')): 375 res += 'Command to generate trace:\n' 376 res += 'tools/serialize_test_trace.py ' 377 res += '--descriptor {} {} > {}\n'.format( 378 os.path.relpath(self.trace_descriptor_path, ROOT_DIR), 379 os.path.relpath(self.test.trace_path, ROOT_DIR), 380 os.path.relpath(trace_path, ROOT_DIR)) 381 res += f"Command line:\n{' '.join(result.cmd)}\n" 382 return res 383 384 if result.exit_code != 0 or not result.passed: 385 result.passed = False 386 str += result.stderr 387 388 if result.exit_code == 0: 389 str += f"Expected did not match actual for test {self.test.name}.\n" 390 str += write_cmdlines() 391 str += result.write_diff() 392 else: 393 str += write_cmdlines() 394 395 str += (f"{self.colors.red('[ FAILED ]')} {self.test.name}\n") 396 str += result.rebase(rebase) 397 398 return result, str 399 else: 400 str += (f"{self.colors.green('[ OK ]')} {self.test.name} " 401 f"(ingest: {result.perf_result.ingest_time_ns / 1000000:.2f} ms " 402 f"query: {result.perf_result.real_time_ns / 1000000:.2f} ms)\n") 403 return result, str 404 405 # Run a TestCase. 406 def execute(self, extension_descriptor_paths: List[str], 407 metrics_descriptor_paths: List[str], keep_input: bool, 408 rebase: bool) -> Tuple[str, str, TestResult]: 409 if not metrics_descriptor_paths: 410 out_path = os.path.dirname(self.trace_processor_path) 411 metrics_protos_path = os.path.join(out_path, 'gen', 'protos', 'perfetto', 412 'metrics') 413 metrics_descriptor_paths = [ 414 os.path.join(metrics_protos_path, 'metrics.descriptor'), 415 os.path.join(metrics_protos_path, 'chrome', 416 'all_chrome_metrics.descriptor'), 417 os.path.join(metrics_protos_path, 'webview', 418 'all_webview_metrics.descriptor') 419 ] 420 result, run_str = self.__run(metrics_descriptor_paths, 421 extension_descriptor_paths, keep_input, rebase) 422 if not result: 423 return self.test.name, run_str, None 424 425 return self.test.name, run_str, result 426 427 428# Fetches and executes all diff viable tests. 429@dataclass 430class DiffTestsRunner: 431 tests: List[TestCase] 432 trace_processor_path: str 433 trace_descriptor_path: str 434 test_runners: List[TestCaseRunner] 435 quiet: bool 436 437 def __init__(self, name_filter: str, trace_processor_path: str, 438 trace_descriptor: str, no_colors: bool, 439 override_sql_module_paths: List[str], test_dir: str, 440 quiet: bool): 441 self.tests = read_all_tests(name_filter, test_dir) 442 self.trace_processor_path = trace_processor_path 443 self.quiet = quiet 444 445 out_path = os.path.dirname(self.trace_processor_path) 446 self.trace_descriptor_path = get_trace_descriptor_path( 447 out_path, trace_descriptor) 448 self.test_runners = [] 449 color_formatter = ColorFormatter(no_colors) 450 for test in self.tests: 451 self.test_runners.append( 452 TestCaseRunner(test, self.trace_processor_path, 453 self.trace_descriptor_path, color_formatter, 454 override_sql_module_paths)) 455 456 def run_all_tests(self, metrics_descriptor_paths: List[str], 457 chrome_extensions: str, test_extensions: str, 458 winscope_extensions: str, keep_input: bool, 459 rebase: bool) -> TestResults: 460 perf_results = [] 461 failures = [] 462 rebased = [] 463 test_run_start = datetime.datetime.now() 464 completed_tests = 0 465 466 with concurrent.futures.ProcessPoolExecutor() as e: 467 fut = [ 468 e.submit(test.execute, 469 [chrome_extensions, test_extensions, winscope_extensions], 470 metrics_descriptor_paths, keep_input, rebase) 471 for test in self.test_runners 472 ] 473 for res in concurrent.futures.as_completed(fut): 474 test_name, res_str, result = res.result() 475 476 if self.quiet: 477 completed_tests += 1 478 sys.stderr.write(f"\rRan {completed_tests} tests") 479 if not result.passed: 480 sys.stderr.write(f"\r") 481 sys.stderr.write(res_str) 482 else: 483 sys.stderr.write(res_str) 484 485 if not result or not result.passed: 486 if rebase: 487 rebased.append(test_name) 488 failures.append(test_name) 489 else: 490 perf_results.append(result.perf_result) 491 test_time_ms = int( 492 (datetime.datetime.now() - test_run_start).total_seconds() * 1000) 493 if self.quiet: 494 sys.stderr.write(f"\r") 495 return TestResults(failures, perf_results, rebased, test_time_ms) 496