TabPFN
TabPFN copied to clipboard
Document bar distribution: bimodal / full uncertainty regression outputs and creating example
Example usage and visualization and of full uncertainty mode:
reg = TabPFNRegressor()
reg.fit(x, y_noisy)
preds = reg.predict(x_test, output_type="full")
fig, ax = plt.subplots(1, figsize=(12,6))
N = 10 #number of samples to visualize
plot_bar_distribution(ax, torch.tensor(x)[0:N], preds["criterion"].borders, preds["logits"][0:N])
ax.set_ylim(-1, 10)
import matplotlib.patches as patches
import seaborn as sns
import torch
import warnings
from matplotlib.collections import PatchCollection
def get_rect(coord, height, width):
rect = patches.Rectangle(coord, height, width)
return rect
def heatmap_with_box_sizes(
ax,
data: torch.Tensor,
x_starts,
x_ends,
y_starts,
y_ends,
palette=None,
set_lims=True,
threshold_i=0.0, # Threshold intensity (not probability)
y_min=None,
y_max=None,
transpose=False,
per_col_normalize=False,
):
"""
Beware all x and y arrays should be sorted from small to large and the data will appear in that same order: Small indexes map to lower x/y-axis values.
"""
if palette is None:
palette = sns.cubehelix_palette(
start=2.9,
rot=0.0,
dark=0.6,
light=1,
gamma=4.0,
hue=9.0,
as_cmap=True
# use gamma to control how much of the spectrum is saturated, more gamma -> bigger part that is saturated
# use dark to control how dark the darkest part is, a higher value will make the darkest part lighter
)
if set_lims:
ax.set_xlim(x_starts[0], x_ends[-1])
if not y_min or not y_max:
assert (
len(y_starts.shape) == 1
), "If y_min and y_max are not provided, y_starts should be 1D. Please set y_min and y_max manually."
ax.set_ylim(y_starts[0], y_ends[-1])
else:
ax.set_ylim(y_min, y_max)
if per_col_normalize:
data = (data - data.min(0, keepdim=True).values) / (
data.max(0, keepdim=True).values - data.min(0, keepdim=True).values
)
else:
data = (data - data.min()) / (data.max() - data.min())
rects, colors = [], []
assert y_ends.shape == y_starts.shape
if len(y_starts.shape) == 1:
y_starts = y_starts.unsqueeze(0).expand(len(x_starts), -1)
y_ends = y_ends.unsqueeze(0).expand(len(x_starts), -1)
for col_i, (col_start, col_end) in enumerate(zip(x_starts, x_ends)):
for row_i, (row_start, row_end) in enumerate(
zip(y_starts[col_i], y_ends[col_i])
):
intensity = data[row_i, col_i].item()
intensity = max(0.0, (intensity - threshold_i)) / (
1 - threshold_i
) # Start with intensity at the threshold value (smoother visualization)
if intensity <= 0:
continue
if y_max and y_min and (row_start > y_max or row_end < y_min):
continue
if row_start >= row_end or col_start >= col_end:
continue
if palette(intensity) == (1.0, 1.0, 1.0, 1.0):
continue
# print(row_start, row_end, col_start, col_end, intensity, palette(intensity))
# print(intensity, palette(intensity), row_start, row_end)
# e.g. data[row_i, col_i].item() / col_end - col_start (or row_end - row_start)
if transpose:
rects += [
get_rect(
(row_start, col_start), row_end - row_start, col_end - col_start
)
]
else:
rects += [
get_rect(
(col_start, row_start), col_end - col_start, row_end - row_start
)
]
colors += [palette(intensity)]
rect_collection = PatchCollection(
rects, facecolors=colors, edgecolor="none", linewidth=1
)
ax.add_collection(rect_collection)
ax.set_rasterized(True)
def plot_bar_distribution(
ax,
x: torch.Tensor,
bar_borders: torch.Tensor,
logits: torch.Tensor,
merge_bars=None,
restrict_to_range=None,
plot_log_probs=False,
**kwargs,
):
"""
:param ax: A matplotlib axis, you can get one with: `fig, ax = pyplot.subplots()`
:param x: The positions to plot on the x-axis, this is your x, but it has to be 1d with shape (num_examples,)
:param bar_borders: The borders of your bar distritbuion, they can be obtained at transformer_model.criterion.borders
:param logits: A tensor of shape (num_examples, len(bar_borders)-1) that comes straight out of the model
:param merge_bars: Number of bars to merge into one. If None, no merging is done. This speeds up the plotting.
:param restrict_to_range: A tuple of (min_y, max_y) that restricts the y-axis to this range. If None, no restriction is done.
:param plot_log_probs: If True, the log probabilities are plotted instead of the probabilities. This is useful if some probabilities are really high.
:param kwargs:
:return:
"""
x = x.squeeze()
predictions = logits.squeeze().softmax(-1)
assert len(x.shape) == 1
assert len(predictions.shape) == 2
assert len(predictions) == len(x)
assert len(bar_borders.shape) == 1
assert len(bar_borders) - 1 == predictions.shape[1]
assert isinstance(x, torch.Tensor)
if merge_bars and merge_bars > 1:
new_borders_inds = torch.arange(0, len(bar_borders), merge_bars)
if new_borders_inds[-1] != len(bar_borders) - 1:
new_borders_inds = torch.cat(
[new_borders_inds, torch.tensor([len(bar_borders) - 1])]
)
bar_borders = bar_borders[new_borders_inds]
pred_cumsum = torch.cat(
[torch.zeros(len(predictions), 1), predictions.cumsum(-1)], dim=-1
)
predictions = (
pred_cumsum[:, new_borders_inds[1:]] - pred_cumsum[:, new_borders_inds[:-1]]
)
assert len(bar_borders) - 1 == predictions.shape[-1]
if restrict_to_range is not None:
min_y, max_y = restrict_to_range
border_mask = (min_y <= bar_borders) & (bar_borders <= max_y)
# make the mask itself one border broader
border_mask[:-1] = border_mask[1:] | border_mask[:-1]
border_mask[1:] = border_mask[1:] | border_mask[:-1]
logit_mask = border_mask[:-1] & border_mask[1:]
bar_borders = bar_borders[border_mask]
predictions = predictions[:, logit_mask]
y_starts = bar_borders[:-1]
y_ends = bar_borders[1:]
x, order = x.sort(0)
predictions = predictions[order] / (bar_borders[1:] - bar_borders[:-1])
predictions[torch.isinf(predictions)] = 0.0
predictions[:, (bar_borders[1:] - bar_borders[:-1]) < 1e-10] = 0.0
if plot_log_probs:
predictions = predictions.log()
predictions[predictions.isinf()] = torch.min(predictions[~predictions.isinf()])
# assume x is sorted
x_starts = torch.cat([x[0].unsqueeze(0), (x[1:] + x[:-1]) / 2])
x_ends = torch.cat(
[
(x[1:] + x[:-1]) / 2,
x[-1].unsqueeze(0),
]
)
heatmap_with_box_sizes(
ax, predictions.T, x_starts, x_ends, y_starts, y_ends, **kwargs
)
We need to document usage and add the visualization code to our repository (tabpfn-extensions?)