dejax
dejax copied to clipboard
More efficient add_batch_fn implementation
Really much appreciate this repository. However, for non-trivial batch sizes, the add_batch_fn
caused a lot of slowdowns due to having to add every batch element sequentially.
In case somebody is interested, I alleviated this issue via this implementation. First, a new util function to add a batch via a single operation to the data storage.
def set_pytree_batch_items(tree_batch, index, trees):
return jax.tree_util.tree_map(
lambda tb, t: jax.lax.dynamic_update_slice(tb, t, (index, 0)),
tree_batch, trees,
)
Second, the outer function:
def add_batch_fn(state: UniformReplayBufferState, batch: ItemBatch) -> UniformReplayBufferState:
buffer = state.storage
insert_pos = buffer.head
new_data = utils.set_pytree_batch_items(buffer.data, insert_pos, batch)
new_head = (insert_pos + batch[0].shape[0]) % circular_buffer.max_size(buffer)
new_tail = jax.lax.select(
buffer.full,
on_true=0, # Changed, due to the way `jax.lax.dynamic_update_slice` behaves inside `set_pytree_batch_items`
on_false=buffer.tail,
)
new_full = new_head == new_tail
return state.replace(storage=buffer.replace(data=new_data, head=new_head, tail=new_tail, full=new_full))
I thought about doing a pull request, but I'm too lazy right now to do a clean adaptation of this quick fix to the code base.
Nice! Unfortunately I don't have a lot of time to support this repo atm, but I'll try to integrate your idea.