1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Avatar metrics trace."""
16
17import atexit
18import time
19import types
20
21from avatar.metrics.trace_pb2 import DebugAnnotation
22from avatar.metrics.trace_pb2 import ProcessDescriptor
23from avatar.metrics.trace_pb2 import ThreadDescriptor
24from avatar.metrics.trace_pb2 import Trace
25from avatar.metrics.trace_pb2 import TracePacket
26from avatar.metrics.trace_pb2 import TrackDescriptor
27from avatar.metrics.trace_pb2 import TrackEvent
28from google.protobuf import any_pb2
29from google.protobuf import message
30from mobly.base_test import BaseTestClass
31from pathlib import Path
32from typing import TYPE_CHECKING, Any, Dict, List, Optional, Protocol, Tuple, Union
33
34if TYPE_CHECKING:
35    from avatar import PandoraDevices
36    from avatar.pandora_client import PandoraClient
37else:
38    PandoraClient = object
39    PandoraDevices = object
40
41devices_id: Dict[PandoraClient, int] = {}
42devices_process_id: Dict[PandoraClient, int] = {}
43packets: List[TracePacket] = []
44genesis: int = time.monotonic_ns()
45output_path: Optional[Path] = None
46id: int = 0
47
48
49def next_id() -> int:
50    global id
51    id += 1
52    return id
53
54
55@atexit.register
56def dump_trace() -> None:
57    global packets, output_path
58    if output_path is None:
59        return
60    trace = Trace(packet=packets)
61    with open(output_path / "avatar.trace", "wb") as f:
62        f.write(trace.SerializeToString())
63
64
65def hook_test(test: BaseTestClass, devices: PandoraDevices) -> None:
66    global packets, output_path
67
68    if output_path is None:
69        mobly_output_path: str = test.current_test_info.output_path  # type: ignore
70        output_path = (Path(mobly_output_path) / '..' / '..').resolve()  # skip test class and method name
71
72    original_setup_test = test.setup_test
73
74    def setup_test(self: BaseTestClass) -> None:
75        global genesis
76        genesis = time.monotonic_ns()
77        process_id = next_id()
78        packets.append(
79            TracePacket(
80                track_descriptor=TrackDescriptor(
81                    uuid=process_id,
82                    process=ProcessDescriptor(
83                        pid=process_id, process_name=f"{self.__class__.__name__}.{self.current_test_info.name}"
84                    ),
85                )
86            )
87        )
88
89        for device in devices:
90            devices_process_id[device] = process_id
91            devices_id[device] = next_id()
92            descriptor = TrackDescriptor(
93                uuid=devices_id[device],
94                parent_uuid=process_id,
95                thread=ThreadDescriptor(thread_name=device.name, pid=process_id, tid=devices_id[device]),
96            )
97            packets.append(TracePacket(track_descriptor=descriptor))
98
99        original_setup_test()
100
101    test.setup_test = types.MethodType(setup_test, test)
102
103
104class AsTrace(Protocol):
105    def as_trace(self) -> TracePacket:
106        ...
107
108
109class Callsite(AsTrace):
110    id_counter = 0
111
112    @classmethod
113    def next_id(cls) -> int:
114        cls.id_counter += 1
115        return cls.id_counter
116
117    def __init__(self, device: PandoraClient, name: Union[bytes, str], message: Any) -> None:
118        self.at = time.monotonic_ns() - genesis
119        self.name = name if isinstance(name, str) else name.decode('utf-8')
120        self.device = device
121        self.message = message
122        self.events: List[CallEvent] = []
123        self.id = Callsite.next_id()
124
125        device.log.info(f"{self}")
126
127    def pretty(self) -> str:
128        name_pretty = self.name[1:].split('.')[-1].replace('/', '.')
129        if self.message is None:
130            return f"%{self.id} {name_pretty}"
131        message_pretty, _ = debug_message(self.message)
132        return f"{name_pretty}({message_pretty})"
133
134    def __str__(self) -> str:
135        return f"{str2color('╭──', self.id)} {self.pretty()}"
136
137    def output(self, message: Any) -> None:
138        self.events.append(CallOutput(self, message))
139
140    def input(self, message: Any) -> None:
141        self.events.append(CallInput(self, message))
142
143    def end(self, message: Any) -> None:
144        global packets
145        if self.device not in devices_id:
146            return
147        self.events.append(CallEnd(self, message))
148        packets.append(self.as_trace())
149        for event in self.events:
150            packets.append(event.as_trace())
151
152    def as_trace(self) -> TracePacket:
153        return TracePacket(
154            timestamp=self.at,
155            track_event=TrackEvent(
156                name=self.name,
157                type=TrackEvent.Type.TYPE_SLICE_BEGIN,
158                track_uuid=devices_id[self.device],
159                debug_annotations=None
160                if self.message is None
161                else [
162                    DebugAnnotation(name=self.message.__class__.__name__, dict_entries=debug_message(self.message)[1])
163                ],
164            ),
165            trusted_packet_sequence_id=devices_process_id[self.device],
166        )
167
168
169class CallEvent(AsTrace):
170    def __init__(self, callsite: Callsite, message: Any) -> None:
171        self.at = time.monotonic_ns() - genesis
172        self.callsite = callsite
173        self.message = message
174
175        callsite.device.log.info(f"{self}")
176
177    def __str__(self) -> str:
178        return f"{str2color('╰──', self.callsite.id)} {self.stringify('⟶ ')}"
179
180    def as_trace(self) -> TracePacket:
181        return TracePacket(
182            timestamp=self.at,
183            track_event=TrackEvent(
184                name=self.callsite.name,
185                type=TrackEvent.Type.TYPE_INSTANT,
186                track_uuid=devices_id[self.callsite.device],
187                debug_annotations=None
188                if self.message is None
189                else [
190                    DebugAnnotation(name=self.message.__class__.__name__, dict_entries=debug_message(self.message)[1])
191                ],
192            ),
193            trusted_packet_sequence_id=devices_process_id[self.callsite.device],
194        )
195
196    def stringify(self, direction: str) -> str:
197        message_pretty = "" if self.message is None else debug_message(self.message)[0]
198        return (
199            str2color(f"[{(self.at - self.callsite.at) / 1000000000:.3f}s]", self.callsite.id)
200            + f" {self.callsite.pretty()} {str2color(direction, self.callsite.id)} ({message_pretty})"
201        )
202
203
204class CallOutput(CallEvent):
205    def __str__(self) -> str:
206        return f"{str2color('├──', self.callsite.id)} {self.stringify('⟶ ')}"
207
208    def as_trace(self) -> TracePacket:
209        return super().as_trace()
210
211
212class CallInput(CallEvent):
213    def __str__(self) -> str:
214        return f"{str2color('├──', self.callsite.id)} {self.stringify('⟵ ')}"
215
216    def as_trace(self) -> TracePacket:
217        return super().as_trace()
218
219
220class CallEnd(CallEvent):
221    def __str__(self) -> str:
222        return f"{str2color('╰──', self.callsite.id)} {self.stringify('⟶ ')}"
223
224    def as_trace(self) -> TracePacket:
225        return TracePacket(
226            timestamp=self.at,
227            track_event=TrackEvent(
228                name=self.callsite.name,
229                type=TrackEvent.Type.TYPE_SLICE_END,
230                track_uuid=devices_id[self.callsite.device],
231                debug_annotations=None
232                if self.message is None
233                else [
234                    DebugAnnotation(name=self.message.__class__.__name__, dict_entries=debug_message(self.message)[1])
235                ],
236            ),
237            trusted_packet_sequence_id=devices_process_id[self.callsite.device],
238        )
239
240
241def debug_value(v: Any) -> Tuple[Any, Dict[str, Any]]:
242    if isinstance(v, any_pb2.Any):
243        return '...', {'string_value': f'{v}'}
244    elif isinstance(v, message.Message):
245        json, entries = debug_message(v)
246        return json, {'dict_entries': entries}
247    elif isinstance(v, bytes):
248        return (v if len(v) < 16 else '...'), {'string_value': f'{v!r}'}
249    elif isinstance(v, bool):
250        return v, {'bool_value': v}
251    elif isinstance(v, int):
252        return v, {'int_value': v}
253    elif isinstance(v, float):
254        return v, {'double_value': v}
255    elif isinstance(v, str):
256        return v, {'string_value': v}
257    try:
258        return v, {'array_values': [DebugAnnotation(**debug_value(x)[1]) for x in v]}  # type: ignore
259    except:
260        return v, {'string_value': f'{v}'}
261
262
263def debug_message(msg: message.Message) -> Tuple[Dict[str, Any], List[DebugAnnotation]]:
264    json: Dict[str, Any] = {}
265    dbga: List[DebugAnnotation] = []
266    for f, v in msg.ListFields():
267        if (
268            isinstance(v, bytes)
269            and len(v) == 6
270            and ('address' in f.name or (f.containing_oneof and 'address' in f.containing_oneof.name))
271        ):
272            addr = ':'.join([f'{x:02X}' for x in v])
273            json[f.name] = addr
274            dbga.append(DebugAnnotation(name=f.name, string_value=addr))
275        else:
276            json_entry, dbga_entry = debug_value(v)
277            json[f.name] = json_entry
278            dbga.append(DebugAnnotation(name=f.name, **dbga_entry))
279    return json, dbga
280
281
282def str2color(s: str, id: int) -> str:
283    CSI = "\x1b["
284    CSI_RESET = CSI + "0m"
285    CSI_BOLD = CSI + "1m"
286    color = ((id * 10) % (230 - 17)) + 17
287    return CSI + ("1;38;5;%dm" % color) + CSI_BOLD + s + CSI_RESET
288