Multiple sharding policy in plugin.jax.data_iterator
Describe the question.
Is there a way to set jax sharding for each output separately plugin.jax.data_iterator? For example, I have pipeline, that has 2 outputs. I want first output to be PartitionSpec(‘batch’, ‘model’) and the second to be PartitionSpec(‘batch, None) or PartitionSpec(‘batch’)?
Check for duplicates
- [X] I have searched the open bugs/issues and have found no duplicates for this bug report
Hello @sali1997s
Thanks for the question. Currently, something like this not supported unfortunately. This enhancement is in our TODO list for JAX integration.
Could you tell more about your use case and how would you need this to work? Especially, how do you map this to map on devices? Do you need both CPU and the GPU? I am asking because with DALI pipelines working on particular GPU there are some design and performance considerations for this feature and we would like the input from the users to influence these decisions. Thanks!
Thank you, for answering, @awolant! Sorry, I was thinking about my task deeper, and came to conclusion that partitioning data over batch fully covers my needs. I thought, i need more control over partitioning, but i don't need it currently.
But i've found that dataloader workes only in Data Parallel training, it currently doesn't support model-parallism inside.
Here is a minimal reproducable example. By changing device_mesh to mesh_utils.create_device_mesh((4, 2)) it fails.
Also i've got question about @data_iterator (size param) and external source interaction. In case number of samples is divisible by shard size it works as supposed. But in other case it fails with
WARNING:root:DALI iterator does not support resetting while epoch is not finished. Ignoring.... And doesn't go for second epoch iteration. Is there i can do something with it?
from nvidia.dali.plugin.jax import data_iterator
from jax.experimental import mesh_utils
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import numpy as np
GLOBAL_BATCH_SIZE = 64
class DataSorceCallable:
def __init__(self, batch_size, seed, shard_id, num_shards):
self.rng = np.random.default_rng(seed=seed)
self.batch_size = batch_size
self.files = np.random.rand(GLOBAL_BATCH_SIZE * 10, 4).astype(np.float32)
self.shard_id = shard_id
self.num_shards = num_shards
self.shard_size = len(self.files) // num_shards
self.shard_offset = self.shard_size * shard_id
# If the shard size is not divisible by the batch size, the last incomplete batch
# will be omitted.
self.full_iterations = self.shard_size // batch_size
# print(self.full_iterations, self.shard_size, batch_size, len(self.files))
self.perm = None
self.last_seen_epoch = (
None # so that we don't have to recompute the `self.perm` for every sample
)
def __call__(self, sample_info):
if sample_info.iteration >= self.full_iterations:
raise StopIteration()
if self.last_seen_epoch != sample_info.epoch_idx:
self.last_seen_epoch = sample_info.epoch_idx
self.perm = np.random.default_rng(seed=42 + sample_info.epoch_idx).permutation(
len(self.files)
)
sample_idx = self.perm[sample_info.idx_in_epoch + self.shard_offset]
return self.files[sample_idx, :]
if __name__ == "__main__":
device_mesh = mesh_utils.create_device_mesh((8, 1))
mesh = Mesh(device_mesh, axis_names=("batch",'model'))
sharding = NamedSharding(mesh, PartitionSpec("batch",))
@data_iterator(output_map=['out'], sharding=sharding, size = GLOBAL_BATCH_SIZE * 10, prepare_first_batch = False)
def callable_pipeline(num_shards, shard_id):
out, = fn.external_source(
source=DataSorceCallable(GLOBAL_BATCH_SIZE//num_shards, num_shards=num_shards, shard_id=shard_id, seed=42),
num_outputs=1,
batch=False,
# parallel=True,
dtype=[types.FLOAT],
)
return out.gpu()
dataloader = callable_pipeline(batch_size = GLOBAL_BATCH_SIZE)
for el in dataloader:
print(el['out'].sharding)
Thanks for the reproduction. This is definitely a feature that could be added to DALI JAX support to enhance in functionality. For the first version of this integrating layer we focused only on the most common and simple cases.
When it comes to your question about external source, unfortunately, right now there is no way to do something like this. As I said, for this first version we wanted it to work in the most common and stable case.
In your use case, how would you expect this to work? I am asking just to get feedback about possible improvements to the JAX integration? Would you like for the missing samples to be filled/duplicated somehow?