Mava
Mava copied to clipboard
[BUG][Jax] Default executor flattens observations
Describe the bug
The default executor flattens observations which is undesirable for environments which can benefit from spatial knowledge.
To Reproduce
observation = executor.store.observation.observation.reshape((1, -1))
<- inside action_selection.py
Expected behavior
It was expected that observations would be fed to neural networks unedited.
Possible Solution
observation = utils.add_batch_dim(executor.store.observation.observation)