vggt icon indicating copy to clipboard operation
vggt copied to clipboard

Bug In trainer.py while accum_steps > 1

Open xsddff opened this issue 1 month ago • 1 comments

seek the function get_chunk_from_data in line 836. Acctually data's type is Dict.
When using the code len(data) // num_chunks it would return the nums of keys,rather than batch_size? The real batch_size is inside the value of every keys

xsddff avatar Nov 19 '25 05:11 xsddff

Hi here is the response from a llm (I did not carefully check it though).

Based on the code provided, the comment/bug report is incorrect.

The logic in get_chunk_from_data handles dictionaries correctly. The user who reported the bug likely misunderstood the order of operations in the if/elif block.

Here is the step-by-step breakdown of why the code is correct:

1. The Logic Flow
When chunk_batch_for_accum_steps passes a batch (which is a Dict) to get_chunk_from_data, the following happens:

Check 1: if isinstance(data, torch.Tensor) or is_sequence_of_primitives(data):

A Python dict is not a torch.Tensor.

A Python dict is not a Sequence (it is a Mapping).

Result: The code skips this block. It does not execute len(data) // num_chunks on the dictionary here.

Check 2: elif isinstance(data, Mapping):

A Python dict is a Mapping.

Result: The code enters this block.

2. The Recursion
Inside the Mapping block, the code executes this:

Python

return {
    key: get_chunk_from_data(value, chunk_id, num_chunks)
    for key, value in data.items()
}
It does not attempt to slice the dictionary using len(data). Instead, it iterates over the keys and calls get_chunk_from_data recursively on the values (which are the actual Tensors).

3. The Actual Slicing
The slicing happens in the recursive call, when data is finally a torch.Tensor:

The function is called with a Tensor (e.g., batch['images']).

Check 1: if isinstance(data, torch.Tensor)... evaluates to True.

Action: It calculates len(data) // num_chunks.

For a PyTorch Tensor, len() returns the size of the first dimension (the batch size).

Conclusion
The code correctly drills down into the dictionary keys until it finds the tensors, and then slices the tensors based on the batch size. It does not attempt to slice the dictionary keys or use the number of keys as the batch size.

The code provided is safe and functional for accum_steps > 1.

jytime avatar Nov 28 '25 10:11 jytime