TabPFN icon indicating copy to clipboard operation
TabPFN copied to clipboard

Document bar distribution: bimodal / full uncertainty regression outputs and creating example

Open noahho opened this issue 11 months ago • 0 comments

Example usage and visualization and of full uncertainty mode:

reg = TabPFNRegressor()
reg.fit(x, y_noisy)
preds = reg.predict(x_test, output_type="full")

fig, ax = plt.subplots(1, figsize=(12,6))

N = 10 #number of samples to visualize

plot_bar_distribution(ax, torch.tensor(x)[0:N], preds["criterion"].borders, preds["logits"][0:N])
ax.set_ylim(-1, 10)
import matplotlib.patches as patches
import seaborn as sns
import torch
import warnings
from matplotlib.collections import PatchCollection


def get_rect(coord, height, width):
    rect = patches.Rectangle(coord, height, width)

    return rect


def heatmap_with_box_sizes(
    ax,
    data: torch.Tensor,
    x_starts,
    x_ends,
    y_starts,
    y_ends,
    palette=None,
    set_lims=True,
    threshold_i=0.0,  # Threshold intensity (not probability)
    y_min=None,
    y_max=None,
    transpose=False,
    per_col_normalize=False,
):
    """
    Beware all x and y arrays should be sorted from small to large and the data will appear in that same order: Small indexes map to lower x/y-axis values.
    """
    if palette is None:
        palette = sns.cubehelix_palette(
            start=2.9,
            rot=0.0,
            dark=0.6,
            light=1,
            gamma=4.0,
            hue=9.0,
            as_cmap=True
            # use gamma to control how much of the spectrum is saturated, more gamma -> bigger part that is saturated
            # use dark to control how dark the darkest part is, a higher value will make the darkest part lighter
        )

    if set_lims:
        ax.set_xlim(x_starts[0], x_ends[-1])
        if not y_min or not y_max:
            assert (
                len(y_starts.shape) == 1
            ), "If y_min and y_max are not provided, y_starts should be 1D. Please set y_min and y_max manually."
            ax.set_ylim(y_starts[0], y_ends[-1])
        else:
            ax.set_ylim(y_min, y_max)

    if per_col_normalize:
        data = (data - data.min(0, keepdim=True).values) / (
            data.max(0, keepdim=True).values - data.min(0, keepdim=True).values
        )
    else:
       data = (data - data.min()) / (data.max() - data.min())
    rects, colors = [], []

    assert y_ends.shape == y_starts.shape
    if len(y_starts.shape) == 1:
        y_starts = y_starts.unsqueeze(0).expand(len(x_starts), -1)
        y_ends = y_ends.unsqueeze(0).expand(len(x_starts), -1)

    for col_i, (col_start, col_end) in enumerate(zip(x_starts, x_ends)):
        for row_i, (row_start, row_end) in enumerate(
            zip(y_starts[col_i], y_ends[col_i])
        ):
            intensity = data[row_i, col_i].item()
            intensity = max(0.0, (intensity - threshold_i)) / (
                1 - threshold_i
            )  # Start with intensity at the threshold value (smoother visualization)

            if intensity <= 0:
                continue

            if y_max and y_min and (row_start > y_max or row_end < y_min):
                continue

            if row_start >= row_end or col_start >= col_end:
                continue

            if palette(intensity) == (1.0, 1.0, 1.0, 1.0):
                continue

            # print(row_start, row_end, col_start, col_end, intensity, palette(intensity))

            # print(intensity, palette(intensity), row_start, row_end)

            # e.g. data[row_i, col_i].item() / col_end - col_start (or row_end - row_start)
            if transpose:
                rects += [
                    get_rect(
                        (row_start, col_start), row_end - row_start, col_end - col_start
                    )
                ]
            else:
                rects += [
                    get_rect(
                        (col_start, row_start), col_end - col_start, row_end - row_start
                    )
                ]
            colors += [palette(intensity)]
    rect_collection = PatchCollection(
        rects, facecolors=colors, edgecolor="none", linewidth=1
    )
    ax.add_collection(rect_collection)
    ax.set_rasterized(True)


def plot_bar_distribution(
    ax,
    x: torch.Tensor,
    bar_borders: torch.Tensor,
    logits: torch.Tensor,
    merge_bars=None,
    restrict_to_range=None,
    plot_log_probs=False,
    **kwargs,
):
    """
    :param ax: A matplotlib axis, you can get one with: `fig, ax = pyplot.subplots()`
    :param x: The positions to plot on the x-axis, this is your x, but it has to be 1d with shape (num_examples,)
    :param bar_borders: The borders of your bar distritbuion, they can be obtained at transformer_model.criterion.borders
    :param logits: A tensor of shape (num_examples, len(bar_borders)-1) that comes straight out of the model
    :param merge_bars: Number of bars to merge into one. If None, no merging is done. This speeds up the plotting.
    :param restrict_to_range: A tuple of (min_y, max_y) that restricts the y-axis to this range. If None, no restriction is done.
    :param plot_log_probs: If True, the log probabilities are plotted instead of the probabilities. This is useful if some probabilities are really high.
    :param kwargs:
    :return:
    """
    x = x.squeeze()
    predictions = logits.squeeze().softmax(-1)
    assert len(x.shape) == 1
    assert len(predictions.shape) == 2
    assert len(predictions) == len(x)
    assert len(bar_borders.shape) == 1
    assert len(bar_borders) - 1 == predictions.shape[1]
    assert isinstance(x, torch.Tensor)

    if merge_bars and merge_bars > 1:
        new_borders_inds = torch.arange(0, len(bar_borders), merge_bars)
        if new_borders_inds[-1] != len(bar_borders) - 1:
            new_borders_inds = torch.cat(
                [new_borders_inds, torch.tensor([len(bar_borders) - 1])]
            )
        bar_borders = bar_borders[new_borders_inds]
        pred_cumsum = torch.cat(
            [torch.zeros(len(predictions), 1), predictions.cumsum(-1)], dim=-1
        )

        predictions = (
            pred_cumsum[:, new_borders_inds[1:]] - pred_cumsum[:, new_borders_inds[:-1]]
        )
        assert len(bar_borders) - 1 == predictions.shape[-1]

    if restrict_to_range is not None:
        min_y, max_y = restrict_to_range
        border_mask = (min_y <= bar_borders) & (bar_borders <= max_y)
        # make the mask itself one border broader
        border_mask[:-1] = border_mask[1:] | border_mask[:-1]
        border_mask[1:] = border_mask[1:] | border_mask[:-1]
        logit_mask = border_mask[:-1] & border_mask[1:]
        bar_borders = bar_borders[border_mask]
        predictions = predictions[:, logit_mask]

    y_starts = bar_borders[:-1]
    y_ends = bar_borders[1:]

    x, order = x.sort(0)

    predictions = predictions[order] / (bar_borders[1:] - bar_borders[:-1])
    predictions[torch.isinf(predictions)] = 0.0
    predictions[:, (bar_borders[1:] - bar_borders[:-1]) < 1e-10] = 0.0

    if plot_log_probs:
        predictions = predictions.log()
        predictions[predictions.isinf()] = torch.min(predictions[~predictions.isinf()])

    # assume x is sorted
    x_starts = torch.cat([x[0].unsqueeze(0), (x[1:] + x[:-1]) / 2])
    x_ends = torch.cat(
        [
            (x[1:] + x[:-1]) / 2,
            x[-1].unsqueeze(0),
        ]
    )

    heatmap_with_box_sizes(
        ax, predictions.T, x_starts, x_ends, y_starts, y_ends, **kwargs
    )

We need to document usage and add the visualization code to our repository (tabpfn-extensions?)

noahho avatar Jan 12 '25 11:01 noahho