flashbax icon indicating copy to clipboard operation
flashbax copied to clipboard

[BUG] Replay buffer on CPU is much slower than the performance in doc.

Open HeavyCrab opened this issue 5 months ago • 12 comments

Describe the bug

I am trying to put the buffer on CPU to alleviate the memory shortage on GPU, but encountered a large speed degradation. As shown in readme.md, the speed on CPU should be much faster than that on GPU or TPU. However in my case the speed on CPU is much slower than both.

To Reproduce

This is a minimal script to test the speed.

import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

from types import SimpleNamespace
import jax
import jax.numpy as jnp
from flax import struct
import timeit

import flashbax as fbx
from flashbax.buffers.prioritised_trajectory_buffer import PrioritisedTrajectoryBufferState

config = SimpleNamespace()
config.num_envs = 16
config.batch_size = 256
config.obs_n_stack = 1
config.num_unroll_steps = 5
config.buffer_size = 2000000
config.priority_prob_alpha = 0.5
config.max_moves = 100
config.td_steps = 5
config.start_transitions = 400
config.obj_size = 100
config.device = jax.default_backend()

print(f"jax backend: {jax.default_backend()}")

@struct.dataclass
class BufferState:
    _state: PrioritisedTrajectoryBufferState

class ReplayBuffer:
    def __init__(self, config):
        self.buffer = fbx.make_prioritised_trajectory_buffer(
            add_batch_size=config.num_envs,
            sample_batch_size=config.batch_size,
            sample_sequence_length=config.obs_n_stack + config.num_unroll_steps + config.td_steps,
            period=1,
            min_length_time_axis=(config.start_transitions + config.num_envs - 1) // config.num_envs,
            max_length_time_axis=(config.buffer_size + config.num_envs - 1) // config.num_envs,
            priority_exponent=config.priority_prob_alpha,
            device=config.device,
        )

        self.config = config

    def init(self):
        xx = jnp.zeros((config.obj_size,), dtype=jnp.float32)
        _state = self.buffer.init(xx)

        init_state = BufferState(
            _state=_state,
        )
        return init_state

    def add(self, state: BufferState, batched_trajs) -> BufferState:
        _state = self.buffer.add(state._state, batched_trajs)

        return BufferState(
            _state=_state
        )

# initialize training components
buffer = ReplayBuffer(config)
buffer_state = buffer.init()

def buf_add(buffer_state):
    his = jnp.ones((config.num_envs, config.max_moves, config.obj_size), dtype=jnp.float32)
    buffer_state = buffer.add(buffer_state, his)
    return buffer_state


jit_buf_add = jax.jit(buf_add)
jit_buf_add_donate = jax.jit(buf_add, donate_argnums=(0,))
# warmup
buffer_state = jit_buf_add(buffer_state)
jax.block_until_ready(buffer_state)
buffer_state = jit_buf_add_donate(buffer_state)
jax.block_until_ready(buffer_state)

tot = 10

start_time = timeit.default_timer()
for i in range(tot):
    buffer_state = jit_buf_add(buffer_state)
jax.block_until_ready(buffer_state)
execution_time = timeit.default_timer() - start_time
print('jit_buf_add execution_time (sec):', execution_time / tot)

start_time = timeit.default_timer()
for i in range(tot):
    buffer_state = jit_buf_add_donate(buffer_state)
jax.block_until_ready(buffer_state)
execution_time = timeit.default_timer() - start_time
print('jit_buf_add_donate execution_time (sec):', execution_time / tot)


The output is:

$ python buf.py 
jax backend: cpu
jit_buf_add execution_time (sec): 0.4437830951996148
jit_buf_add_donate execution_time (sec): 0.20866907988674938
$ python buf.py 
jax backend: gpu
jit_buf_add execution_time (sec): 0.002522601559758186
jit_buf_add_donate execution_time (sec): 0.0025197524111717938

Context (Environment)

brax                      0.10.3                   pypi_0    pypi
flashbax                  0.1.3                    pypi_0    pypi
flax                      0.10.7                   pypi_0    pypi
gymnax                    0.0.9                    pypi_0    pypi
jax                       0.6.2                    pypi_0    pypi
jax-cuda12-pjrt           0.6.2                    pypi_0    pypi
jax-cuda12-plugin         0.6.2                    pypi_0    pypi
jaxlib                    0.6.2                    pypi_0    pypi
jaxmarl                   0.0.7                    pypi_0    pypi
jaxopt                    0.8.5                    pypi_0    pypi
optax                     0.2.5                    pypi_0    pypi
orbax-checkpoint          0.11.14                  pypi_0    pypi

HeavyCrab avatar Aug 04 '25 03:08 HeavyCrab

Looks like you're not donating the correct argument, argument 0 is self, you need to donate the buffer state which is argument 1. You can also use donate_argnames to be sure you're donating the correct argument.

For an MWE I'd recommend not using a class and not wrapping the state as those could be interfering with the performance. If you must use a class it's often better to jit inside the constructor of that class as then you're jitting the fbx functions directly and you're less likely to donate the wrong argument

sash-a avatar Aug 04 '25 03:08 sash-a

@sash-a Thanks for your reply. I think I am using a correct donation since I am jitting the wrapped function buf_add but not the buffer.add. buf_add has only one argument.

HeavyCrab avatar Aug 04 '25 04:08 HeavyCrab

And I have removed the wrapper. The result is similar.

import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

from types import SimpleNamespace
import jax
import jax.numpy as jnp
import timeit

import flashbax as fbx

config = SimpleNamespace()
config.num_envs = 16
config.batch_size = 256
config.obs_n_stack = 1
config.num_unroll_steps = 5
config.buffer_size = 2000000
config.priority_prob_alpha = 0.5
config.max_moves = 100
config.td_steps = 5
config.start_transitions = 400
config.obj_size = 100
config.device = jax.default_backend()

print(f"jax backend: {jax.default_backend()}")

buffer = fbx.make_prioritised_trajectory_buffer(
    add_batch_size=config.num_envs,
    sample_batch_size=config.batch_size,
    sample_sequence_length=config.obs_n_stack + config.num_unroll_steps + config.td_steps,
    period=1,
    min_length_time_axis=(config.start_transitions + config.num_envs - 1) // config.num_envs,
    max_length_time_axis=(config.buffer_size + config.num_envs - 1) // config.num_envs,
    priority_exponent=config.priority_prob_alpha,
    device=config.device,
)
template = jnp.zeros((config.obj_size,), dtype=jnp.float32)
buffer_state = buffer.init(template)

def buf_add(buffer_state):
    his = jnp.ones((config.num_envs, config.max_moves, config.obj_size), dtype=jnp.float32)
    buffer_state = buffer.add(buffer_state, his)
    return buffer_state


jit_buf_add = jax.jit(buf_add)
jit_buf_add_donate = jax.jit(buf_add, donate_argnums=(0,))
# warmup
buffer_state = jit_buf_add(buffer_state)
jax.block_until_ready(buffer_state)
buffer_state = jit_buf_add_donate(buffer_state)
jax.block_until_ready(buffer_state)

tot = 10

start_time = timeit.default_timer()
for i in range(tot):
    buffer_state = jit_buf_add(buffer_state)
jax.block_until_ready(buffer_state)
execution_time = timeit.default_timer() - start_time
print('jit_buf_add execution_time (sec):', execution_time / tot)

start_time = timeit.default_timer()
for i in range(tot):
    buffer_state = jit_buf_add_donate(buffer_state)
jax.block_until_ready(buffer_state)
execution_time = timeit.default_timer() - start_time
print('jit_buf_add_donate execution_time (sec):', execution_time / tot)

output:

$ python buf.py 
jax backend: cpu
jit_buf_add execution_time (sec): 0.4400634744670242
jit_buf_add_donate execution_time (sec): 0.2075462588109076
$ python buf.py 
jax backend: gpu
jit_buf_add execution_time (sec): 0.0025221142917871476
jit_buf_add_donate execution_time (sec): 0.0025158045347779988

HeavyCrab avatar Aug 04 '25 04:08 HeavyCrab

Ah my bad, I was on my phone and I thought you were jitting the class method.

I had a look and it seems to be the fact that it's a prioritised buffer. If you switch to a normal trajectory buffer the speed is in line with what we show in the docs. I would guess that this is because the prioritised buffer has to do some computation which is much faster on the GPU. I doubt this can be changed unfortunately, but maybe @EdanToledo can weigh in here?

sash-a avatar Aug 05 '25 08:08 sash-a

Just for the record this is the speed I get with a TrajectorBuffer. However I needed to switch to the same config as in the docs, when I use your config I see a similar trend of the GPU being faster, this is likely because of the higher num envs helping the GPU do more in parallel.

❯ python fbx_test.py
jax backend: cpu
jit_buf_add execution_time (sec): 0.1259669284001575
jit_buf_add_donate execution_time (sec): 3.787140012718737e-05

❯ python fbx_test.py
jax backend: gpu
jit_buf_add execution_time (sec): 0.014787551900008111
jit_buf_add_donate execution_time (sec): 0.00020687850010290277

Config:

config = SimpleNamespace()
config.num_envs = 1
config.batch_size = 256
config.obs_n_stack = 1
config.num_unroll_steps = 5
config.buffer_size = 100_000
config.priority_prob_alpha = 0.5
config.max_moves = 1
config.td_steps = 5
config.start_transitions = 400
config.obj_size = (32, 32, 3)
config.device = jax.default_backend()

So maybe worth doing a more extensive test for the docs to show which parameters lend themselves to which devices

sash-a avatar Aug 05 '25 08:08 sash-a

Hey, yeah what i will say is that prioritised replay buffers on cpu do sequential operations when adding data to the buffer due to it being faster (at least at the time of originally writing flashbax). It would be worth testing the gpu adding operation on cpu for prioritised buffer with newer versions of jax. The prioritised buffer is much slower than the trajectory buffer when adding large quantities of data due to the sum tree update. Writing a new benchmarking script for all the buffers would be actually quite a useful thing to do so we can update the results on the readme. If i find time to do this, i dont mind giving it a crack to get new numbers but i am quite busy rn. Ultimately, to summarise, different buffers and different devices have different parameters that scale with time in different ways. crazy sentence but its complicated to show the proper relationships for all of them.

EdanToledo avatar Aug 05 '25 10:08 EdanToledo

@EdanToledo Thank you for the clarification! I have a few more questions. BTW, What is the main factor causing the speed difference of sum tree updates between CPU and GPU? If we denote the number of newly added transitions as $K=B\times M$ (i.e., batch_size * max_move) and the size of the sum tree as $N$, the insertion complexity can be $O(K/p\times \log N)$ or $O(K/p + B/p \times \log N)$ with an appropriate implementation, where $p$ is the degree of parallelism. On the GPU, leveraging its higher degree of parallelism, it can achieve $O(\log N)$ when $K$ is not too large. I tested a simple fori_loop and found that on the CPU it can be up to a thousand times faster than on the GPU. Based on this, the difference here should not be so significant—at least in my case with $K=1600$, I would expect the speeds to be roughly comparable (even a faster speed on CPU). Therefore, I don’t really understand why there is such a huge performance gap here. In addition, inspired by @sash-a , I tested the case of prioritised_buffer with max_move = 1. I found that when max_move = 1, the speed of add_with_donation is extremely fast, but when max_move = 2, it becomes much slower (over a thousand times slower rather than just twice as slow). This doesn’t seem very reasonable, since I can manually split a max_move = 2 insertion into two separate max_move = 1 insertions. I’m not sure if I’m missing something here. Any further explanation would be greatly appreciated.

I use the same script here:

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

from types import SimpleNamespace
import jax
import jax.numpy as jnp
import timeit

import flashbax as fbx

config = SimpleNamespace()
config.num_envs = 16
config.batch_size = 256
config.obs_n_stack = 1
config.num_unroll_steps = 5
config.buffer_size = 2000000
config.priority_prob_alpha = 0.5
config.max_moves = 2
config.td_steps = 5
config.start_transitions = 400
config.obj_size = 100
config.device = jax.default_backend()

print(f"jax backend: {jax.default_backend()}")
print(f"max_moves: {config.max_moves}")

buffer = fbx.make_prioritised_trajectory_buffer(
    add_batch_size=config.num_envs,
    sample_batch_size=config.batch_size,
    sample_sequence_length=config.obs_n_stack + config.num_unroll_steps + config.td_steps,
    period=1,
    min_length_time_axis=(config.start_transitions + config.num_envs - 1) // config.num_envs,
    max_length_time_axis=(config.buffer_size + config.num_envs - 1) // config.num_envs,
    priority_exponent=config.priority_prob_alpha,
    device=config.device,
)
template = jnp.zeros((config.obj_size,), dtype=jnp.float32)
buffer_state = buffer.init(template)

def buf_add(buffer_state):
    his = jnp.ones((config.num_envs, config.max_moves, config.obj_size), dtype=jnp.float32)
    buffer_state = buffer.add(buffer_state, his)
    return buffer_state


jit_buf_add = jax.jit(buf_add)
jit_buf_add_donate = jax.jit(buf_add, donate_argnums=(0,))
# warmup
buffer_state = jit_buf_add(buffer_state)
jax.block_until_ready(buffer_state)
buffer_state = jit_buf_add_donate(buffer_state)
jax.block_until_ready(buffer_state)

tot = 10

start_time = timeit.default_timer()
for i in range(tot):
    buffer_state = jit_buf_add(buffer_state)
jax.block_until_ready(buffer_state)
execution_time = timeit.default_timer() - start_time
print('jit_buf_add execution_time (sec):', execution_time / tot)

start_time = timeit.default_timer()
for i in range(tot):
    buffer_state = jit_buf_add_donate(buffer_state)
jax.block_until_ready(buffer_state)
execution_time = timeit.default_timer() - start_time
print('jit_buf_add_donate execution_time (sec):', execution_time / tot)

And the result is:

$ python buf.py 
jax backend: cpu
max_moves: 1
jit_buf_add execution_time (sec): 0.11380044710822404
jit_buf_add_donate execution_time (sec): 7.206676527857781e-05
$ python buf.py 
jax backend: cpu
max_moves: 2
jit_buf_add execution_time (sec): 0.4524868482723832
jit_buf_add_donate execution_time (sec): 0.20394585207104682

HeavyCrab avatar Aug 05 '25 23:08 HeavyCrab

To tell you the honest truth, i also dont fully understand the speed difference. I remember when i was developing, that a lot of things had to be empirically tested. However, for clarity here is a more detailed breakdown of the two implementations (at least from skimming the code):

CPU Implementation: The complexity is $O(K \log N)$. We do the LogN updates K times sequentially.

GPU Implementation: The complexity is $O(K + N + \log N)$. To achieve parallelism across the $K$ updates, we had to use a bincount operation which is $O(K + N)$ (*i think) - The N comes from the fact that we have to use a static sized operation for JIT compilation and we have to use the capacity to ensure all updates are potentially accounted for. The choice of bincount for this was through empirical testing.

The counter-intuitive part here is that even with a worse theoretical complexity, the GPU version is empirically much faster when done on a GPU and this is the reason we have two update implementations in the first place (i really did not want to have this when i was developing but i could not figure out a single implementation that both worked well on CPU and GPU). The reason for this being faster is due to the magic of GPUs and JAX that i unfortunately do not fully understand. I imagine that during the fusing process of operations, it somehow becomes more parallelised or lower complexity through black magic.

(Disclaimer: It's been a while since I developed this, so my apologies for any potential inaccuracies in my recollection.)

My recommendation: I suggest benchmarking the Sum Tree on CPU and GPU independently, separate from the rest of the buffer. It's possible that other factors within the PER buffer logic are contributing to the slowdown, and isolating the Sum Tree will help us confirm where the bottleneck is.

Let me know if this makes sense!

EdanToledo avatar Aug 06 '25 10:08 EdanToledo

@EdanToledo Thank you so much for the detailed explanation! That answers some of my questions. However, I think the CPU add code might still be buggy here according to my latest script above. Do you have any comments on it? To be detailed, the speed of add with max_move=2 is thousands times slower than max_move=1. However, we can implement add with max_move=2 by calling add with max_move=1 twice, which is only twice slower than max_move=1. According to your clarification on the CPU add method (sequentially adding), it should be the latter case.

Here is an updated script to help you reproduce the results.

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

from types import SimpleNamespace
import jax
import jax.numpy as jnp
import timeit

import flashbax as fbx

config = SimpleNamespace()
config.num_envs = 16
config.batch_size = 256
config.obs_n_stack = 1
config.num_unroll_steps = 5
config.buffer_size = 2000000
config.priority_prob_alpha = 0.5
config.max_moves = 2
config.td_steps = 5
config.start_transitions = 400
config.obj_size = 100
config.device = jax.default_backend()

print(f"jax backend: {jax.default_backend()}")
print(f"max_moves: {config.max_moves}")

buffer = fbx.make_prioritised_trajectory_buffer(
    add_batch_size=config.num_envs,
    sample_batch_size=config.batch_size,
    sample_sequence_length=config.obs_n_stack + config.num_unroll_steps + config.td_steps,
    period=1,
    min_length_time_axis=(config.start_transitions + config.num_envs - 1) // config.num_envs,
    max_length_time_axis=(config.buffer_size + config.num_envs - 1) // config.num_envs,
    priority_exponent=config.priority_prob_alpha,
    device=config.device,
)
template = jnp.zeros((config.obj_size,), dtype=jnp.float32)
buffer_state = buffer.init(template)

his = jnp.ones((config.num_envs, config.max_moves, config.obj_size), dtype=jnp.float32)

def buf_add(buffer_state):
    buffer_state = buffer.add(buffer_state, his)
    return buffer_state

def buf_add_sequential(buffer_state):
    for i in range(config.max_moves):
        buffer_state = buffer.add(buffer_state, his[:, i:i+1, :])
    return buffer_state


jit_buf_add_donate = jax.jit(buf_add, donate_argnums=(0,))
jit_buf_add_sequential_donate = jax.jit(buf_add_sequential, donate_argnums=(0,))
# warmup
buffer_state = jit_buf_add_donate(buffer_state)
jax.block_until_ready(buffer_state)
buffer_state = jit_buf_add_sequential_donate(buffer_state)
jax.block_until_ready(buffer_state)

tot = 10

start_time = timeit.default_timer()
for i in range(tot):
    buffer_state = jit_buf_add_donate(buffer_state)
jax.block_until_ready(buffer_state)
execution_time = timeit.default_timer() - start_time
print('jit_buf_add_donate execution_time (sec):', execution_time / tot)

start_time = timeit.default_timer()
for i in range(tot):
    buffer_state = jit_buf_add_sequential_donate(buffer_state)
jax.block_until_ready(buffer_state)
execution_time = timeit.default_timer() - start_time
print('jit_buf_add_sequential_donate execution_time (sec):', execution_time / tot)

output:

$ python buf.py 
jax backend: cpu
max_moves: 2
jit_buf_add_donate execution_time (sec): 0.19228420229628682
jit_buf_add_sequential_donate execution_time (sec): 7.678167894482612e-05

HeavyCrab avatar Aug 06 '25 16:08 HeavyCrab

ah! that is suuuuper weird! What i dont get is why this happens only on CPU. the adding sequentially really throws me, this makes me even more curious if the bug is not in the sum tree at all but somewhere else in the buffer logic. I unfortunately don't have a lot of time right now (nor am i the maintainer of the library anymore 😭). Thanks for providing these scripts though because if i find the time or my curiosity drives me insane this week, i can try look into it. @sash-a you have time to look? basically, i'd just try to isolate this to the sum tree asap and if thats the case it massively reduces the complexity of debugging.

EdanToledo avatar Aug 06 '25 17:08 EdanToledo

Thanks for all the info @EdanToledo. I don't think I'll have time for the next month to look into this, but I can try in September. Possible that @SimonDuToit may have some time, however he will be off for a while very soon.

sash-a avatar Aug 07 '25 14:08 sash-a

Hi there, any update on this issue?

HeavyCrab avatar Sep 13 '25 22:09 HeavyCrab