pytorch-lightning
pytorch-lightning copied to clipboard
Dashboard
Description & Motivation
Lightning incorporates a number of 3rd party tools to visualize training (e.g., Tensorboard). However, during debugging or manual hyperparameter search, I mostly use console and would find it useful to be able to quickly judge whether the code is working as expected (typically whether the loss is going down; also whether model inputs and outputs appear reasonable).
Pitch
I want to propose including a console-based dashboard that would incorporate user-defined components, such as progress bar, continuous plotting, or sample outputs from the model. Below is the code I've been using for myself for a while now that provides the described functionality. I attempts to build on top of RichProgressBar
, so hopefully it wouldn't be too hard to issue a pull request if people get interested in this feature.
(For full context, see my repo where this code is actually used for training.)
rich_progress.py
Contains the base class for defining dashboards, BaseDashboard
from threading import RLock
from typing import Any, Callable, Dict, Optional, Union
from lightning_utilities.core.imports import RequirementCache
import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar, RichProgressBarTheme, CustomProgress, MetricsTextColumn
GetTimeCallable = Callable[[], float]
_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")
if _RICH_AVAILABLE:
from rich import get_console, reconfigure
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from rich.progress import ProgressColumn, Task, TaskID
class ProgressWithLayout(CustomProgress):
"""
Overrides ``CustomProgress`` to support passing layout as a renderable.
Presumably `get_renderable` would provide the same functionality without having to copy-paste the entire `__init__` method, but I cannot get it to work.
"""
def __init__(
self,
*columns: Union[str, ProgressColumn],
renderable=None, # new parameter
console: Optional[Console] = None,
auto_refresh: bool = True,
refresh_per_second: float = 10,
speed_estimate_period: float = 30.0,
transient: bool = False,
redirect_stdout: bool = True,
redirect_stderr: bool = True,
get_time: Optional[GetTimeCallable] = None,
disable: bool = False,
expand: bool = False,
) -> None:
assert refresh_per_second > 0, "refresh_per_second must be > 0"
self._lock = RLock()
self.columns = columns or self.get_default_columns()
self.speed_estimate_period = speed_estimate_period
self.disable = disable
self.expand = expand
self._tasks: Dict[TaskID, Task] = {}
self._task_index: TaskID = TaskID(0)
self.live = Live(
renderable=renderable, # pass it here
console=console or get_console(),
auto_refresh=auto_refresh,
refresh_per_second=refresh_per_second,
transient=transient,
redirect_stdout=redirect_stdout,
redirect_stderr=redirect_stderr,
)
self.get_time = get_time or self.console.get_time
self.print = self.console.print
self.log = self.console.log
class BaseDashboard(RichProgressBar):
"""Create a dashboard with a progress bar with `rich text formatting <https://github.com/Textualize/rich>`_ and any other rich components.
This is a base class that should be inherited when composing your own dashboard. Your own subclass must define `layout` (`rich.layout.Layout`) that is split into subpanels (with unique names) and a `components` dictionary that for each of these names defines a subclass of `rich.jupyter.JupyterMixin` where the component's behavior is defined.
Args:
refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
Set it to ``0`` to disable the display.
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
theme: Contains styles used to stylize the progress bar.
console_kwargs: Args for constructing a `Console`
Raises:
ModuleNotFoundError:
If required `rich` package is not installed on the device.
Note:
PyCharm users will need to enable “emulate terminal” in output console option in
run/debug configuration to see styled output.
Reference: https://rich.readthedocs.io/en/latest/introduction.html#requirements
"""
def __init__(
self,
refresh_rate: int = 1,
leave: bool = False,
theme: RichProgressBarTheme = RichProgressBarTheme(),
console_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(refresh_rate=refresh_rate,
leave=leave,
theme=theme,
console_kwargs=console_kwargs
)
self.layout: Optional[Layout] = None
self.components: Dict[str, Any] = {}
def _init_progress(self, trainer: "pl.Trainer") -> None:
if self.is_enabled and (self.progress is None or self._progress_stopped):
self._reset_progress_bar_ids()
reconfigure(**self._console_kwargs)
self._console = get_console()
self._console.clear_live()
self._metric_component = MetricsTextColumn(
trainer,
self.theme.metrics,
self.theme.metrics_text_delimiter,
self.theme.metrics_format,
)
self.progress = ProgressWithLayout(
*self.configure_columns(trainer),
self._metric_component,
auto_refresh=False,
disable=self.is_disabled,
console=self._console,
renderable=self.layout
)
self.progress.start()
# progress has started
self._progress_stopped = False
def _update_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
metrics = self.get_metrics(trainer, pl_module)
if self._metric_component:
self._metric_component.update(metrics)
for component in self.components.values():
component.update(trainer)
dashboard.py
Contains a used-defined Dashboard
with a progress bar, loss plot, and text samples output. Dashboard
is passed to Trainer
as any other callable.
from typing import Any, Dict, Optional
from collections import deque, defaultdict
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme, _detect_light_colab_theme
from rich.ansi import AnsiDecoder
from rich.console import Group
from rich.jupyter import JupyterMixin
from rich.layout import Layout
from rich.progress import TaskID
from rich.table import Table
import plotext as plt
import rich_progress
class Dashboard(rich_progress.BaseDashboard):
def __init__(self,
plot_window: Optional[int] = None,
refresh_rate: int = 1,
leave: bool = False,
theme: RichProgressBarTheme = RichProgressBarTheme(),
console_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(refresh_rate=refresh_rate,
leave=leave,
theme=theme,
console_kwargs=console_kwargs
)
self.components = {
'plot': Plot(window=plot_window),
'samples': TextSamples()
}
self.layout: Layout = Layout(name='root')
self.layout.split(
Layout(name='plot', ratio=1),
Layout(name='progress_bar', size=1),
Layout(name='samples', size=6),
)
def _update(self,
progress_bar_id: Optional['TaskID'],
current: int,
visible: bool = True
) -> None:
super()._update(
progress_bar_id=progress_bar_id,
current=current,
visible=visible
)
# make room for multiple progress bars
self.layout['progress_bar'].size = sum([t.visible for t in self.progress.tasks])
self.layout['progress_bar'].update(self.progress)
for name, component in self.components.items():
self.layout[name].update(component)
self.refresh()
class Plot(JupyterMixin):
def __init__(self, window: Optional[int] = None):
self.window = window
self._metrics = defaultdict(lambda: [
deque(maxlen=window), deque(maxlen=window)])
self.decoder = AnsiDecoder()
def __rich_console__(self, console, options):
self.width = options.max_width or console.width
self.height = options.height or console.height
canvas = self.make_plot()
self.rich_canvas = Group(*self.decoder.decode(canvas))
yield self.rich_canvas
def update(self, trainer) -> None:
for name, value in trainer.progress_bar_metrics.items():
self._metrics[name][0].append(trainer.global_step)
self._metrics[name][1].append(value)
def make_plot(self):
plt.clear_data()
plt.clear_figure()
xs = []
ys = []
for name, (x, y) in self._metrics.items():
plt.plot(x, y, label=name)
xs.extend(list(x))
if name.startswith('val'):
ys.extend(list(y))
if not _detect_light_colab_theme():
plt.theme('dark')
plt.plotsize(self.width, self.height)
if self.window is not None:
start = max(0, max(xs, default=0) - self.window)
plt.xlim(start, start + self.window)
return plt.build()
class TextSamples(JupyterMixin):
def __init__(self):
self.samples = []
def __rich_console__(self, console, options):
grid = Table.grid()
for sample in self.samples:
grid.add_row(sample)
yield grid
def update(self, trainer) -> None:
self.samples = trainer.loggers[0].samples
loggers.py
Since dashboard is displaying sample output from the model, I additionally define a TextLogger
class that stores the latest batch, so that it can be printed in the dashboard.
from lightning.pytorch.loggers.logger import DummyLogger
from lightning.fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger
from lightning.fabric.utilities.types import _PATH
class TextLogger(DummyLogger):
def __init__(self, save_dir: _PATH):
# A quick hack to the version number
self._version = FabricCSVLogger(root_dir=save_dir).version
self.samples = []
@property
def version(self):
return self._version
def log_samples(self, *samples):
self.samples = samples
Alternatives
No response
Additional context
This is how it looks like in practice. Train and validation losses are continuously updated and sample outputs are printed (in this case, I was training nanoGPT to remove citations from academic papers and you can see how the output does not contain (Plyaskine et al., 1991)
).
cc @borda