QDax
QDax copied to clipboard
PGAME replay buffer addition fail for large batch-size
Hello :)
The current implementation of the PGAME emitter is failing for large batch-size (here for example 32768) with the following error:
Traceback (most recent call last):
File "main.py", line 415, in <module>
repertoire, emitter_state, random_key = map_elites.init(
File "/git/exp/QDax/qdax/core/map_elites.py", line 80, in init
emitter_state = self._emitter.state_update(
File "/git/exp/QDax/qdax/core/emitters/pga_me_emitter.py", line 262, in state_update
replay_buffer = emitter_state.replay_buffer.insert(transitions)
File "/git/exp/QDax/qdax/core/neuroevolution/buffers/buffer.py", line 351, in insert
new_data = jax.lax.dynamic_update_slice_in_dim(
TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (3276800, 53) for operand shape (1000000, 53).
Cause: This error is due to the replay buffer addition: when the batch-size is too large, the emitter is trying to add to it a number of transitions that is greater than the maximum size of the buffer.
Possible fix: Here are some of the possible fixes I could think of:
- Augment the buffer-size proportionally to the batch-size. However, this might have unpredictable side effects on the convergence of PGAME.
- Select the first buffer-size transitions for addition. This would work as a quick fix. However, as transitions are ordered per actor, it would not guarantee that the transitions are non-correlated, and might have an impact on the convergence, especially for small buffer-size.
- Randomly sample buffer-size transitions for addition. This seems to be the best choice of this list so far. However, the current
state_update
method of thePGEmitter
and the currentinsert
method of theReplayBuffer
do not take any random key as input. One solution might be to store this key in the emitter state, but it would probably be more relevant to do this transitions-selection inside theReplayBuffer
class and not inside thePGEmitter
class.
Can confirm this issue! I observed the same bug a couple of months ago when testing. Raised this internally but then lost track of it once we started focusing on other parts of the library.
Hi!
Have you had any thought about you prefered way to handle this?
As it sounds like a particular use case, I would suggest the following solution:
- add a check in the current buffer implementation to have a clear error message when the new batch is bigger than the buffer max size
- add a new buffer (sharing most methods with the current one through inheritence) that would implement one of the solutions you proposed to handle this use case
What do you think about this?