Training loss differs based on batch size
Please check that this issue hasn't been reported before.
- [x] I searched previous Bug Reports didn't find any similar reports.
Expected Behavior
I expect the training loss to be equal when I am training on the same dataset, especially on the very first step, independent of batch size given that the training dataset comprises of 16 identical samples.
Current behaviour
With 16 samples, micro_batch_size=16, gradient_accumulation_steps =1, my loss at step 1 is 2.4
With 16 samples, micro_batch_size=1, gradient_accumulation_steps =16, my loss at step 1 is 4.75
Steps to reproduce
- Create a dataset.jsonl file with 16 identical samples:
{"messages": [{"role": "user", "content": "What mammal lays the largest eggs?", "context": ""}, {"role": "assistant", "content": "The echidna, an Australian monotreme, is the largest egg-laying mammal."}]} - Run axolotl train with the config below.
- Run axolotl train with the 2nd config. You should see that the training losses are different on the initial step.
Config yaml
Config #1:
base_model: google/gemma-3-27b-pt
use_kernels: true # Use optimized kernels if available.
ddp_find_unused_parameters: true
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
chat_template: gemma3
eot_tokens:
- <end_of_turn>
# Dataset configuration
datasets:
- path: <path to local dataset>
field: messages
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
# Corresponds to seq_length and PackedSequenceSpecs in NeMo.
sequence_len: 4096
sample_packing: false
eval_sample_packing: false
gradient_accumulation_steps: 1
micro_batch_size: 16
# max_steps: 1000 #max_steps=1000,
max_steps: 1
# num_epochs: 1
# Optimizer settings from NeMo's OptimizerConfig.
optimizer: adamw_torch_fused
learning_rate: 5e-6 # lr=1e-6, changed to align with LlamaFactory
weight_decay: 0.1 # weight_decay=0.1,
adam_epsilon: 1e-5 # adam_eps=1e-5
adam_beta1: 0.9 # adam_beta1=0.9
adam_beta2: 0.95 # adam_beta2=0.95
max_grad_norm: 1.0 # clip_grad=1.0
# Scheduler settings from NeMo's CosineAnnealingScheduler.
lr_scheduler: cosine
# warmup_steps: 100 # warmup_steps=100
cosine_min_lr_ratio: 0.1 # min_lr/lr = 1e-7/1e-6 = 0.1
warmup_ratio: 0.03
# Precision settings. Corresponds to MegatronMixedPrecision.
bf16: true
fp16: false
tf32: true # Recommended for Ampere GPUs like H100 for better performance.
logging_steps: 1
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
# Parallelism strategy.
# FSDP (Fully Sharded Data Parallel) is the modern equivalent to NeMo's MegatronStrategy.
# This replaces tensor_model_parallel_size, pipeline_model_parallel_size, etc.
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
reshard_after_forward: true
# eval_steps: 100 # val_check_interval=100
eval_batch_size: 1
save_steps: 500 # Checkpoints will be saved at the same interval as validation.
save_total_limit: 2 # Keeps the latest and the best checkpoint.
metric_for_best_model: eval_loss
greater_is_better: false
include_tokens_per_second: true
save_safetensors: true
use_tensorboard: true
Config #2:
base_model: google/gemma-3-27b-pt
use_kernels: true # Use optimized kernels if available.
ddp_find_unused_parameters: true
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
chat_template: gemma3
eot_tokens:
- <end_of_turn>
# Dataset configuration
datasets:
- path: <path to local dataset>
field: messages
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
# Corresponds to seq_length and PackedSequenceSpecs in NeMo.
sequence_len: 4096
sample_packing: false
eval_sample_packing: false
gradient_accumulation_steps: 16
micro_batch_size: 1
# max_steps: 1000 #max_steps=1000,
max_steps: 1
# num_epochs: 1
# Optimizer settings from NeMo's OptimizerConfig.
optimizer: adamw_torch_fused
learning_rate: 5e-6 # lr=1e-6, changed to align with LlamaFactory
weight_decay: 0.1 # weight_decay=0.1,
adam_epsilon: 1e-5 # adam_eps=1e-5
adam_beta1: 0.9 # adam_beta1=0.9
adam_beta2: 0.95 # adam_beta2=0.95
max_grad_norm: 1.0 # clip_grad=1.0
# Scheduler settings from NeMo's CosineAnnealingScheduler.
lr_scheduler: cosine
# warmup_steps: 100 # warmup_steps=100
cosine_min_lr_ratio: 0.1 # min_lr/lr = 1e-7/1e-6 = 0.1
warmup_ratio: 0.03
# Precision settings. Corresponds to MegatronMixedPrecision.
bf16: true
fp16: false
tf32: true # Recommended for Ampere GPUs like H100 for better performance.
logging_steps: 1
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
# Parallelism strategy.
# FSDP (Fully Sharded Data Parallel) is the modern equivalent to NeMo's MegatronStrategy.
# This replaces tensor_model_parallel_size, pipeline_model_parallel_size, etc.
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
reshard_after_forward: true
# eval_steps: 100 # val_check_interval=100
eval_batch_size: 1
save_steps: 500 # Checkpoints will be saved at the same interval as validation.
save_total_limit: 2 # Keeps the latest and the best checkpoint.
metric_for_best_model: eval_loss
greater_is_better: false
include_tokens_per_second: true
save_safetensors: true
use_tensorboard: true
Possible solution
No response
Which Operating Systems are you using?
- [x] Linux
- [ ] macOS
- [ ] Windows
Python Version
3.11
axolotl branch-commit
main
Acknowledgements
- [x] My issue title is concise, descriptive, and in title casing.
- [x] I have searched the existing issues to make sure this bug has not been reported yet.
- [x] I am using the latest version of axolotl.
- [x] I have provided enough information for the maintainers to reproduce and diagnose the issue.
thanks for reporting the issue ,
can you try comparing the loss
For Config 1: Look at step 1
with
For Config 2: Look at step 16 after full gradient accumulation ?
If I run Config 2 with the change: max_steps: 1 -> max_steps: 16
{'loss': 4.75, 'grad_norm': 98.5, 'learning_rate': 5e-06, 'memory/max_mem_active(gib)': 25.45, 'memory/max_mem_allocated(gib)': 25.45, 'memory/device_mem_reserved(gib)': 30.22, 'epoch': 1.0}
{'loss': 2.0312, 'grad_norm': 59.5, 'learning_rate': 4.956766880907269e-06, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 2.0}
{'loss': 0.8984, 'grad_norm': 45.0, 'learning_rate': 4.828728948150395e-06, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 3.0}
{'loss': 0.4102, 'grad_norm': 19.625, 'learning_rate': 4.620806627680728e-06, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 4.0}
{'loss': 0.2031, 'grad_norm': 11.8125, 'learning_rate': 4.340990257669732e-06, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 5.0}
{'loss': 0.0889, 'grad_norm': 7.3125, 'learning_rate': 4.000033024294105e-06, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 6.0}
{'loss': 0.0305, 'grad_norm': 3.734375, 'learning_rate': 3.611037722821452e-06, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 7.0}
{'loss': 0.0125, 'grad_norm': 1.3046875, 'learning_rate': 3.188953224536289e-06, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 8.0}
{'loss': 0.0057, 'grad_norm': 0.703125, 'learning_rate': 2.7500000000000004e-06, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 9.0}
{'loss': 0.0027, 'grad_norm': 0.376953125, 'learning_rate': 2.3110467754637115e-06, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 10.0}
{'loss': 0.0015, 'grad_norm': 0.23046875, 'learning_rate': 1.888962277178548e-06, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 11.0}
{'loss': 0.001, 'grad_norm': 0.1533203125, 'learning_rate': 1.4999669757058956e-06, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 12.0}
{'loss': 0.0008, 'grad_norm': 0.11474609375, 'learning_rate': 1.1590097423302683e-06, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 13.0}
{'loss': 0.0007, 'grad_norm': 0.1005859375, 'learning_rate': 8.791933723192731e-07, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 14.0}
{'loss': 0.0006, 'grad_norm': 0.09326171875, 'learning_rate': 6.712710518496049e-07, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 15.0}
{'loss': 0.0006, 'grad_norm': 0.08837890625, 'learning_rate': 5.432331190927316e-07, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 16.0}
{'train_runtime': 230.7145, 'train_samples_per_second': 8.877, 'train_steps_per_second': 0.069, 'train_tokens_per_second': 71.014, 'train_loss': 0.5274004936218262, 'memory/max_mem_active(gib)': 34.89, 'memory/max_mem_allocated(gib)': 34.89, 'memory/device_mem_reserved(gib)': 41.26, 'epoch': 16.0}