scvi-tools
scvi-tools copied to clipboard
BatchStats KeyError with LayerNorm in JaxPEAKVI
trafficstars
Error when using Flax LayerNorm in the Decoder in place of BatchNorm for JaxPEAKVI
Code
class FlaxDecoder(nn.Module):
n_input: int
dropout_rate: float
n_hidden: int
def setup(self):
self.dense1 = Dense(self.n_hidden)
self.dense2 = Dense(self.n_hidden)
self.dense3 = Dense(self.n_hidden)
self.dense4 = Dense(self.n_hidden)
self.dense5 = Dense(self.n_input)
self.layernorm1 = nn.LayerNorm()
self.layernorm2 = nn.LayerNorm()
self.dropout1 = nn.Dropout(self.dropout_rate)
self.dropout2 = nn.Dropout(self.dropout_rate)
def __call__(self, z: jnp.ndarray, batch: jnp.ndarray, training: bool = False):
is_eval = not training
h = self.dense1(z)
h += self.dense2(batch)
h = self.layernorm1(h)
h = nn.leaky_relu(h)
h = self.dropout1(h, deterministic=is_eval)
h = self.dense3(h)
# skip connection
h += self.dense4(batch)
h = self.layernorm2(h)
h = nn.leaky_relu(h)
h = self.dropout2(h, deterministic=is_eval)
h = self.dense5(h)
h = nn.sigmoid(h)
return h
Error
_______________________________________________________ test_jax_peakvi _______________________________________________________
def test_jax_peakvi():
n_latent = 5
adata = synthetic_iid()
JaxPEAKVI.setup_anndata(
adata,
batch_key="batch",
)
model = JaxPEAKVI(adata, n_latent=n_latent)
> model.train(2, train_size=0.5, check_val_every_n_epoch=1)
tests/models/test_models.py:101:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
scvi/model/base/_jaxmixin.py:97: in train
runner()
scvi/train/_trainrunner.py:74: in __call__
self.trainer.fit(self.training_plan, self.data_splitter)
scvi/train/_trainer.py:188: in fit
super().fit(*args, **kwargs)
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:770: in fit
self._call_and_handle_interrupt(
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:723: in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:811: in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1236: in _run
results = self._run_stage()
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1323: in _run_stage
return self._run_train()
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1353: in _run_train
self.fit_loop.run()
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/loops/base.py:199: in run
self.on_run_start(*args, **kwargs)
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:217: in on_run_start
self.trainer._call_callback_hooks("on_train_start")
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1636: in _call_callback_hooks
fn(self, self.lightning_module, *args, **kwargs)
scvi/train/_callbacks.py:164: in on_train_start
batch_stats = module_init["batch_stats"]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = FrozenDict({
params: {
encoder: {
dense1: {
kernel: DeviceArray([[-0.01504633,... 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
},
},
},
})
key = 'batch_stats'
def __getitem__(self, key):
> v = self._dict[key]
E KeyError: 'batch_stats'
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/flax/core/frozen_dict.py:66: KeyError