flashbax icon indicating copy to clipboard operation
flashbax copied to clipboard

[BUG] Large floating point errors

Open adzcai opened this issue 10 months ago • 12 comments

Describe the bug

The sum tree implementation is sensitive to floating point errors. I noticed during my RL training runs that, at a certain point, sampling from the buffer would return experiences of all zeros. Setting JAX_ENABLE_X64=true fixed the issue.

To Reproduce

Here's a code snippet that illustrates the error.

import flashbax as fbx
import jax.numpy as jnp
import jax

batch_size = 32
length = 64
buffer = fbx.make_prioritised_trajectory_buffer(
    add_batch_size=batch_size,
    sample_batch_size=batch_size,
    sample_sequence_length=length,
    period=1,
    min_length_time_axis=length,
    max_length_time_axis=length * 4,
)

# initialize the state (just add once)
state = buffer.init(0.0)
state = buffer.add(state, jnp.zeros((batch_size, length * 2)))
assert buffer.can_sample(state)
key = jax.random.key(42)
key, key_ = jax.random.split(key)
sample = buffer.sample(state, key_)

# update the priorities multiple times
for i in range(1000):
    key, key_ = jax.random.split(key)
    priorities = jax.random.uniform(key_, batch_size, minval=0.1, maxval=5)
    state = buffer.set_priorities(
        state,
        sample.indices,
        priorities,
    )
    nodes = state.priority_state.nodes
    if i % 100 == 0:
        print(nodes[0] - nodes[nodes.size // 2 :].sum())

Expected behavior

I would hope for the printed values to all be zero (or at least on the order of 1e-6). However, I get the values

-0.00024414062
0.0014648438
0.008056641
0.0075683594
0.0024414062
0.007080078
0.014404297
0.02368164
0.020996094
0.024414062

This makes it possible for an out-of-bounds index to be sampled.

Context (Environment)

I'm running macOS 15.3.1. I just ran pip install flashbax in a fresh environment:

absl-py==2.1.0
chex==0.1.88
etils==1.12.0
flashbax==0.1.2
flax==0.10.3
fsspec==2025.2.0
humanize==4.12.0
importlib-resources==6.5.2
jax==0.5.0
jaxlib==0.5.0
markdown-it-py==3.0.0
mdurl==0.1.2
ml-dtypes==0.5.1
msgpack==1.1.0
nest-asyncio==1.6.0
numpy==2.2.3
opt-einsum==3.4.0
optax==0.2.4
orbax-checkpoint==0.11.5
protobuf==5.29.3
pygments==2.19.1
pyyaml==6.0.2
rich==13.9.4
scipy==1.15.2
simplejson==3.20.1
tensorstore==0.1.71
toolz==1.0.0
treescope==0.1.8
typing-extensions==4.12.2
zipp==3.21.0

Additional context

No other context.

Possible Solution

It might be a good idea to just recompute the entire tree occasionally if the difference starts getting large.

adzcai avatar Feb 17 '25 14:02 adzcai

There was a bug with the sum tree and PER implementation that has now been fixed.

EdanToledo avatar Mar 17 '25 16:03 EdanToledo

I still observe this behavior (floating point errors building up) after the fix. It might be something else.

Sometimes, when nodes[0] > nodes[nodes.size // 2 :].sum(), you would sample nodes that are currently invalid and lead to sample.probabilities=0.0. A workaround is to 1. rebuild the entire sum tree from bottom to up from time to time, and 2. mask sampled transitions with probabilities=0.0, as it leads to nan in importance weights.

JINKEHE avatar Apr 08 '25 19:04 JINKEHE

Can you try get a reproduction script? I'm currently using the new PER and i haven't observed any nans yet but i understand this might happen in a pretty specific setting.

EdanToledo avatar Apr 08 '25 20:04 EdanToledo

Running the original script with the current version of flashbax gives me:

0.0
0.0014648438
0.008056641
0.0078125
0.0026855469
0.007080078
0.014404297
0.02368164
0.020996094
0.024414062

EdanToledo avatar Apr 08 '25 20:04 EdanToledo

Sorry, it will take some time for me to produce a minimal reproduction script.

But perhaps I can explain the logic behind 0 probability appearing in a sample. As far as I understand, the way you sample in the sum tree is first you sample query_key with query_value = query_value * _total_priority(state) where _total_priority is the value at the root node. However, due to the floating error that we are talking about here, _total_priority(state) can become bigger than the sum of leaf priorities, making it possible to sample outside of the valid leaves. Those invalid leaves have 0 priorities/probabilities. When you then do 1/probability for importance weights you get inf or nan but this is easy to prevent by adding a small number to the divisor and masking out samples with probability 0.0. It happens rarely, like once in a batch of size 256.

JINKEHE avatar Apr 09 '25 08:04 JINKEHE

Same issue here, any updates on this issue?

HeavyCrab avatar Jul 08 '25 06:07 HeavyCrab

Do you have a reproduction script? I imagine this bug is also potentially hardware specific to some extent due to floating point errors coming from the gpu. One solution I imagine is to clip priority updates to a lower precision so there is less chance of this happening.

EdanToledo avatar Jul 08 '25 07:07 EdanToledo

script

This is a minimal reproducible script: https://colab.research.google.com/drive/1cq2C8VTPUja--PvbKq21M8ULglME2vrU?usp=sharing The output shows

trained_steps: 207
if raw_batch is finite: False
raw_batch: (Array(True, dtype=bool), Array(True, dtype=bool), Array(False, dtype=bool))

which means the importance weight is NaN, indicating the probability is zero.

The output on GPU may vary time to time due to its internal randomness.

Possible solution

If @JINKEHE says correctly

But perhaps I can explain the logic behind 0 probability appearing in a sample. As far as I understand, the way you sample in the sum tree is first you sample query_key with query_value = query_value * _total_priority(state) where _total_priority is the value at the root node. However, due to the floating error that we are talking about here, _total_priority(state) can become bigger than the sum of leaf priorities, making it possible to sample outside of the valid leaves. Those invalid leaves have 0 priorities/probabilities. When you then do 1/probability for importance weights you get inf or nan but this is easy to prevent by adding a small number to the divisor and masking out samples with probability 0.0. It happens rarely, like once in a batch of size 256.

Instead of using global cumulative sums and binary search (sampling from $[0,R)$, where $R$ is the root value), I suggest implementing a sampling scheme that makes local decisions at each tree node. At each internal node with children of value $p$ and $q$, one can select the left child with probability $p / (p + q)$ and otherwise the right child, recursively traversing to a leaf.

This method guarantees the selected leaf having positive probability and only requires minimal modification.

HeavyCrab avatar Jul 09 '25 05:07 HeavyCrab

thanks for making this script. When i find the time i will try to properly nail down the issue and solve it. Whilst your sampling idea is not bad, ultimately we aim to be true to the original paper.

EdanToledo avatar Jul 09 '25 17:07 EdanToledo

https://github.com/instadeepai/flashbax/pull/58

There was another bug in PER that rounded new transitions priority value which this thread helped me find. I also just made it configurable. Essentially, the sum tree implementation being in float32 is the cause for this and i am not sure if there is any other way around this than simply using float64. Its what dopamines sum tree implementation does as well so they probably had the same problem. Beyond changing the sampling method to be different from the original PER paper, the only other solution is simply replacing zero priority transitions from the batch with another random transition from the batch. Since this is usually only 1 transition in a batch and it doesn't happen that much, replacing it with the most probable transition would almost definitely not affect performance of an agent. The reproduction script works with float64 enabled even if everything else is in float32. Its just the sum tree that needs to be in float64.

EdanToledo avatar Jul 09 '25 20:07 EdanToledo

Thanks for the quick update.

If the goal is to align with the original sampling approach — i.e., sampling a query value $x \sim \text{Unif}[0, R)$ and traversing the tree top-down—we can still improve numerical robustness without switching to a local sampling method or float64.

One idea is to normalize the query value x at each node, like this:

# at the beginning it is ensured $x\in [0,value[node])$
x = x / value[node] * (value[left_child] + value[right_child])
# after this, we have $x \in [0, value[left_child] + value[right_child])$

# Ideally we should have value[node] == value[left_child] + value[right_child],
# which often not holds in float representation of reals.
# In this way we can correct it with a multiplier close to 1.

if x > value[left_child]:
    x = x - value[left_child]
    # go right
else:
    # go left

x = Clip(x, 0, value[left_child] + value[right_child]) should also work, if we don't care about the tiny increases probability on last element.

This is just a quick idea I had—I’m not sure if it introduces any other issues, so I’d love to hear your thoughts on it.

HeavyCrab avatar Jul 09 '25 20:07 HeavyCrab

Let me think about this - when i have time i'll play around with this idea. But also some input from @SimonDuToit who is now the maintainer would be useful

EdanToledo avatar Jul 10 '25 09:07 EdanToledo