scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

`SCANVI.loss` fails when `n_samples>1`

Open martinkim0 opened this issue 1 year ago • 0 comments

Reproducible script:

vae = SCANVAE(n_input=100, n_labels=1)
x = torch.randint(0, 100, (10, 100), dtype=torch.float32)
batch = torch.zeros(10, dtype=torch.long)
labels = torch.zeros(10, dtype=torch.long)
tensors = {
    REGISTRY_KEYS.X_KEY: x,
    REGISTRY_KEYS.BATCH_KEY: batch,
    REGISTRY_KEYS.LABELS_KEY: labels,
}

_ = vae.forward(
    tensors,
    inference_kwargs={"n_samples": 2}
)

martinkim0 avatar Jan 23 '24 22:01 martinkim0