axolotl
axolotl copied to clipboard
FSDP Full-finetuned Model params and weights are NAN
Please check that this issue hasn't been reported before.
- [X] I searched previous Bug Reports didn't find any similar reports.
Expected Behavior
I fully finetuning using FSDP on TinyLama (4G) on a single machine with two GPUs. The flow is completed as expected and the model size is 2G as expected to use float16. The problem is that the model params and weights are NAN.
- The finetuned model:
weights = torch.load("model-out/pytorch_model.bin")
OrderedDict([('model.embed_tokens.weight', tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], dtype=torch.float16)), ('model.layers.0.self_attn.q_proj.weight', tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], dtype=torch.float16)), ('model.layers.0.self_attn.k_proj.weight', tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], dtype=torch.float16)),
- configuration.yaml:
base_model: TinyLlama/TinyLlama-1.1B-step-50K-105b
base_model_config: TinyLlama/TinyLlama-1.1B-step-50K-105b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
tokenizer_legacy: false
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: data.jsonl
type: completion
field: text
dataset_prepared_path:
val_set_size: 0.08
output_dir: model-out
adapter:
lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 12
eval_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
warmup_steps: 10
evals_per_epoch: 1
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
Log
- I noticed that the loss become Zero after a few iterations!
[2024-01-16 16:35:25,268] [INFO] [axolotl.calculate_total_num_steps:529] [PID:426] [RANK:0] total_num_steps: 689
[2024-01-16 16:35:25,277] [INFO] [axolotl.train.train:48] [PID:426] [RANK:0] loading tokenizer... TinyLlama/TinyLlama-1.1B-step-50K-105b
[2024-01-16 16:35:25,577] [DEBUG] [axolotl.load_tokenizer:75] [PID:426] [RANK:0] EOS: 2 / </s>
[2024-01-16 16:35:25,578] [DEBUG] [axolotl.load_tokenizer:76] [PID:426] [RANK:0] BOS: 1 / <s>
[2024-01-16 16:35:25,578] [DEBUG] [axolotl.load_tokenizer:77] [PID:426] [RANK:0] PAD: 2 / </s>
[2024-01-16 16:35:25,578] [DEBUG] [axolotl.load_tokenizer:78] [PID:426] [RANK:0] UNK: 0 / <unk>
Filter (num_proc=16): 0%| | 0/8260 [00:00<?, ? examples/s]
[2024-01-16 16:35:25,910] [INFO] [axolotl.train.train:56] [PID:426] [RANK:0] loading model and (optionally) peft_config...
Filter (num_proc=16): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8260/8260 [00:00<00:00, 8321.14 examples/s]
Downloading model.safetensors: 1%|█▏ | 31.5M/4.40G [00:00<00:28, 155MB/s]
/usr/local/lib/python3.10/dist-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.
table = cls._concat_blocks(blocks, axis=0)
Filter (num_proc=16): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 719/719 [00:00<00:00, 2785.06 examples/s]
[2024-01-16 16:35:27,248] [INFO] [axolotl.calculate_total_num_steps:529] [PID:427] [RANK:1] total_num_steps: 689
[2024-01-16 16:35:27,260] [INFO] [axolotl.train.train:48] [PID:427] [RANK:1] loading tokenizer... TinyLlama/TinyLlama-1.1B-step-50K-105b
Downloading model.safetensors: 5%|████████▋ | 220M/4.40G [00:01<00:20, 201MB/s]
[2024-01-16 16:35:27,556] [DEBUG] [axolotl.load_tokenizer:75] [PID:427] [RANK:1] EOS: 2 / </s>
[2024-01-16 16:35:27,556] [DEBUG] [axolotl.load_tokenizer:76] [PID:427] [RANK:1] BOS: 1 / <s>
[2024-01-16 16:35:27,556] [DEBUG] [axolotl.load_tokenizer:77] [PID:427] [RANK:1] PAD: 2 / </s>
[2024-01-16 16:35:27,556] [DEBUG] [axolotl.load_tokenizer:78] [PID:427] [RANK:1] UNK: 0 / <unk>
Downloading model.safetensors: 6%|██████████▎ | 262M/4.40G [00:01<00:20, 203MB/s]
[2024-01-16 16:35:27,828] [INFO] [axolotl.train.train:56] [PID:427] [RANK:1] loading model and (optionally) peft_config...
Downloading model.safetensors: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.40G/4.40G [00:21<00:00, 203MB/s]
Downloading generation_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 129/129 [00:00<00:00, 502kB/s]
[2024-01-16 16:35:49,636] [INFO] [axolotl.load_model:396] [PID:427] [RANK:1] GPU memory usage after model load: 2.062GB (+0.087GB cache, +1.026GB misc)
[2024-01-16 16:35:49,642] [INFO] [axolotl.load_model:424] [PID:427] [RANK:1] converting modules to torch.float16 for flash attention
[2024-01-16 16:35:50,064] [INFO] [axolotl.train.train:108] [PID:427] [RANK:1] Starting trainer...
[2024-01-16 16:35:51,523] [INFO] [axolotl.load_model:396] [PID:426] [RANK:0] GPU memory usage after model load: 2.062GB (+0.087GB cache, +1.026GB misc)
[2024-01-16 16:35:51,529] [INFO] [axolotl.load_model:424] [PID:426] [RANK:0] converting modules to torch.float16 for flash attention
[2024-01-16 16:35:51,933] [INFO] [axolotl.train.train:108] [PID:426] [RANK:0] Starting trainer...
{'loss': 6.6129, 'learning_rate': 0.0, 'epoch': 0.0}
0%|▎ | 1/688 [00:47<9:00:26, 47.20s/it]
[2024-01-16 16:37:09,838] [INFO] [axolotl.callbacks.on_step_end:122] [PID:427] [RANK:1] GPU memory usage while training: 0.368GB (+9.704GB cache, +1.219GB misc)
[2024-01-16 16:37:09,839] [INFO] [axolotl.callbacks.on_step_end:122] [PID:426] [RANK:0] GPU memory usage while training: 0.368GB (+9.704GB cache, +1.219GB misc)
{'loss': 5.6427, 'learning_rate': 0.0, 'epoch': 0.01}
{'loss': 4.7613, 'learning_rate': 0.0, 'epoch': 0.01}
{'loss': 6.0507, 'learning_rate': 0.0, 'epoch': 0.01}
{'loss': 4.0042, 'learning_rate': 2e-05, 'epoch': 0.01}
{'loss': 0.0, 'learning_rate': 2e-05, 'epoch': 0.02}
{'loss': 0.0, 'learning_rate': 2e-05, 'epoch': 0.02}
{'loss': 0.0, 'learning_rate': 2e-05, 'epoch': 0.02}
1%|██▍ | 8/688 [02:44<3:25:46, 18.16s/it]
Current behaviour
Should be a model that doesn't have all NAN values for param and weights.
Steps to reproduce
Same as expected Expected Behavior
Config yaml
base_model: TinyLlama/TinyLlama-1.1B-step-50K-105b
base_model_config: TinyLlama/TinyLlama-1.1B-step-50K-105b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
tokenizer_legacy: false
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: data.jsonl
type: completion
field: text
dataset_prepared_path:
val_set_size: 0.08
output_dir: model-out
adapter:
lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 12
eval_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
warmup_steps: 10
evals_per_epoch: 1
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
Possible solution
Which Operating Systems are you using?
- [X] Linux
- [ ] macOS
- [ ] Windows
Python Version
3.10
axolotl branch-commit
a045db02146751548fec57a5d3f31382ce4e5959
Acknowledgements
- [X] My issue title is concise, descriptive, and in title casing.
- [X] I have searched the existing issues to make sure this bug has not been reported yet.
- [X] I am using the latest version of axolotl.
- [X] I have provided enough information for the maintainers to reproduce and diagnose the issue.
if you're using fp16, you'll likely have to change your learning rate way down. you're getting over/underflows of the fp16 values leading to 0 loss
@winglian but I full finetuned the same TinyLLama model using fp16 with deepspeed Zero2, and there is no problem with it, no NAN weights.
@winglian btw I am using docker version with the following packages versions: cuda: 11.8 pytorch: 2.0.1+cu118 accelerate: 0.24.0.dev0 transformers: 4.35.0.dev0
@winglian I changed the learning rate to learning_rate: 0.000002 and the loss still become ZERO
@winglian Any idea, please?
Which GPU are you using? Are you able to use bf16? It should stabilize loss better.
@NanoCode012 I am using 2 X Tesla T4 using fp16
@hahmad2008 , it's possible that there is instability. Would you be able to use a newer gen GPU that's ampere gen? I would recommend enabling bf16: true
to prevent this issue.
Alternatively, can you try deepspeed? I believe there was also some issue with fsdp a while back.
@NanoCode012 Thanks, I will give it a try and will come back to you.
@NanoCode012 @winglian I tried with bf16 on A10 GPU and the training loss was stable, but with fp16 it was not stable the loss was jumping to zero and weight of the generated model was NAN.
btw, training with FP16, the trainer should use GradScaler. I am wondering if axolotl with FSDP uses gradscaler with mixed precision fp16 or not?