QDax icon indicating copy to clipboard operation
QDax copied to clipboard

PGAME replay buffer addition fail for large batch-size

Open manon-but-yes opened this issue 2 years ago • 2 comments

Hello :)

The current implementation of the PGAME emitter is failing for large batch-size (here for example 32768) with the following error:

Traceback (most recent call last):
  File "main.py", line 415, in <module>
    repertoire, emitter_state, random_key = map_elites.init(
  File "/git/exp/QDax/qdax/core/map_elites.py", line 80, in init
    emitter_state = self._emitter.state_update(
  File "/git/exp/QDax/qdax/core/emitters/pga_me_emitter.py", line 262, in state_update
    replay_buffer = emitter_state.replay_buffer.insert(transitions)
  File "/git/exp/QDax/qdax/core/neuroevolution/buffers/buffer.py", line 351, in insert
    new_data = jax.lax.dynamic_update_slice_in_dim(
TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (3276800, 53) for operand shape (1000000, 53).

Cause: This error is due to the replay buffer addition: when the batch-size is too large, the emitter is trying to add to it a number of transitions that is greater than the maximum size of the buffer.

Possible fix: Here are some of the possible fixes I could think of:

  • Augment the buffer-size proportionally to the batch-size. However, this might have unpredictable side effects on the convergence of PGAME.
  • Select the first buffer-size transitions for addition. This would work as a quick fix. However, as transitions are ordered per actor, it would not guarantee that the transitions are non-correlated, and might have an impact on the convergence, especially for small buffer-size.
  • Randomly sample buffer-size transitions for addition. This seems to be the best choice of this list so far. However, the current state_update method of the PGEmitter and the current insert method of the ReplayBuffer do not take any random key as input. One solution might be to store this key in the emitter state, but it would probably be more relevant to do this transitions-selection inside the ReplayBuffer class and not inside the PGEmitter class.

manon-but-yes avatar Jul 19 '22 15:07 manon-but-yes

Can confirm this issue! I observed the same bug a couple of months ago when testing. Raised this internally but then lost track of it once we started focusing on other parts of the library.

limbryan avatar Jul 20 '22 09:07 limbryan

Hi!

Have you had any thought about you prefered way to handle this?

As it sounds like a particular use case, I would suggest the following solution:

  • add a check in the current buffer implementation to have a clear error message when the new batch is bigger than the buffer max size
  • add a new buffer (sharing most methods with the current one through inheritence) that would implement one of the solutions you proposed to handle this use case

What do you think about this?

felixchalumeau avatar Aug 19 '22 12:08 felixchalumeau