1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Pigweed Console ProgressBar implementation. 15 16Designed to be embedded in an existing prompt_toolkit full screen 17application.""" 18 19from __future__ import annotations 20 21from datetime import datetime, timedelta 22import functools 23from typing import ( 24 Iterable, 25 Sequence, 26) 27 28from prompt_toolkit.filters import Condition 29from prompt_toolkit.formatted_text import AnyFormattedText 30from prompt_toolkit.layout import ( 31 ConditionalContainer, 32 FormattedTextControl, 33 HSplit, 34 VSplit, 35 Window, 36) 37from prompt_toolkit.layout.dimension import AnyDimension, D 38from prompt_toolkit.styles import BaseStyle 39 40from prompt_toolkit.shortcuts.progress_bar import ( 41 ProgressBar, 42 ProgressBarCounter, 43) 44from prompt_toolkit.shortcuts.progress_bar.base import _ProgressControl 45from prompt_toolkit.shortcuts.progress_bar.formatters import ( 46 Formatter, 47 IterationsPerSecond, 48 Text, 49 TimeLeft, 50 create_default_formatters, 51) 52 53 54class TextIfNotHidden(Text): 55 def format( 56 self, 57 progress_bar: ProgressBar, 58 progress: ProgressBarCounter[object], 59 width: int, 60 ) -> AnyFormattedText: 61 formatted_text = super().format(progress_bar, progress, width) 62 if hasattr(progress, 'hide_eta') and progress.hide_eta: # type: ignore 63 formatted_text = [('', ' ' * width)] 64 return formatted_text 65 66 67class IterationsPerSecondIfNotHidden(IterationsPerSecond): 68 def format( 69 self, 70 progress_bar: ProgressBar, 71 progress: ProgressBarCounter[object], 72 width: int, 73 ) -> AnyFormattedText: 74 formatted_text = super().format(progress_bar, progress, width) 75 if hasattr(progress, 'hide_eta') and progress.hide_eta: # type: ignore 76 formatted_text = [('class:iterations-per-second', ' ' * width)] 77 return formatted_text 78 79 80class TimeLeftIfNotHidden(TimeLeft): 81 def format( 82 self, 83 progress_bar: ProgressBar, 84 progress: ProgressBarCounter[object], 85 width: int, 86 ) -> AnyFormattedText: 87 formatted_text = super().format(progress_bar, progress, width) 88 if hasattr(progress, 'hide_eta') and progress.hide_eta: # type: ignore 89 formatted_text = [('class:time-left', ' ' * width)] 90 return formatted_text 91 92 93class ProgressBarImpl: 94 """ProgressBar for rendering in an existing prompt_toolkit application.""" 95 96 def __init__( 97 self, 98 title: AnyFormattedText = None, 99 formatters: Sequence[Formatter] | None = None, 100 style: BaseStyle | None = None, 101 ) -> None: 102 self.title = title 103 self.formatters = formatters or create_default_formatters() 104 self.counters: list[ProgressBarCounter[object]] = [] 105 self.style = style 106 107 # Create UI Application. 108 title_toolbar = ConditionalContainer( 109 Window( 110 FormattedTextControl(lambda: self.title), 111 height=1, 112 style='class:progressbar,title', 113 ), 114 filter=Condition(lambda: self.title is not None), 115 ) 116 117 def width_for_formatter(formatter: Formatter) -> AnyDimension: 118 # Needs to be passed as callable (partial) to the 'width' 119 # parameter, because we want to call it on every resize. 120 return formatter.get_width(progress_bar=self) # type: ignore 121 122 progress_controls = [ 123 Window( 124 content=_ProgressControl(self, f, None), # type: ignore 125 width=functools.partial(width_for_formatter, f), 126 ) 127 for f in self.formatters 128 ] 129 130 self.container = HSplit( 131 [ 132 title_toolbar, 133 VSplit( 134 progress_controls, 135 height=lambda: D( 136 min=len(self.counters), max=len(self.counters) 137 ), 138 ), 139 ] 140 ) 141 142 def __pt_container__(self): 143 return self.container 144 145 def __exit__(self, *a: object) -> None: 146 pass 147 148 def __call__( 149 self, 150 data: Iterable | None = None, 151 label: AnyFormattedText = '', 152 remove_when_done: bool = False, 153 total: int | None = None, 154 ) -> ProgressBarCounter: 155 """ 156 Start a new counter. 157 158 :param label: Title text or description for this progress. (This can be 159 formatted text as well). 160 :param remove_when_done: When `True`, hide this progress bar. 161 :param total: Specify the maximum value if it can't be calculated by 162 calling ``len``. 163 """ 164 counter = ProgressBarCounter( 165 self, # type: ignore 166 data, 167 label=label, 168 remove_when_done=remove_when_done, 169 total=total, 170 ) 171 # Ensure the elapsed time for the progress counter isn't ever 172 # zero by making the start time one second in the past. 173 counter.start_time = datetime.now() + timedelta(seconds=-1) 174 self.counters.append(counter) 175 return counter 176