DALI
DALI copied to clipboard
Data corruption with JAX plugin
Version
nvidia-dali-cuda120==1.40.0, jax==0.4.31
Describe the bug.
We recently discovered a problem that when we used DALI, our training curves were mysteriously worse. We were able to fix it by adding a jax.block_until_ready() call after each train step, to foil JAX's asynchronous dispatch. I therefore hypothesized that there was some sort of data corruption going on, caused by a lack of synchronization between JAX and DALI.
If I understand correctly, when you request the next element, the JAX plugin roughly does the following steps:
def get_next_element(pipe: dali.Pipeline):
# gets the next element
element = pipe.share_outputs()
# copies to JAX memory
element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in element]
# tells DALI that we are done with the output buffers
pipe.release_outputs()
# schedules the next fetch
pipe.schedule_run()
return element
I believe the problem is that jax.dlpack.from_dlpack and jnp.copy are both asynchronous. Therefore, pipe.release_outputs() is called before the copy actually occurs. When you request the next element, DALI can overwrite the output buffers before JAX is done reading from them.
I was able to fix the problem by bypassing the DALI JAX plugin, and adding the following line to my code:
def get_next_element(pipe: dali.Pipeline):
# gets the next element
element = pipe.share_outputs()
# copies to JAX memory
element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in element]
# wait until the copy is done (THIS IS THE MISSING STEP)
element = jax.block_until_ready(element)
# tells DALI that we are done with the output buffers
pipe.release_outputs()
# schedules the next fetch
pipe.schedule_run()
return element
Of course, I suspect this forces JAX to flush the entire GPU pipeline, which is somewhat inefficient. However, I don't think there's any way around this without deeper integration between JAX and DALI (specifically, I think DALI needs to provide a DLPack deleter, although I'm no expert).
Minimum reproducible example
No response
Relevant log output
No response
Other/Misc.
No response
Check for duplicates
- [x] I have searched the open bugs/issues and have found no duplicates for this bug report
Hello @kvablack
thank you for reporting this issue. Do you have some standalone script to reproduce it reliably. It would help a lot as I had no luck in reproducing it so far.
As far as I remember DALI should pass DLPack deleter to JAX to be called when capsule is no longer needed. There might be some issue with it as you pointed out.
Here you go, reproduces on a 4090. You do need a couple of things to trigger the issue -- namely, large enough arrays and a fake "train step".
import jax
import jax.numpy as jnp
import numpy as np
from nvidia import dali
import tqdm
# this needs to be large to trigger the bug.
ARR_SIZE = 2**16
class ExternalSource:
def __call__(self, sample_info: dali.types.SampleInfo):
return [
np.full((ARR_SIZE), fill_value=sample_info.idx_in_epoch, dtype=np.int32),
np.full((ARR_SIZE), fill_value=sample_info.idx_in_epoch, dtype=np.int32),
]
def get_pipe():
@dali.pipeline_def(
batch_size=2,
num_threads=1,
prefetch_queue_depth=6,
py_start_method="spawn",
)
def pipeline():
outputs = dali.fn.external_source(
source=ExternalSource(),
num_outputs=2,
batch=False,
parallel=True,
)
outputs = [arr.gpu() for arr in outputs]
return tuple(outputs)
pipe = pipeline(device_id=0)
pipe.build()
pipe.schedule_run()
return pipe
def get_next(pipe: dali.Pipeline):
outputs = pipe.share_outputs()
element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in outputs]
# UNCOMMENT THIS LINE TO MAKE THE ASSERTIONS PASS.
# jax.block_until_ready(element)
pipe.release_outputs()
pipe.schedule_run()
return element
@jax.jit
def f(x):
# matmul to simulate the train step. interestingly, this is required to reproduce the bug.
return [y @ y.T for y in x]
if __name__ == "__main__":
pipe = get_pipe()
batches = []
for _ in tqdm.trange(10):
batches.append(f(get_next(pipe)))
for i in tqdm.trange(len(batches)):
for elem in batches[i]:
# `x` is what the ExternalSource should have returned
x = jnp.broadcast_to(jnp.arange(elem.shape[0])[:, None] + i * elem.shape[0], (elem.shape[0], ARR_SIZE))
# `y` is what the "train step" should have returned
y = x @ x.T
# check that it matches the batch that came out of the pipeline
assert jnp.all(y == elem), f"Failed on batch {i}\n\nExpected:\n{y}\n\nGot:\n{elem}"
+1
Hello @kvablack ,
The problem has been resolved in the new (opt-in) "dynamic" executor, available in recent DALI versions.
Please update DALI to 1.44 (or newer, when available) and specify exec_dynamic=True in your pipeline.
Also, the function _expose_dlpack_capsule was a private, temporary solution and has been removed. DALI tensors now fully implement DLPack protocol, so you can just pass DALI tensor directly to from_dlpack - JAX will pick use the __dlpack__ under the hood.
@dali.pipeline_def(
batch_size=2,
num_threads=1,
exec_dynamic=True, # <-------------------- add this
prefetch_queue_depth=6,
py_start_method="spawn",
)