streaming
streaming copied to clipboard
How to handle broken samples?
Hi, thank you for your project, it is really helpful. I wanted to ask a question how should I handle broken samples? I work with videos, and sometimes they fail to get decoded due to some random internal ffmpeg
failures (and there are really no alternative to ffmpeg — all current python video libs rely on it at least partially). My original attempt was to implement the handling in the dataloader via something like this (some code omitted for brevity):
class FilteredDataLoader(StreamingDataLoader):
def __iter__(self) -> Iterator[Any]:
cur_incomplete_batch = None
for batch in super().__iter__():
batch_size = self._get_batch_size(batch)
batch_filtered = self._filter_batch(batch)
non_broken_batch_size = self._get_batch_size(batch_filtered)
if cur_incomplete_batch is None and non_broken_batch_size == batch_size:
yield batch_filtered # Best situation: no broken samples in this batch.
else:
cur_incomplete_batch = batch_filtered if cur_incomplete_batch is None else self._merge_batches(cur_incomplete_batch, batch_filtered, batch_size)
if self._get_batch_size(cur_incomplete_batch) == batch_size:
yield cur_incomplete_batch # We have a complete batch now.
cur_incomplete_batch = None
def _filter_batch(self, batch: Dict[str, Iterable]) -> Dict[str, List]:
return {k: self.collate_fn([v for v, b in zip(vs, batch['is_broken']) if not b]) for k, vs in batch.items()}
def _merge_batches(self, lhs: Dict[str, Iterable], rhs: Dict[str, Iterable], max_batch_size: int) -> Dict[str, Iterable]:
return {k: self.collate_fn((list(lhs[k]) + list(rhs[k]))[:max_batch_size]) for k in lhs.keys()}
My dataset was returning an additional field is_broken
when the sample was broken, but this approach led to a problem that different processes can now have different epoch sizes and I had issues combining such logic with StreamingDataset (I had freezes in some parts of my code). So now I have the following logic:
class SafeStreamingDataset(StreamingDataset):
def _unsafe_get_item(self, idx: int, num_retries_left: int=None) -> Dict[str, Any]:
return super().get_item(idx)
def get_item(self, idx: int, num_retries_left: int=None):
try:
return self._unsafe_get_item(idx)
except Exception as e:
print(f"Exception in __getitem__({idx}): {e}")
num_retries_left = self.num_retries if num_retries_left is None else num_retries_left
if num_retries_left >= 0:
return self.get_item(idx=np.random.randint(low=0, high=len(self)), num_retries_left=num_retries_left - 1)
else:
print(f"Failed to load the video even after {self.num_retries} retries. Something is broken.")
raise e
In general, I want my data loading to be robust to failures so that not to break the entire training process which is expensive to rewind because of a single broken sample. The above solution is not too good because random access fetches extra shards which are then discarded.
What is the recommended strategy to handle broken samples with Streaming Dataset? Do you think it should work fine if I implement such filtering logic on the side of the IterableDataset (i.e., implementing an __iter__
method for it which throws away bad samples)?
Hey, thanks for raising this issue. Your current solution which filters in the dataloader will work, but as you say, it may be slow and can pose issues with batch construction and expected number of samples. Unfortunately since it's impossible to know which samples are broken beforehand, and because StreamingDataset partitions and shuffles samples beforehand for elastic determinism and resumption, it's not possible to filter on the StreamingDataset
side. Would it be possible to convert the videos into a usable format beforehand, and then create a new dataset of mds shards with that?
Hi @snarayan21 , thank you for your response.
Unfortunately, it's not possible to have 100% stable videos, since some decoding errors can happen in external libraries (e.g., ffmpeg) and we want to have some error handling so that the entire training process does not fail due to a random loading error.
Also, we want to do filtering of our existing dataset based on some "business logic". E.g., imagine that we want to train several models on different subsets of LAION-5B based on the aesthetic score threshold to see which one has the best quality. Do I get it right that it's impossible to implement something like this with StreamingDataset and we should switch to something else? Can such logic be easily implemented if we drop the elastic determinism and resumption?
Hi @snarayan21 . After thinking a bit more about the scenario and inspecting your source code, I think that:
- Filtering of the broken samples should work fine with smth similar to that
FilteredDataLoader
I mentioned above, since all the dataset partitions seem to be traversed independently and it shouldn't make one dataset for one rank to increment its epoch earlier than its peers (I found an issue in my own metrics computation code when the epochs in different processes had different sizes, but for training with an infinite dataloader this shouldn't be an issue). This will break the assumption however that the dataloaders in different ranks yield the same amount of samples and will make the resumption only approximate, butnum_samples_yielded
shouldn't deviate too much in different processes unless shuffling is broken (note that I am talking about unfilterednum_samples_yielded
everywhere here). I also tried implementing the filtering logic by tweaking the__iter__
method of StreamingDataset — it works, but I guess will lead to worse discrepancies in resumption since we would only be able to keepnum_samples_yielded
either per a worker or per filtered dataloader (depending on how we'll count). - The scenario with various aesthetics thresholds for LAION could be easily supported using bucketing and multiple index files (i.e., having
index_aesthetic_thresh_geq_0.json
,index_aesthetic_thresh_geq_1.json
, ...,index_aesthetic_thresh_geq_X.json
instead of just a singleindex.json
), but I would need to fork your repo to support this logic since theindex.json
name is hardcoded in StreamingDataset currently. Also, it seems to be impossible to do some bucketed sharding with the current MDSWriter in an "online" fashion, when we have N queues for N buckets, that write samples appropriately once they are ready.
What do you think?
Hey @universome, this is really useful! Filtering with StreamingDataset is something the team has thought about, but yes, it is hard to have filtering in conjunction with elastic determinism and resumption since the number of samples is not known in advance. To your points:
- Filtering on the dataloader side, especially with an infinite dataloader / no multi-epoch training is possible, as you have shown above. Filtering on the dataset side is also possible but I think it makes more sense to do through the dataloader. Then the resumption logic will still use the unfiltered number of samples on the dataset side. Resumption actually just keeps track of the number of samples seen so far, and drops that many number of samples from the sample partition. So for an infinite dataloader where you would expect to stop training before any one worker runs out of samples, you could do filtering as described. A better approach may be to use our
dataframe_to_mds
converter and do filtering on the dataframe, convert to MDS, and run training on that. - In this scenario, we do support creating different index files for each bucket of your dataset -- in fact, this is the approach we recommend to customers. While each bucket must have its own
index.json
file since we don't relax the naming requirement, there are two approaches you can take here:- if you know which shards belong to each bucket in advance, you can create a new directory for each bucket that contains only the new
index.json
file for that bucket. The shards can still all live in one directory, you just have to change the file paths in theindex.json
file. - if you want to do bucketed sharding, you can create an MDSWriter for each bucket and send each sample to each writer based on your filtering criteria. Each MDSWriter should write to its own directory and will create an
index.json
file for that bucket. If you want to multiprocess this, split your dataset K-ways and have each process do 1/K of the dataset. Then you can merge the index files for each bucket by merging shard groups, using the functionality here.
- if you know which shards belong to each bucket in advance, you can create a new directory for each bucket that contains only the new
Down the line, we are working to integrate better with the Databricks platform so that we can enable more advanced features like directly streaming and filtering. Let me know if this helps, and happy to further continue this discussion!
Hi @snarayan21 , sorry for the late reply and thank you for the advice. For filtering, in my current implementation I pre-filtered the dataset and added some dirty resampling logic (via random access) to replace broken samples on the fly, and it seems to be working fine for the current ratio of broken samples. I will experiment with a more rigorous filtered dataloader in the near future.
Your solution with bucketing through multiple directories is reasonable, I will proceed with that.
I will close the issue currently, since right now I do not have any questions to ask about filtering. Thanks again for your help.
Ok, turns out the discussed approach does not work. I've described the problem in that ticket, but also copying it here for completeness:
The problem is that one worker finishes its epoch earlier, tries to rewind the epoch, gets stuck on the shared barrier, while other workers are still doing training on the first epoch and get stuck on the
torch.distributed.barrier
, which is always present in DDP training (it syncs normally on a backward pass). This leads to the deadlock until the entire run is killed on DDP timeout.I have made a fork which simply removes all the shared barriers in
_get_work
and epoch resumption. Could you please take a look at it to say if there are any terrible side-effects of my changes (I didn't have time to analyze the entire codebase)? Can it lead to samples being duplicated among different workers? (as far as I understand it shouldn't since we still selectworker_sample_ids
from the sameepoch_sample_ids
). If it's just each worker computingepoch_sample_ids
on its own (andepoch_sample_ids
is still equivalent between all the workers), then it does not seem too big of a deal to be honest.In my use-case, I have some filtering happening in the dataloader (filtering out short videos) and often have fewer iterations in some workers compared to other ones. When the amount of iterations among the workers is different, this leads to a deadlock for the reason I described above.
P.S. I had to also change the shuffling strategy in such a way that next_epoch is not taken from shared memory, but is rather unique for each worker. The rationale is that in
generate_work
, some workers might take the incrementednext_epoch
from the shared memory. Could you please tell the motivation to keep next_epoch in the shared memory? When can it diverge?
And the above approach with tweaking away shared memory syncs, my startup time became too slow for large jobs (16+ nodes).