xref: /aosp_15_r20/external/perfetto/python/generators/diff_tests/runner.py (revision 6dbdd20afdafa5e3ca9b8809fa73465d530080dc)
1#!/usr/bin/env python3
2# Copyright (C) 2023 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import concurrent.futures
17import datetime
18import difflib
19import os
20import subprocess
21import sys
22import tempfile
23from binascii import unhexlify
24from dataclasses import dataclass
25from typing import List, Tuple, Optional
26
27from google.protobuf import text_format, message_factory, descriptor_pool
28from python.generators.diff_tests.testing import TestCase, TestType, BinaryProto
29from python.generators.diff_tests.utils import (
30    ColorFormatter, create_message_factory, get_env, get_trace_descriptor_path,
31    read_all_tests, serialize_python_trace, serialize_textproto_trace,
32    modify_trace)
33
34ROOT_DIR = os.path.dirname(
35    os.path.dirname(
36        os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
37
38
39# Performance result of running the test.
40@dataclass
41class PerfResult:
42  test: TestCase
43  ingest_time_ns: int
44  real_time_ns: int
45
46  def __init__(self, test: TestCase, perf_lines: List[str]):
47    self.test = test
48
49    assert len(perf_lines) == 1
50    perf_numbers = perf_lines[0].split(',')
51
52    assert len(perf_numbers) == 2
53    self.ingest_time_ns = int(perf_numbers[0])
54    self.real_time_ns = int(perf_numbers[1])
55
56
57# Data gathered from running the test.
58@dataclass
59class TestResult:
60  test: TestCase
61  trace: str
62  cmd: List[str]
63  expected: str
64  actual: str
65  passed: bool
66  stderr: str
67  exit_code: int
68  perf_result: Optional[PerfResult]
69
70  def __init__(self, test: TestCase, gen_trace_path: str, cmd: List[str],
71               expected_text: str, actual_text: str, stderr: str,
72               exit_code: int, perf_lines: List[str]) -> None:
73    self.test = test
74    self.trace = gen_trace_path
75    self.cmd = cmd
76    self.stderr = stderr
77    self.exit_code = exit_code
78
79    # For better string formatting we often add whitespaces, which has to now
80    # be removed.
81    def strip_whitespaces(text: str):
82      no_front_new_line_text = text.lstrip('\n')
83      return '\n'.join(s.strip() for s in no_front_new_line_text.split('\n'))
84
85    self.expected = strip_whitespaces(expected_text)
86    self.actual = strip_whitespaces(actual_text)
87
88    expected_content = self.expected.replace('\r\n', '\n')
89
90    actual_content = self.actual.replace('\r\n', '\n')
91    self.passed = (expected_content == actual_content)
92
93    if self.exit_code == 0:
94      self.perf_result = PerfResult(self.test, perf_lines)
95    else:
96      self.perf_result = None
97
98  def write_diff(self):
99    expected_lines = self.expected.splitlines(True)
100    actual_lines = self.actual.splitlines(True)
101    diff = difflib.unified_diff(
102        expected_lines, actual_lines, fromfile='expected', tofile='actual')
103    return "".join(list(diff))
104
105  def rebase(self, rebase) -> str:
106    if not rebase or self.passed:
107      return ""
108    if not self.test.blueprint.is_out_file():
109      return f"Can't rebase expected results passed as strings.\n"
110    if self.exit_code != 0:
111      return f"Rebase failed for {self.test.name} as query failed\n"
112
113    with open(self.test.expected_path, 'w') as f:
114      f.write(self.actual)
115    return f"Rebasing {self.test.name}\n"
116
117
118# Results of running the test suite. Mostly used for printing aggregated
119# results.
120@dataclass
121class TestResults:
122  test_failures: List[str]
123  perf_data: List[PerfResult]
124  rebased: List[str]
125  test_time_ms: int
126
127  def str(self, no_colors: bool, tests_no: int):
128    c = ColorFormatter(no_colors)
129    res = (
130        f"[==========] {tests_no} tests ran. ({self.test_time_ms} ms total)\n"
131        f"{c.green('[  PASSED  ]')} "
132        f"{tests_no - len(self.test_failures)} tests.\n")
133    if len(self.test_failures) > 0:
134      res += (f"{c.red('[  FAILED  ]')} "
135              f"{len(self.test_failures)} tests.\n")
136      for failure in self.test_failures:
137        res += f"{c.red('[  FAILED  ]')} {failure}\n"
138    return res
139
140  def rebase_str(self):
141    res = f"\n[  REBASED  ] {len(self.rebased)} tests.\n"
142    for name in self.rebased:
143      res += f"[  REBASED  ] {name}\n"
144    return res
145
146
147# Responsible for executing singular diff test.
148@dataclass
149class TestCaseRunner:
150  test: TestCase
151  trace_processor_path: str
152  trace_descriptor_path: str
153  colors: ColorFormatter
154  override_sql_module_paths: List[str]
155
156  def __output_to_text_proto(self, actual: str, out: BinaryProto) -> str:
157    """Deserializes a binary proto and returns its text representation.
158
159  Args:
160    actual: (string) HEX encoded serialized proto message
161    message_type: (string) Message type
162
163  Returns:
164    Text proto
165  """
166    try:
167      raw_data = unhexlify(actual.splitlines()[-1][1:-1])
168      out_path = os.path.dirname(self.trace_processor_path)
169      descriptor_paths = [
170          f.path
171          for f in os.scandir(
172              os.path.join(ROOT_DIR, out_path, 'gen', 'protos', 'perfetto',
173                           'trace_processor'))
174          if f.is_file() and os.path.splitext(f.name)[1] == '.descriptor'
175      ]
176      descriptor_paths.append(
177          os.path.join(ROOT_DIR, out_path, 'gen', 'protos', 'third_party',
178                       'pprof', 'profile.descriptor'))
179      proto = create_message_factory(descriptor_paths, out.message_type)()
180      proto.ParseFromString(raw_data)
181      try:
182        return out.post_processing(proto)
183      except:
184        return '<Proto post processing failed>'
185      return text_format.MessageToString(proto)
186    except:
187      return '<Invalid input for proto deserializaiton>'
188
189  def __run_metrics_test(self, trace_path: str,
190                         metrics_message_factory) -> TestResult:
191
192    if self.test.blueprint.is_out_file():
193      with open(self.test.expected_path, 'r') as expected_file:
194        expected = expected_file.read()
195    else:
196      expected = self.test.blueprint.out.contents
197
198    tmp_perf_file = tempfile.NamedTemporaryFile(delete=False)
199    is_json_output_file = self.test.blueprint.is_out_file(
200    ) and os.path.basename(self.test.expected_path).endswith('.json.out')
201    is_json_output = is_json_output_file or self.test.blueprint.is_out_json()
202    cmd = [
203        self.trace_processor_path,
204        '--analyze-trace-proto-content',
205        '--crop-track-events',
206        '--extra-checks',
207        '--run-metrics',
208        self.test.blueprint.query.name,
209        '--metrics-output=%s' % ('json' if is_json_output else 'binary'),
210        '--perf-file',
211        tmp_perf_file.name,
212        trace_path,
213    ]
214    for sql_module_path in self.override_sql_module_paths:
215      cmd += ['--override-sql-module', sql_module_path]
216    tp = subprocess.Popen(
217        cmd,
218        stdout=subprocess.PIPE,
219        stderr=subprocess.PIPE,
220        env=get_env(ROOT_DIR))
221    (stdout, stderr) = tp.communicate()
222
223    if is_json_output:
224      expected_text = expected
225      actual_text = stdout.decode('utf8')
226    else:
227      # Expected will be in text proto format and we'll need to parse it to
228      # a real proto.
229      expected_message = metrics_message_factory()
230      text_format.Merge(expected, expected_message)
231
232      # Actual will be the raw bytes of the proto and we'll need to parse it
233      # into a message.
234      actual_message = metrics_message_factory()
235      actual_message.ParseFromString(stdout)
236
237      # Convert both back to text format.
238      expected_text = text_format.MessageToString(expected_message)
239      actual_text = text_format.MessageToString(actual_message)
240
241    perf_lines = [line.decode('utf8') for line in tmp_perf_file.readlines()]
242    tmp_perf_file.close()
243    os.remove(tmp_perf_file.name)
244    return TestResult(self.test, trace_path, cmd, expected_text, actual_text,
245                      stderr.decode('utf8'), tp.returncode, perf_lines)
246
247  # Run a query based Diff Test.
248  def __run_query_test(self, trace_path: str, keep_query: bool) -> TestResult:
249    # Fetch expected text.
250    if self.test.expected_path:
251      with open(self.test.expected_path, 'r') as expected_file:
252        expected = expected_file.read()
253    else:
254      expected = self.test.blueprint.out.contents
255
256    # Fetch query.
257    if self.test.blueprint.is_query_file():
258      query = self.test.query_path
259    else:
260      tmp_query_file = tempfile.NamedTemporaryFile(delete=False)
261      with open(tmp_query_file.name, 'w') as query_file:
262        query_file.write(self.test.blueprint.query)
263      query = tmp_query_file.name
264
265    tmp_perf_file = tempfile.NamedTemporaryFile(delete=False)
266    cmd = [
267        self.trace_processor_path,
268        '--analyze-trace-proto-content',
269        '--crop-track-events',
270        '--extra-checks',
271        '-q',
272        query,
273        '--perf-file',
274        tmp_perf_file.name,
275        trace_path,
276    ]
277    for sql_module_path in self.override_sql_module_paths:
278      cmd += ['--override-sql-module', sql_module_path]
279    tp = subprocess.Popen(
280        cmd,
281        stdout=subprocess.PIPE,
282        stderr=subprocess.PIPE,
283        env=get_env(ROOT_DIR))
284    (stdout, stderr) = tp.communicate()
285
286    if not self.test.blueprint.is_query_file() and not keep_query:
287      tmp_query_file.close()
288      os.remove(tmp_query_file.name)
289    perf_lines = [line.decode('utf8') for line in tmp_perf_file.readlines()]
290    tmp_perf_file.close()
291    os.remove(tmp_perf_file.name)
292
293    actual = stdout.decode('utf8')
294    if self.test.blueprint.is_out_binaryproto():
295      actual = self.__output_to_text_proto(actual, self.test.blueprint.out)
296
297    return TestResult(self.test, trace_path, cmd, expected, actual,
298                      stderr.decode('utf8'), tp.returncode, perf_lines)
299
300  def __run(self, metrics_descriptor_paths: List[str],
301            extension_descriptor_paths: List[str], keep_input,
302            rebase) -> Tuple[TestResult, str]:
303    # We can't use delete=True here. When using that on Windows, the
304    # resulting file is opened in exclusive mode (in turn that's a subtle
305    # side-effect of the underlying CreateFile(FILE_ATTRIBUTE_TEMPORARY))
306    # and TP fails to open the passed path.
307    gen_trace_file = None
308    if self.test.blueprint.is_trace_file():
309      if self.test.trace_path.endswith('.py'):
310        gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
311        serialize_python_trace(ROOT_DIR, self.trace_descriptor_path,
312                               self.test.trace_path, gen_trace_file)
313
314      elif self.test.trace_path.endswith('.textproto'):
315        gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
316        serialize_textproto_trace(self.trace_descriptor_path,
317                                  extension_descriptor_paths,
318                                  self.test.trace_path, gen_trace_file)
319
320    elif self.test.blueprint.is_trace_textproto():
321      gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
322      proto = create_message_factory([self.trace_descriptor_path] +
323                                     extension_descriptor_paths,
324                                     'perfetto.protos.Trace')()
325      text_format.Merge(self.test.blueprint.trace.contents, proto)
326      gen_trace_file.write(proto.SerializeToString())
327      gen_trace_file.flush()
328
329    else:
330      gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
331      with open(gen_trace_file.name, 'w') as trace_file:
332        trace_file.write(self.test.blueprint.trace.contents)
333
334    if self.test.blueprint.trace_modifier is not None:
335      if gen_trace_file:
336        # Overwrite |gen_trace_file|.
337        modify_trace(self.trace_descriptor_path, extension_descriptor_paths,
338                     gen_trace_file.name, gen_trace_file.name,
339                     self.test.blueprint.trace_modifier)
340      else:
341        # Create |gen_trace_file| to save the modified trace.
342        gen_trace_file = tempfile.NamedTemporaryFile(delete=False)
343        modify_trace(self.trace_descriptor_path, extension_descriptor_paths,
344                     self.test.trace_path, gen_trace_file.name,
345                     self.test.blueprint.trace_modifier)
346
347    if gen_trace_file:
348      trace_path = os.path.realpath(gen_trace_file.name)
349    else:
350      trace_path = self.test.trace_path
351
352    str = f"{self.colors.yellow('[ RUN      ]')} {self.test.name}\n"
353
354    if self.test.type == TestType.QUERY:
355      result = self.__run_query_test(trace_path, keep_input)
356    elif self.test.type == TestType.METRIC:
357      result = self.__run_metrics_test(
358          trace_path,
359          create_message_factory(metrics_descriptor_paths,
360                                 'perfetto.protos.TraceMetrics'))
361    else:
362      assert False
363
364    if gen_trace_file:
365      if keep_input:
366        str += f"Saving generated input trace: {trace_path}\n"
367      else:
368        gen_trace_file.close()
369        os.remove(trace_path)
370
371    def write_cmdlines():
372      res = ""
373      if self.test.trace_path and (self.test.trace_path.endswith('.textproto')
374                                   or self.test.trace_path.endswith('.py')):
375        res += 'Command to generate trace:\n'
376        res += 'tools/serialize_test_trace.py '
377        res += '--descriptor {} {} > {}\n'.format(
378            os.path.relpath(self.trace_descriptor_path, ROOT_DIR),
379            os.path.relpath(self.test.trace_path, ROOT_DIR),
380            os.path.relpath(trace_path, ROOT_DIR))
381      res += f"Command line:\n{' '.join(result.cmd)}\n"
382      return res
383
384    if result.exit_code != 0 or not result.passed:
385      result.passed = False
386      str += result.stderr
387
388      if result.exit_code == 0:
389        str += f"Expected did not match actual for test {self.test.name}.\n"
390        str += write_cmdlines()
391        str += result.write_diff()
392      else:
393        str += write_cmdlines()
394
395      str += (f"{self.colors.red('[  FAILED  ]')} {self.test.name}\n")
396      str += result.rebase(rebase)
397
398      return result, str
399    else:
400      str += (f"{self.colors.green('[       OK ]')} {self.test.name} "
401              f"(ingest: {result.perf_result.ingest_time_ns / 1000000:.2f} ms "
402              f"query: {result.perf_result.real_time_ns / 1000000:.2f} ms)\n")
403    return result, str
404
405  # Run a TestCase.
406  def execute(self, extension_descriptor_paths: List[str],
407              metrics_descriptor_paths: List[str], keep_input: bool,
408              rebase: bool) -> Tuple[str, str, TestResult]:
409    if not metrics_descriptor_paths:
410      out_path = os.path.dirname(self.trace_processor_path)
411      metrics_protos_path = os.path.join(out_path, 'gen', 'protos', 'perfetto',
412                                         'metrics')
413      metrics_descriptor_paths = [
414          os.path.join(metrics_protos_path, 'metrics.descriptor'),
415          os.path.join(metrics_protos_path, 'chrome',
416                       'all_chrome_metrics.descriptor'),
417          os.path.join(metrics_protos_path, 'webview',
418                       'all_webview_metrics.descriptor')
419      ]
420    result, run_str = self.__run(metrics_descriptor_paths,
421                                 extension_descriptor_paths, keep_input, rebase)
422    if not result:
423      return self.test.name, run_str, None
424
425    return self.test.name, run_str, result
426
427
428# Fetches and executes all diff viable tests.
429@dataclass
430class DiffTestsRunner:
431  tests: List[TestCase]
432  trace_processor_path: str
433  trace_descriptor_path: str
434  test_runners: List[TestCaseRunner]
435  quiet: bool
436
437  def __init__(self, name_filter: str, trace_processor_path: str,
438               trace_descriptor: str, no_colors: bool,
439               override_sql_module_paths: List[str], test_dir: str,
440               quiet: bool):
441    self.tests = read_all_tests(name_filter, test_dir)
442    self.trace_processor_path = trace_processor_path
443    self.quiet = quiet
444
445    out_path = os.path.dirname(self.trace_processor_path)
446    self.trace_descriptor_path = get_trace_descriptor_path(
447        out_path, trace_descriptor)
448    self.test_runners = []
449    color_formatter = ColorFormatter(no_colors)
450    for test in self.tests:
451      self.test_runners.append(
452          TestCaseRunner(test, self.trace_processor_path,
453                         self.trace_descriptor_path, color_formatter,
454                         override_sql_module_paths))
455
456  def run_all_tests(self, metrics_descriptor_paths: List[str],
457                    chrome_extensions: str, test_extensions: str,
458                    winscope_extensions: str, keep_input: bool,
459                    rebase: bool) -> TestResults:
460    perf_results = []
461    failures = []
462    rebased = []
463    test_run_start = datetime.datetime.now()
464    completed_tests = 0
465
466    with concurrent.futures.ProcessPoolExecutor() as e:
467      fut = [
468          e.submit(test.execute,
469                   [chrome_extensions, test_extensions, winscope_extensions],
470                   metrics_descriptor_paths, keep_input, rebase)
471          for test in self.test_runners
472      ]
473      for res in concurrent.futures.as_completed(fut):
474        test_name, res_str, result = res.result()
475
476        if self.quiet:
477          completed_tests += 1
478          sys.stderr.write(f"\rRan {completed_tests} tests")
479          if not result.passed:
480            sys.stderr.write(f"\r")
481            sys.stderr.write(res_str)
482        else:
483          sys.stderr.write(res_str)
484
485        if not result or not result.passed:
486          if rebase:
487            rebased.append(test_name)
488          failures.append(test_name)
489        else:
490          perf_results.append(result.perf_result)
491    test_time_ms = int(
492        (datetime.datetime.now() - test_run_start).total_seconds() * 1000)
493    if self.quiet:
494      sys.stderr.write(f"\r")
495    return TestResults(failures, perf_results, rebased, test_time_ms)
496