xref: /aosp_15_r20/external/pigweed/pw_console/py/pw_console/progress_bar/progress_bar_state.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Pigweed Console progress bar task state."""
15
16from contextvars import ContextVar
17import copy
18from dataclasses import dataclass, field
19import signal
20
21from prompt_toolkit.application import get_app_or_none
22from prompt_toolkit.shortcuts import ProgressBar
23from prompt_toolkit.shortcuts.progress_bar import formatters
24
25from pw_console.progress_bar.progress_bar_impl import (
26    IterationsPerSecondIfNotHidden,
27    ProgressBarImpl,
28    TextIfNotHidden,
29    TimeLeftIfNotHidden,
30)
31from pw_console.progress_bar.progress_bar_task_counter import (
32    ProgressBarTaskCounter,
33)
34from pw_console.style import generate_styles
35
36CUSTOM_FORMATTERS = [
37    formatters.Label(suffix=': '),
38    formatters.Rainbow(
39        formatters.Bar(start='|Pigw', end='|', sym_a='e', sym_b='d!', sym_c=' ')
40    ),
41    formatters.Text(' '),
42    formatters.Progress(),
43    formatters.Text(' ['),
44    formatters.Percentage(),
45    formatters.Text('] in '),
46    formatters.TimeElapsed(),
47    TextIfNotHidden(' ('),
48    IterationsPerSecondIfNotHidden(),
49    TextIfNotHidden('/s, eta: '),
50    TimeLeftIfNotHidden(),
51    TextIfNotHidden(')'),
52]
53
54
55def prompt_toolkit_app_running() -> bool:
56    existing_app = get_app_or_none()
57    if existing_app:
58        return True
59    return False
60
61
62@dataclass
63class ProgressBarState:
64    """Pigweed Console wide state for all repl progress bars.
65
66    An instance of this class is intended to be a global variable."""
67
68    tasks: dict[str, ProgressBarTaskCounter] = field(default_factory=dict)
69    instance: ProgressBar | ProgressBarImpl | None = None
70
71    def _install_sigint_handler(self) -> None:
72        """Add ctrl-c handling if not running inside pw_console"""
73
74        def handle_sigint(_signum, _frame):
75            # Shut down the ProgressBar prompt_toolkit application
76            prog_bar = self.instance
77            if prog_bar is not None and hasattr(prog_bar, '__exit__'):
78                prog_bar.__exit__()  # pylint: disable=unnecessary-dunder-call
79            raise KeyboardInterrupt
80
81        signal.signal(signal.SIGINT, handle_sigint)
82
83    def startup_progress_bar_impl(self):
84        prog_bar = self.instance
85        if not prog_bar:
86            if prompt_toolkit_app_running():
87                prog_bar = ProgressBarImpl(
88                    style=get_app_or_none().style, formatters=CUSTOM_FORMATTERS
89                )
90            else:
91                self._install_sigint_handler()
92                prog_bar = ProgressBar(
93                    style=generate_styles(), formatters=CUSTOM_FORMATTERS
94                )
95                # Start the ProgressBar prompt_toolkit application in a separate
96                # thread.
97                prog_bar.__enter__()  # pylint: disable=unnecessary-dunder-call
98            self.instance = prog_bar
99        return self.instance
100
101    def cleanup_finished_tasks(self) -> None:
102        for task_name in copy.copy(list(self.tasks.keys())):
103            task = self.tasks[task_name]
104            if task.completed or task.canceled:
105                ptc = task.prompt_toolkit_counter
106                self.tasks.pop(task_name, None)
107                if (
108                    self.instance
109                    and self.instance.counters
110                    and ptc in self.instance.counters
111                ):
112                    self.instance.counters.remove(ptc)
113
114    @property
115    def all_tasks_complete(self) -> bool:
116        tasks_complete = [
117            task.completed or task.canceled
118            for _task_name, task in self.tasks.items()
119        ]
120        self.cleanup_finished_tasks()
121        return all(tasks_complete)
122
123    def cancel_all_tasks(self):
124        self.tasks = {}
125        if self.instance is not None:
126            self.instance.counters = []
127
128    def get_container(self):
129        prog_bar = self.instance
130        if prog_bar is not None and hasattr(prog_bar, '__pt_container__'):
131            return prog_bar.__pt_container__()
132        return None
133
134
135TASKS_CONTEXTVAR = ContextVar(
136    'pw_console_progress_bar_tasks', default=ProgressBarState()
137)
138