InternVL
InternVL copied to clipboard
[Bug] Weird implementation of sampler used in training
Checklist
- [X] 1. I have searched related issues but cannot get the expected help.
- [X] 2. The bug has not been fixed in the latest version.
- [X] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
Describe the bug
As I understand, the class LengthGroupSampler (link) is used to make sure the sample lengths in a micro batch are close to each other in order to reduce the computation waste of padding tokens. However, it seems this sampler will result in more padding tokens, than the random sampler.
Reproduction
max_length = 1920
min_length = 1680
lengths = list(range(min_length, max_length + 1)) * 1024
batch_size = 2
world_size = 64
generator = torch.Generator().manual_seed(42)
for world_size in [4, 8, 16, 32, 64, 128]:
for indices in [
get_length_grouped_indices(lengths, batch_size, world_size, generator=generator), # https://github.com/OpenGVLab/InternVL/blob/main/internvl_chat/internvl/patch/train_sampler_patch.py#L36
torch.randperm(len(lengths), generator=generator)
]:
lengths = [lengths[i] for i in indices]
batched_lengths = [lengths[i : i + batch_size] for i in range(0, len(lengths), batch_size)]
total_pad = 0
for batch in batched_lengths:
max_length = max(batch)
total_pad += max_length * len(batch) - sum(batch)
sampler_name = "LengthGroupedSampler" if isinstance(indices, list) else "RandomSampler"
print(f"{sampler_name} Avg padding: {total_pad / len(lengths):.2f} for world size {world_size}")
Environment
It does not rely on 3rd party envs.
I use commit `80776deaecbe4` of this repo.
Error traceback
The result shows LengthGroupedSampler adds more padding tokens than RandomSampler.
LengthGroupedSampler Avg padding: 67.03 for world size 4
RandomSampler Avg padding: 49.64 for world size 4
LengthGroupedSampler Avg padding: 71.27 for world size 8
RandomSampler Avg padding: 49.43 for world size 8
LengthGroupedSampler Avg padding: 73.32 for world size 16
RandomSampler Avg padding: 49.51 for world size 16
LengthGroupedSampler Avg padding: 74.44 for world size 32
RandomSampler Avg padding: 49.35 for world size 32
LengthGroupedSampler Avg padding: 75.13 for world size 64
RandomSampler Avg padding: 49.27 for world size 64
LengthGroupedSampler Avg padding: 75.39 for world size 128
RandomSampler Avg padding: 49.36 for world size 128