axolotl
axolotl copied to clipboard
Wrong epoch when turning on `context_parallel_size `
Please check that this issue hasn't been reported before.
- [x] I searched previous Bug Reports didn't find any similar reports.
Expected Behavior
The number of epochs should remain consistent regardless of whether context_parallel_size is configured or not.
Current behaviour
When it is set, the training final log behaves as follows:
{'train_runtime': 79.4145, 'train_samples_per_second': 25.789, 'train_steps_per_second': 0.201, 'train_loss': 3.3714769408106804, 'memory/max_mem_active(gib)': 13.62, 'memory/max_mem_allocated(gib)': 13.62, 'memory/device_mem_reserved(gib)': 14.68, 'epoch': 5.42}
When it is not set, the training final log behave as follows:
{'train_runtime': 49.8935, 'train_samples_per_second': 5.131, 'train_steps_per_second': 0.04, 'train_loss': 5.478998422622681, 'memory/max_mem_active(gib)': 46.82, 'memory/max_mem_allocated(gib)': 46.82, 'memory/device_mem_reserved(gib)': 47.9, 'epoch': 0.84}
Steps to reproduce
axolotl train placeholder.yaml
Config yaml
base_model: /model/Llama-3.2-1B-Instruct
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
# processor_type: AutoProcessor
remove_unused_columns: false
shuffle_merged_datasets: true
# Training configuration
output_dir: /data/axolotl/outputs/Llama3.2-1B-test-1015-wo-cp
num_epochs: 1
# Sequence and packing settings
sequence_len: 32768
sample_packing: true
pad_to_sequence_len: true
pretrain_multipack_attn: true # Prevent cross-attention between packed sequences
flash_attention: true
# context_parallel_size: 8 # You should change this to survey.
sample_packing_bin_size: 1000
sample_packing_group_size: 200000
# Batch size settings
gradient_accumulation_steps: 16
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch_fused
lr_scheduler: cosine
cosine_min_lr_ratio: 0.1
learning_rate: 2e-4
warmup_ratio: 0.05
weight_decay: 0.01
# Precision and performance
bf16: true
tf32: true
# Logging and checkpointing
logging_steps: 1
save_strategy: steps
save_steps: 2000
save_total_limit: 3
# Data
dataset_prepared_path: /data/axolotl/prepared_path/Llama3.2-1B-test-1015
dataloader_num_workers: 4
dataset_processes: 128
deepspeed: /data/axolotl/deepspeed_configs/zero1.json
special_tokens:
pad_token: <|finetune_right_pad_id|>
datasets:
- path: /data/datasets/pt/fineweb_sample10BT/fineweb_sample10BT.parquet
ds_type: parquet
type: completion
field: text
split: "train[:10000]"
Possible solution
No response
Which Operating Systems are you using?
- [x] Linux
- [ ] macOS
- [ ] Windows
Python Version
3.12.3
axolotl branch-commit
main
Acknowledgements
- [x] My issue title is concise, descriptive, and in title casing.
- [x] I have searched the existing issues to make sure this bug has not been reported yet.
- [x] I am using the latest version of axolotl.
- [x] I have provided enough information for the maintainers to reproduce and diagnose the issue.
Discord thread for reference: https://discord.com/channels/1104757954588196865/1426831119340273787/1427939353446977607