pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Dashboard

Open qbilius opened this issue 6 months ago • 3 comments

Description & Motivation

Lightning incorporates a number of 3rd party tools to visualize training (e.g., Tensorboard). However, during debugging or manual hyperparameter search, I mostly use console and would find it useful to be able to quickly judge whether the code is working as expected (typically whether the loss is going down; also whether model inputs and outputs appear reasonable).

Pitch

I want to propose including a console-based dashboard that would incorporate user-defined components, such as progress bar, continuous plotting, or sample outputs from the model. Below is the code I've been using for myself for a while now that provides the described functionality. I attempts to build on top of RichProgressBar, so hopefully it wouldn't be too hard to issue a pull request if people get interested in this feature.

(For full context, see my repo where this code is actually used for training.)

rich_progress.py

Contains the base class for defining dashboards, BaseDashboard

from threading import RLock
from typing import Any, Callable, Dict, Optional, Union

from lightning_utilities.core.imports import RequirementCache

import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar, RichProgressBarTheme, CustomProgress, MetricsTextColumn


GetTimeCallable = Callable[[], float]
_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")


if _RICH_AVAILABLE:
    from rich import get_console, reconfigure
    from rich.console import Console
    from rich.layout import Layout
    from rich.live import Live
    from rich.progress import ProgressColumn, Task, TaskID

    class ProgressWithLayout(CustomProgress):
        """
        Overrides ``CustomProgress`` to support passing layout as a renderable.

        Presumably `get_renderable` would provide the same functionality without having to copy-paste the entire `__init__` method, but I cannot get it to work.
        """

        def __init__(
            self,
            *columns: Union[str, ProgressColumn],
            renderable=None,  # new parameter
            console: Optional[Console] = None,
            auto_refresh: bool = True,
            refresh_per_second: float = 10,
            speed_estimate_period: float = 30.0,
            transient: bool = False,
            redirect_stdout: bool = True,
            redirect_stderr: bool = True,
            get_time: Optional[GetTimeCallable] = None,
            disable: bool = False,
            expand: bool = False,
        ) -> None:
            assert refresh_per_second > 0, "refresh_per_second must be > 0"
            self._lock = RLock()
            self.columns = columns or self.get_default_columns()
            self.speed_estimate_period = speed_estimate_period

            self.disable = disable
            self.expand = expand
            self._tasks: Dict[TaskID, Task] = {}
            self._task_index: TaskID = TaskID(0)
            self.live = Live(
                renderable=renderable,  # pass it here
                console=console or get_console(),
                auto_refresh=auto_refresh,
                refresh_per_second=refresh_per_second,
                transient=transient,
                redirect_stdout=redirect_stdout,
                redirect_stderr=redirect_stderr,
            )
            self.get_time = get_time or self.console.get_time
            self.print = self.console.print
            self.log = self.console.log

    class BaseDashboard(RichProgressBar):

        """Create a dashboard with a progress bar with `rich text formatting <https://github.com/Textualize/rich>`_ and any other rich components.

        This is a base class that should be inherited when composing your own dashboard. Your own subclass must define `layout` (`rich.layout.Layout`) that is split into subpanels (with unique names) and a `components` dictionary that for each of these names defines a subclass of `rich.jupyter.JupyterMixin` where the component's behavior is defined.

        Args:
            refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
                Set it to ``0`` to disable the display.
            leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
            theme: Contains styles used to stylize the progress bar.
            console_kwargs: Args for constructing a `Console`

        Raises:
            ModuleNotFoundError:
                If required `rich` package is not installed on the device.

        Note:
            PyCharm users will need to enable “emulate terminal” in output console option in
            run/debug configuration to see styled output.
            Reference: https://rich.readthedocs.io/en/latest/introduction.html#requirements

        """

        def __init__(
            self,
            refresh_rate: int = 1,
            leave: bool = False,
            theme: RichProgressBarTheme = RichProgressBarTheme(),
            console_kwargs: Optional[Dict[str, Any]] = None,
        ) -> None:
            super().__init__(refresh_rate=refresh_rate,
                             leave=leave,
                             theme=theme,
                             console_kwargs=console_kwargs
                             )
            self.layout: Optional[Layout] = None
            self.components: Dict[str, Any] = {}

        def _init_progress(self, trainer: "pl.Trainer") -> None:
            if self.is_enabled and (self.progress is None or self._progress_stopped):
                self._reset_progress_bar_ids()
                reconfigure(**self._console_kwargs)
                self._console = get_console()
                self._console.clear_live()
                self._metric_component = MetricsTextColumn(
                    trainer,
                    self.theme.metrics,
                    self.theme.metrics_text_delimiter,
                    self.theme.metrics_format,
                )
                self.progress = ProgressWithLayout(
                    *self.configure_columns(trainer),
                    self._metric_component,
                    auto_refresh=False,
                    disable=self.is_disabled,
                    console=self._console,
                    renderable=self.layout
                )
                self.progress.start()
                # progress has started
                self._progress_stopped = False

        def _update_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
            metrics = self.get_metrics(trainer, pl_module)
            if self._metric_component:
                self._metric_component.update(metrics)

            for component in self.components.values():
                component.update(trainer)

dashboard.py

Contains a used-defined Dashboard with a progress bar, loss plot, and text samples output. Dashboard is passed to Trainer as any other callable.

from typing import Any, Dict, Optional
from collections import deque, defaultdict

from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme, _detect_light_colab_theme

from rich.ansi import AnsiDecoder
from rich.console import Group
from rich.jupyter import JupyterMixin
from rich.layout import Layout
from rich.progress import TaskID
from rich.table import Table

import plotext as plt

import rich_progress


class Dashboard(rich_progress.BaseDashboard):

    def __init__(self,
                 plot_window: Optional[int] = None,
                 refresh_rate: int = 1,
                 leave: bool = False,
                 theme: RichProgressBarTheme = RichProgressBarTheme(),
                 console_kwargs: Optional[Dict[str, Any]] = None,
                 ):
        super().__init__(refresh_rate=refresh_rate,
                         leave=leave,
                         theme=theme,
                         console_kwargs=console_kwargs
                         )

        self.components = {
            'plot': Plot(window=plot_window),
            'samples': TextSamples()
        }
        self.layout: Layout = Layout(name='root')
        self.layout.split(
            Layout(name='plot', ratio=1),
            Layout(name='progress_bar', size=1),
            Layout(name='samples', size=6),
        )

    def _update(self,
                progress_bar_id: Optional['TaskID'],
                current: int,
                visible: bool = True
                ) -> None:

        super()._update(
            progress_bar_id=progress_bar_id,
            current=current,
            visible=visible
        )

        # make room for multiple progress bars
        self.layout['progress_bar'].size = sum([t.visible for t in self.progress.tasks])
        self.layout['progress_bar'].update(self.progress)

        for name, component in self.components.items():
            self.layout[name].update(component)

        self.refresh()


class Plot(JupyterMixin):

    def __init__(self, window: Optional[int] = None):
        self.window = window

        self._metrics = defaultdict(lambda: [
            deque(maxlen=window), deque(maxlen=window)])
        self.decoder = AnsiDecoder()

    def __rich_console__(self, console, options):
        self.width = options.max_width or console.width
        self.height = options.height or console.height
        canvas = self.make_plot()
        self.rich_canvas = Group(*self.decoder.decode(canvas))
        yield self.rich_canvas

    def update(self, trainer) -> None:
        for name, value in trainer.progress_bar_metrics.items():
            self._metrics[name][0].append(trainer.global_step)
            self._metrics[name][1].append(value)

    def make_plot(self):
        plt.clear_data()
        plt.clear_figure()

        xs = []
        ys = []
        for name, (x, y) in self._metrics.items():
            plt.plot(x, y, label=name)
            xs.extend(list(x))
            if name.startswith('val'):
                ys.extend(list(y))

        if not _detect_light_colab_theme():
            plt.theme('dark')
        plt.plotsize(self.width, self.height)

        if self.window is not None:
            start = max(0, max(xs, default=0) - self.window)
            plt.xlim(start, start + self.window)

        return plt.build()


class TextSamples(JupyterMixin):

    def __init__(self):
        self.samples = []

    def __rich_console__(self, console, options):
        grid = Table.grid()
        for sample in self.samples:
            grid.add_row(sample)
        yield grid

    def update(self, trainer) -> None:
        self.samples = trainer.loggers[0].samples

loggers.py

Since dashboard is displaying sample output from the model, I additionally define a TextLogger class that stores the latest batch, so that it can be printed in the dashboard.

from lightning.pytorch.loggers.logger import DummyLogger
from lightning.fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger

from lightning.fabric.utilities.types import _PATH


class TextLogger(DummyLogger):

    def __init__(self, save_dir: _PATH):
        # A quick hack to the version number
        self._version = FabricCSVLogger(root_dir=save_dir).version
        self.samples = []

    @property
    def version(self):
        return self._version

    def log_samples(self, *samples):
        self.samples = samples

Alternatives

No response

Additional context

This is how it looks like in practice. Train and validation losses are continuously updated and sample outputs are printed (in this case, I was training nanoGPT to remove citations from academic papers and you can see how the output does not contain (Plyaskine et al., 1991)). Screen Shot 2024-08-25 at 17 59 47

cc @borda

qbilius avatar Aug 25 '24 15:08 qbilius