xref: /aosp_15_r20/external/pigweed/pw_console/py/pw_console/progress_bar/progress_bar_impl.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Pigweed Console ProgressBar implementation.
15
16Designed to be embedded in an existing prompt_toolkit full screen
17application."""
18
19from __future__ import annotations
20
21from datetime import datetime, timedelta
22import functools
23from typing import (
24    Iterable,
25    Sequence,
26)
27
28from prompt_toolkit.filters import Condition
29from prompt_toolkit.formatted_text import AnyFormattedText
30from prompt_toolkit.layout import (
31    ConditionalContainer,
32    FormattedTextControl,
33    HSplit,
34    VSplit,
35    Window,
36)
37from prompt_toolkit.layout.dimension import AnyDimension, D
38from prompt_toolkit.styles import BaseStyle
39
40from prompt_toolkit.shortcuts.progress_bar import (
41    ProgressBar,
42    ProgressBarCounter,
43)
44from prompt_toolkit.shortcuts.progress_bar.base import _ProgressControl
45from prompt_toolkit.shortcuts.progress_bar.formatters import (
46    Formatter,
47    IterationsPerSecond,
48    Text,
49    TimeLeft,
50    create_default_formatters,
51)
52
53
54class TextIfNotHidden(Text):
55    def format(
56        self,
57        progress_bar: ProgressBar,
58        progress: ProgressBarCounter[object],
59        width: int,
60    ) -> AnyFormattedText:
61        formatted_text = super().format(progress_bar, progress, width)
62        if hasattr(progress, 'hide_eta') and progress.hide_eta:  # type: ignore
63            formatted_text = [('', ' ' * width)]
64        return formatted_text
65
66
67class IterationsPerSecondIfNotHidden(IterationsPerSecond):
68    def format(
69        self,
70        progress_bar: ProgressBar,
71        progress: ProgressBarCounter[object],
72        width: int,
73    ) -> AnyFormattedText:
74        formatted_text = super().format(progress_bar, progress, width)
75        if hasattr(progress, 'hide_eta') and progress.hide_eta:  # type: ignore
76            formatted_text = [('class:iterations-per-second', ' ' * width)]
77        return formatted_text
78
79
80class TimeLeftIfNotHidden(TimeLeft):
81    def format(
82        self,
83        progress_bar: ProgressBar,
84        progress: ProgressBarCounter[object],
85        width: int,
86    ) -> AnyFormattedText:
87        formatted_text = super().format(progress_bar, progress, width)
88        if hasattr(progress, 'hide_eta') and progress.hide_eta:  # type: ignore
89            formatted_text = [('class:time-left', ' ' * width)]
90        return formatted_text
91
92
93class ProgressBarImpl:
94    """ProgressBar for rendering in an existing prompt_toolkit application."""
95
96    def __init__(
97        self,
98        title: AnyFormattedText = None,
99        formatters: Sequence[Formatter] | None = None,
100        style: BaseStyle | None = None,
101    ) -> None:
102        self.title = title
103        self.formatters = formatters or create_default_formatters()
104        self.counters: list[ProgressBarCounter[object]] = []
105        self.style = style
106
107        # Create UI Application.
108        title_toolbar = ConditionalContainer(
109            Window(
110                FormattedTextControl(lambda: self.title),
111                height=1,
112                style='class:progressbar,title',
113            ),
114            filter=Condition(lambda: self.title is not None),
115        )
116
117        def width_for_formatter(formatter: Formatter) -> AnyDimension:
118            # Needs to be passed as callable (partial) to the 'width'
119            # parameter, because we want to call it on every resize.
120            return formatter.get_width(progress_bar=self)  # type: ignore
121
122        progress_controls = [
123            Window(
124                content=_ProgressControl(self, f, None),  # type: ignore
125                width=functools.partial(width_for_formatter, f),
126            )
127            for f in self.formatters
128        ]
129
130        self.container = HSplit(
131            [
132                title_toolbar,
133                VSplit(
134                    progress_controls,
135                    height=lambda: D(
136                        min=len(self.counters), max=len(self.counters)
137                    ),
138                ),
139            ]
140        )
141
142    def __pt_container__(self):
143        return self.container
144
145    def __exit__(self, *a: object) -> None:
146        pass
147
148    def __call__(
149        self,
150        data: Iterable | None = None,
151        label: AnyFormattedText = '',
152        remove_when_done: bool = False,
153        total: int | None = None,
154    ) -> ProgressBarCounter:
155        """
156        Start a new counter.
157
158        :param label: Title text or description for this progress. (This can be
159            formatted text as well).
160        :param remove_when_done: When `True`, hide this progress bar.
161        :param total: Specify the maximum value if it can't be calculated by
162            calling ``len``.
163        """
164        counter = ProgressBarCounter(
165            self,  # type: ignore
166            data,
167            label=label,
168            remove_when_done=remove_when_done,
169            total=total,
170        )
171        # Ensure the elapsed time for the progress counter isn't ever
172        # zero by making the start time one second in the past.
173        counter.start_time = datetime.now() + timedelta(seconds=-1)
174        self.counters.append(counter)
175        return counter
176