xref: /aosp_15_r20/external/perfetto/python/generators/diff_tests/testing.py (revision 6dbdd20afdafa5e3ca9b8809fa73465d530080dc)
1#!/usr/bin/env python3
2# Copyright (C) 2023 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import inspect
17import os
18from dataclasses import dataclass
19from typing import Any, Dict, List, Union, Callable
20from enum import Enum
21import re
22
23from google.protobuf import text_format
24
25TestName = str
26
27
28@dataclass
29class Path:
30  filename: str
31
32
33@dataclass
34class DataPath(Path):
35  filename: str
36
37
38@dataclass
39class Metric:
40  name: str
41
42
43@dataclass
44class Json:
45  contents: str
46
47
48@dataclass
49class Csv:
50  contents: str
51
52
53@dataclass
54class TextProto:
55  contents: str
56
57
58@dataclass
59class BinaryProto:
60  message_type: str
61  contents: str
62  # Comparing protos is tricky. For example, repeated fields might be written in
63  # any order. To help with that you can specify a `post_processing` function
64  # that will be called with the actual proto message object before converting
65  # it to text representation and doing the comparison with `contents`. This
66  # gives us a chance to e.g. sort messages in a repeated field.
67  post_processing: Callable = text_format.MessageToString
68
69
70@dataclass
71class Systrace:
72  contents: str
73
74
75class TraceInjector:
76  '''Injects fields into trace packets before test starts.
77
78  TraceInjector can be used within a DiffTestBlueprint to selectively inject
79  fields to trace packets containing specific data types. For example:
80
81    DiffTestBlueprint(
82        trace=...,
83        trace_modifier=TraceInjector('ftrace_events',
84                                     'sys_stats',
85                                     'process_tree',
86                                     {'machine_id': 1001},
87                                     trusted_uid=123)
88        query=...,
89        out=...)
90
91  packet_data_types: Data types to target for injection ('ftrace_events',
92  'sys_stats', 'process_tree')
93  injected_fields: Fields and their values to inject into matching packets
94  ({'machine_id': 1001}, trusted_uid=123).
95  '''
96
97  def __init__(self, packet_data_types: List[str], injected_fields: Dict[str,
98                                                                         Any]):
99    self.packet_data_types = packet_data_types
100    self.injected_fields = injected_fields
101
102  def inject(self, proto):
103    for p in proto.packet:
104      for f in self.packet_data_types:
105        if p.HasField(f):
106          for k, v, in self.injected_fields.items():
107            setattr(p, k, v)
108          continue
109
110
111class TestType(Enum):
112  QUERY = 1
113  METRIC = 2
114
115
116# Blueprint for running the diff test. 'query' is being run over data from the
117# 'trace 'and result will be compared to the 'out. Each test (function in class
118# inheriting from TestSuite) returns a DiffTestBlueprint.
119@dataclass
120class DiffTestBlueprint:
121
122  trace: Union[Path, DataPath, Json, Systrace, TextProto]
123  query: Union[str, Path, DataPath, Metric]
124  out: Union[Path, DataPath, Json, Csv, TextProto, BinaryProto]
125  trace_modifier: Union[TraceInjector, None] = None
126
127  def is_trace_file(self):
128    return isinstance(self.trace, Path)
129
130  def is_trace_textproto(self):
131    return isinstance(self.trace, TextProto)
132
133  def is_trace_json(self):
134    return isinstance(self.trace, Json)
135
136  def is_trace_systrace(self):
137    return isinstance(self.trace, Systrace)
138
139  def is_query_file(self):
140    return isinstance(self.query, Path)
141
142  def is_metric(self):
143    return isinstance(self.query, Metric)
144
145  def is_out_file(self):
146    return isinstance(self.out, Path)
147
148  def is_out_json(self):
149    return isinstance(self.out, Json)
150
151  def is_out_texproto(self):
152    return isinstance(self.out, TextProto)
153
154  def is_out_binaryproto(self):
155    return isinstance(self.out, BinaryProto)
156
157  def is_out_csv(self):
158    return isinstance(self.out, Csv)
159
160
161# Description of a diff test. Created in `fetch_diff_tests()` in
162# TestSuite: each test (function starting with `test_`) returns
163# DiffTestBlueprint and function name is a TestCase name. Used by diff test
164# script.
165class TestCase:
166
167  def __get_query_path(self) -> str:
168    if not self.blueprint.is_query_file():
169      return None
170
171    if isinstance(self.blueprint.query, DataPath):
172      path = os.path.join(self.test_data_dir, self.blueprint.query.filename)
173    else:
174      path = os.path.abspath(
175          os.path.join(self.index_dir, self.blueprint.query.filename))
176
177    if not os.path.exists(path):
178      raise AssertionError(
179          f"Query file ({path}) for test '{self.name}' does not exist.")
180    return path
181
182  def __get_trace_path(self) -> str:
183    if not self.blueprint.is_trace_file():
184      return None
185
186    if isinstance(self.blueprint.trace, DataPath):
187      path = os.path.join(self.test_data_dir, self.blueprint.trace.filename)
188    else:
189      path = os.path.abspath(
190          os.path.join(self.index_dir, self.blueprint.trace.filename))
191
192    if not os.path.exists(path):
193      raise AssertionError(
194          f"Trace file ({path}) for test '{self.name}' does not exist.")
195    return path
196
197  def __get_out_path(self) -> str:
198    if not self.blueprint.is_out_file():
199      return None
200
201    if isinstance(self.blueprint.out, DataPath):
202      path = os.path.join(self.test_data_dir, self.blueprint.out.filename)
203    else:
204      path = os.path.abspath(
205          os.path.join(self.index_dir, self.blueprint.out.filename))
206
207    if not os.path.exists(path):
208      raise AssertionError(
209          f"Out file ({path}) for test '{self.name}' does not exist.")
210    return path
211
212  def __init__(self, name: str, blueprint: DiffTestBlueprint, index_dir: str,
213               test_data_dir: str) -> None:
214    self.name = name
215    self.blueprint = blueprint
216    self.index_dir = index_dir
217    self.test_data_dir = test_data_dir
218
219    if blueprint.is_metric():
220      self.type = TestType.METRIC
221    else:
222      self.type = TestType.QUERY
223
224    self.query_path = self.__get_query_path()
225    self.trace_path = self.__get_trace_path()
226    self.expected_path = self.__get_out_path()
227
228  # Verifies that the test should be in test suite. If False, test will not be
229  # executed.
230  def validate(self, name_filter: str):
231    query_metric_pattern = re.compile(name_filter)
232    return bool(query_metric_pattern.match(os.path.basename(self.name)))
233
234
235# Virtual class responsible for fetching diff tests.
236# All functions with name starting with `test_` have to return
237# DiffTestBlueprint and function name is a test name. All DiffTestModules have
238# to be included in `test/diff_tests/trace_processor/include_index.py`.
239# `fetch` function should not be overwritten.
240class TestSuite:
241
242  def __init__(
243      self,
244      include_index_dir: str,
245      dir_name: str,
246      class_name: str,
247      test_data_dir: str = os.path.abspath(
248          os.path.join(__file__, '../../../../test/data'))
249  ) -> None:
250    self.dir_name = dir_name
251    self.index_dir = os.path.join(include_index_dir, dir_name)
252    self.class_name = class_name
253    self.test_data_dir = test_data_dir
254
255  def __test_name(self, method_name):
256    return f"{self.class_name}:{method_name.split('test_',1)[1]}"
257
258  def fetch(self) -> List['TestCase']:
259    attrs = (getattr(self, name) for name in dir(self))
260    methods = [attr for attr in attrs if inspect.ismethod(attr)]
261    return [
262        TestCase(
263            self.__test_name(method.__name__), method(), self.index_dir,
264            self.test_data_dir)
265        for method in methods
266        if method.__name__.startswith('test_')
267    ]
268
269
270def PrintProfileProto(profile):
271  locations = {l.id: l for l in profile.location}
272  functions = {f.id: f for f in profile.function}
273  samples = []
274  # Strips trailing annotations like (.__uniq.1657) from the function name.
275  filter_fname = lambda x: re.sub(' [(\[].*?uniq.*?[)\]]$', '', x)
276  for s in profile.sample:
277    stack = []
278    for location in [locations[id] for id in s.location_id]:
279      for function in [functions[l.function_id] for l in location.line]:
280        stack.append("{name} ({address})".format(
281            name=filter_fname(profile.string_table[function.name]),
282            address=hex(location.address)))
283      if len(location.line) == 0:
284        stack.append("({address})".format(address=hex(location.address)))
285    samples.append('Sample:\nValues: {values}\nStack:\n{stack}'.format(
286        values=', '.join(map(str, s.value)), stack='\n'.join(stack)))
287  return '\n\n'.join(sorted(samples)) + '\n'
288