streaming icon indicating copy to clipboard operation
streaming copied to clipboard

How to handle broken samples?

Open universome opened this issue 9 months ago • 6 comments

Hi, thank you for your project, it is really helpful. I wanted to ask a question how should I handle broken samples? I work with videos, and sometimes they fail to get decoded due to some random internal ffmpeg failures (and there are really no alternative to ffmpeg — all current python video libs rely on it at least partially). My original attempt was to implement the handling in the dataloader via something like this (some code omitted for brevity):

class FilteredDataLoader(StreamingDataLoader):
    def __iter__(self) -> Iterator[Any]:
        cur_incomplete_batch = None

        for batch in super().__iter__():
            batch_size = self._get_batch_size(batch)
            batch_filtered = self._filter_batch(batch)
            non_broken_batch_size = self._get_batch_size(batch_filtered)

            if cur_incomplete_batch is None and non_broken_batch_size == batch_size:
                yield batch_filtered # Best situation: no broken samples in this batch.
            else:
                cur_incomplete_batch = batch_filtered if cur_incomplete_batch is None else self._merge_batches(cur_incomplete_batch, batch_filtered, batch_size)
                if self._get_batch_size(cur_incomplete_batch) == batch_size:
                    yield cur_incomplete_batch # We have a complete batch now.
                    cur_incomplete_batch = None

    def _filter_batch(self, batch: Dict[str, Iterable]) -> Dict[str, List]:
        return {k: self.collate_fn([v for v, b in zip(vs, batch['is_broken']) if not b]) for k, vs in batch.items()}

    def _merge_batches(self, lhs: Dict[str, Iterable], rhs: Dict[str, Iterable], max_batch_size: int) -> Dict[str, Iterable]:
        return {k: self.collate_fn((list(lhs[k]) + list(rhs[k]))[:max_batch_size]) for k in lhs.keys()}

My dataset was returning an additional field is_broken when the sample was broken, but this approach led to a problem that different processes can now have different epoch sizes and I had issues combining such logic with StreamingDataset (I had freezes in some parts of my code). So now I have the following logic:

class SafeStreamingDataset(StreamingDataset):
    def _unsafe_get_item(self, idx: int, num_retries_left: int=None) -> Dict[str, Any]:
        return super().get_item(idx)

    def get_item(self, idx: int, num_retries_left: int=None):
        try:
            return self._unsafe_get_item(idx)
        except Exception as e:
            print(f"Exception in __getitem__({idx}): {e}")
            num_retries_left = self.num_retries if num_retries_left is None else num_retries_left
            if num_retries_left >= 0:
                return self.get_item(idx=np.random.randint(low=0, high=len(self)), num_retries_left=num_retries_left - 1)
            else:
                print(f"Failed to load the video even after {self.num_retries} retries. Something is broken.")
                raise e

In general, I want my data loading to be robust to failures so that not to break the entire training process which is expensive to rewind because of a single broken sample. The above solution is not too good because random access fetches extra shards which are then discarded.

What is the recommended strategy to handle broken samples with Streaming Dataset? Do you think it should work fine if I implement such filtering logic on the side of the IterableDataset (i.e., implementing an __iter__ method for it which throws away bad samples)?

universome avatar Sep 25 '23 11:09 universome