scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

BatchStats KeyError with LayerNorm in JaxPEAKVI

Open RK900 opened this issue 3 years ago • 0 comments
trafficstars

Error when using Flax LayerNorm in the Decoder in place of BatchNorm for JaxPEAKVI

Code

class FlaxDecoder(nn.Module):
    n_input: int
    dropout_rate: float
    n_hidden: int

    def setup(self):
        self.dense1 = Dense(self.n_hidden)
        self.dense2 = Dense(self.n_hidden)
        self.dense3 = Dense(self.n_hidden)
        self.dense4 = Dense(self.n_hidden)
        self.dense5 = Dense(self.n_input)

        self.layernorm1 = nn.LayerNorm()
        self.layernorm2 = nn.LayerNorm()
        self.dropout1 = nn.Dropout(self.dropout_rate)
        self.dropout2 = nn.Dropout(self.dropout_rate)

    def __call__(self, z: jnp.ndarray, batch: jnp.ndarray, training: bool = False):
        is_eval = not training

        h = self.dense1(z)
        h += self.dense2(batch)

        h = self.layernorm1(h)
        h = nn.leaky_relu(h)
        h = self.dropout1(h, deterministic=is_eval)
        h = self.dense3(h)
        # skip connection
        h += self.dense4(batch)
        h = self.layernorm2(h)
        h = nn.leaky_relu(h)
        h = self.dropout2(h, deterministic=is_eval)
        h = self.dense5(h)
        h = nn.sigmoid(h)
        return h

Error

_______________________________________________________ test_jax_peakvi _______________________________________________________

    def test_jax_peakvi():
        n_latent = 5

        adata = synthetic_iid()
        JaxPEAKVI.setup_anndata(
            adata,
            batch_key="batch",
        )

        model = JaxPEAKVI(adata, n_latent=n_latent)
>       model.train(2, train_size=0.5, check_val_every_n_epoch=1)

tests/models/test_models.py:101:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
scvi/model/base/_jaxmixin.py:97: in train
    runner()
scvi/train/_trainrunner.py:74: in __call__
    self.trainer.fit(self.training_plan, self.data_splitter)
scvi/train/_trainer.py:188: in fit
    super().fit(*args, **kwargs)
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:770: in fit
    self._call_and_handle_interrupt(
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:723: in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:811: in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1236: in _run
    results = self._run_stage()
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1323: in _run_stage
    return self._run_train()
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1353: in _run_train
    self.fit_loop.run()
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/loops/base.py:199: in run
    self.on_run_start(*args, **kwargs)
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:217: in on_run_start
    self.trainer._call_callback_hooks("on_train_start")
../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1636: in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
scvi/train/_callbacks.py:164: in on_train_start
    batch_stats = module_init["batch_stats"]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = FrozenDict({
    params: {
        encoder: {
            dense1: {
                kernel: DeviceArray([[-0.01504633,...                           0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
            },
        },
    },
})
key = 'batch_stats'

    def __getitem__(self, key):
>     v = self._dict[key]
E     KeyError: 'batch_stats'

../../../anaconda3/envs/scvi_test1/lib/python3.8/site-packages/flax/core/frozen_dict.py:66: KeyError

RK900 avatar Jul 19 '22 23:07 RK900