sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Labeled Samples object for less verbose plotting

Open janfb opened this issue 4 months ago • 0 comments

Context

  • pairplot refactoring in #1631 introduces typed options and upper/lower/diag API; offdiag is deprecated.
  • Pain point remains: users still pass many repeated args (labels, limits, ticks) in tutorials (see docs/advanced_tutorials/17_plotting_functionality.ipynb).
  • Proposal below reduces boilerplate and error-prone dimension bookkeeping by encapsulating per-dataset metadata.

Proposal

LabeledSamples dataclass (or NamedTuple) in sbi.analysis

  • Fields:
    • data: np.ndarray | torch.Tensor, shape (N, D)
    • dim_labels: list[str] length D (optional; defaults to ["θ1", "θ2", ...])
    • Optional: ticks: Optional[List[Tuple[float, float]]], limits: Optional[List[Tuple[float, float]]]
  • Contract:
    • pairplot and marginal_plot accept either raw arrays or LabeledSamples. When provided:
      • Use dim_labels if labels arg not supplied.
      • Use limits inline if limits arg not supplied.
      • Use ticks inline if ticks arg not supplied.
    • Support list of LabeledSamples to overlay multiple sources; FigOptions.samples_labels defaults to the .name or generated labels if present.
  • Interop:
    • A from_xarray(dataarray) constructor could leverage xarray’s named dimensions and coordinates (optional; follow-up).

Migration

  • Backward compatible: continue supporting raw arrays; adopt container to reduce passing repeated labels/ticks in every call.
  • Precedence: explicit function args override container metadata.

Benefits

  • Encapsulates dimension labeling and axis metadata in one place.
  • Simplifies notebook and pipeline code; fewer chances for label/order mismatches.
  • Autocomplete-friendly, aligns with stronger typing direction from this branch.

Minimal sketch

from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch

ArrayLike = Union[np.ndarray, torch.Tensor]

@dataclass(frozen=True)
class LabeledSamples:
    name: Optional[str] = None
    data: ArrayLike = None  # shape (N, D)
    dim_labels: Optional[List[str]] = None
    limits: Optional[List[Tuple[float, float]]] = None
    ticks: Optional[List[Tuple[float, float]]] = None

# prepare_for_plot() detects LabeledSamples, unpacks metadata,
# and prefers explicit function args over container fields.

More context:

  • open questions

    • Defaults for samples-level labels/colors when mixing raw arrays and LabeledSamples.
    • Whether limits are per-sample-set or global; if mixed, choose intersection/union or prefer first.
    • Validation surface: check D consistency across overlaid datasets and metadata lengths.
    • Immutability (frozen=True) and torch/numpy conversion policy at boundaries.
    • Export location and public API: sbi.analysis export for discoverability.
  • Impact on docs/tests

    • Update 17_plotting_functionality.ipynb to show both raw arrays and LabeledSamples flows.
    • Add smoke tests for container + precedence rules; type tests for dim_labels length and limits shape.
    • Re-export LabeledSamples from sbi.analysis in init.py.
  • Suggested steps (one or several PRs depending on size)

    • Introduce LabeledSamples, wire into prepare_for_plot, add precedence + validation.
    • Update tutorials and add unit tests; add from_xarray helper (optional).
    • Consider deprecation notice in docs encouraging container for repeated plotting scenarios.

janfb avatar Sep 18 '25 06:09 janfb