DALI icon indicating copy to clipboard operation
DALI copied to clipboard

Dali pipeline with jax framework

Open zhaoyang-0204 opened this issue 4 years ago • 2 comments

Hi,

I am currently using the jax for training, and going to use dali as the data processing pipeline. I found that there are Tensorflow plugins (I have used before) and so on, but there is not a wrapped plugin to jax. I would like to find out the proper way to pass the data to the jax.

In my understanding, jax generally ultilizes the function pmap() to operate parallel bewteen devices. This function maps the given target function over array axes. So, for inputs, jax would use an additional axis to indicate the inputs of specific devices (just like sharding). For examples, for batch inputs with shape (batch, H, W, C), the batch axis would be reshaped to (num_devices, batch // num_devices, H, W, C). And the type of inputs is usually np.ndarray.

And also in my underanding on dali, dali use the sharding methodology to support for parallel operating. Different devices would launch a different partition of inputs by specifying the shard_id argument explicitly. I think there may be some parctical gaps on the implementation of parallelism between dali and jax. And currently, I follow the conventional way to build a dali pipeline, and then convert the output to numpy array,

pipe = Pipeline(...) pipe.build() output = pipe.run() output = output.as_array() output = reshape(output)

I built it this way, and found it is very costly w.r.t GPU memory. Only very low batch size is allowed, compared to that using sharding in Tensorflow.

So, I wonder is this the proper way of using dali to pass the data to jax? Alternatively, could we reshape the batch axis inside the pipeline? Then, is there some way to directly pass the data processed by the dali gpu to jax mapping functions, instead of this cpu -> gpu way.

Thanks for the excellent work of dali, and I am very interested in a further understanding of dali. Thanks for your time and responeses.

zhaoyang-0204 avatar Dec 23 '21 04:12 zhaoyang-0204

Hi @zhaoyang-0204,

JAX is still advertised as an experimental project and we haven't investigated how to integrate DALI with it. Thus there is no recommended way of using DALI with it. I think it may require building a native plugin/python module similar to the TensorFlow case where DALI would run a separate pipeline for each GPU and pass the relevant memory to each XLA instance. At the current stage of JAX development, we don't plan to support it officially and we welcome any external contribution that would enable it.

JanuszL avatar Dec 24 '21 00:12 JanuszL

HI, @JanuszL

Thanks for your kindly reply. I will study further and thanks for your help.

zhaoyang-0204 avatar Dec 24 '21 06:12 zhaoyang-0204

Hello @zhaoyang-0204

I wanted to let you know that we developed official support for DALI and JAX integration. To learn more you can take a look at the Getting started with JAX and DALI tutorial. We support both pmap and sharding to scale the workflow to multiple GPUs. You can find examples of how to do it in this section of the tutorial.

awolant avatar Jan 24 '24 21:01 awolant

Thanks very much for noticing me . Really cool @awolant @JanuszL . I would give it a try.

zhaoyang-0204 avatar Jan 25 '24 08:01 zhaoyang-0204