scvi-tools
scvi-tools copied to clipboard
`SCANVI.loss` fails when `n_samples>1`
Reproducible script:
vae = SCANVAE(n_input=100, n_labels=1)
x = torch.randint(0, 100, (10, 100), dtype=torch.float32)
batch = torch.zeros(10, dtype=torch.long)
labels = torch.zeros(10, dtype=torch.long)
tensors = {
REGISTRY_KEYS.X_KEY: x,
REGISTRY_KEYS.BATCH_KEY: batch,
REGISTRY_KEYS.LABELS_KEY: labels,
}
_ = vae.forward(
tensors,
inference_kwargs={"n_samples": 2}
)