axolotl icon indicating copy to clipboard operation
axolotl copied to clipboard

FSDP Full-finetuned Model params and weights are NAN

Open hahmad2008 opened this issue 1 year ago • 9 comments

Please check that this issue hasn't been reported before.

  • [X] I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

I fully finetuning using FSDP on TinyLama (4G) on a single machine with two GPUs. The flow is completed as expected and the model size is 2G as expected to use float16. The problem is that the model params and weights are NAN.

  • The finetuned model: weights = torch.load("model-out/pytorch_model.bin")
OrderedDict([('model.embed_tokens.weight', tensor([[nan, nan, nan, ..., nan, nan, nan],
    [nan, nan, nan, ..., nan, nan, nan],
    [nan, nan, nan, ..., nan, nan, nan],
    ...,
    [nan, nan, nan, ..., nan, nan, nan],
    [nan, nan, nan, ..., nan, nan, nan],
    [nan, nan, nan, ..., nan, nan, nan]], dtype=torch.float16)), ('model.layers.0.self_attn.q_proj.weight', tensor([[nan, nan, nan, ..., nan, nan, nan],
    [nan, nan, nan, ..., nan, nan, nan],
    [nan, nan, nan, ..., nan, nan, nan],
    ...,
    [nan, nan, nan, ..., nan, nan, nan],
    [nan, nan, nan, ..., nan, nan, nan],
    [nan, nan, nan, ..., nan, nan, nan]], dtype=torch.float16)), ('model.layers.0.self_attn.k_proj.weight', tensor([[nan, nan, nan, ..., nan, nan, nan],
    [nan, nan, nan, ..., nan, nan, nan],
    [nan, nan, nan, ..., nan, nan, nan],
    ...,
    [nan, nan, nan, ..., nan, nan, nan],
    [nan, nan, nan, ..., nan, nan, nan],
    [nan, nan, nan, ..., nan, nan, nan]], dtype=torch.float16)), 
  • configuration.yaml:
base_model: TinyLlama/TinyLlama-1.1B-step-50K-105b
base_model_config: TinyLlama/TinyLlama-1.1B-step-50K-105b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
tokenizer_legacy: false
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: data.jsonl
    type: completion
    field: text

dataset_prepared_path: 
val_set_size: 0.08
output_dir: model-out

adapter: 
lora_model_dir:

sequence_len: 512
sample_packing: false
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 12
eval_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint: 
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false

warmup_steps: 10
evals_per_epoch: 1
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed: 
weight_decay: 0.0
fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_offload_params: false
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

Log

  • I noticed that the loss become Zero after a few iterations!
[2024-01-16 16:35:25,268] [INFO] [axolotl.calculate_total_num_steps:529] [PID:426] [RANK:0] total_num_steps: 689
[2024-01-16 16:35:25,277] [INFO] [axolotl.train.train:48] [PID:426] [RANK:0] loading tokenizer... TinyLlama/TinyLlama-1.1B-step-50K-105b
[2024-01-16 16:35:25,577] [DEBUG] [axolotl.load_tokenizer:75] [PID:426] [RANK:0] EOS: 2 / </s>
[2024-01-16 16:35:25,578] [DEBUG] [axolotl.load_tokenizer:76] [PID:426] [RANK:0] BOS: 1 / <s>
[2024-01-16 16:35:25,578] [DEBUG] [axolotl.load_tokenizer:77] [PID:426] [RANK:0] PAD: 2 / </s>
[2024-01-16 16:35:25,578] [DEBUG] [axolotl.load_tokenizer:78] [PID:426] [RANK:0] UNK: 0 / <unk>
Filter (num_proc=16):   0%|                                                                                                                                                                                          | 0/8260 [00:00<?, ? examples/s]
[2024-01-16 16:35:25,910] [INFO] [axolotl.train.train:56] [PID:426] [RANK:0] loading model and (optionally) peft_config...
Filter (num_proc=16): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8260/8260 [00:00<00:00, 8321.14 examples/s]
Downloading model.safetensors:   1%|█▏                                                                                                                                                                           | 31.5M/4.40G [00:00<00:28, 155MB/s]
/usr/local/lib/python3.10/dist-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.
  table = cls._concat_blocks(blocks, axis=0)
Filter (num_proc=16): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 719/719 [00:00<00:00, 2785.06 examples/s]
[2024-01-16 16:35:27,248] [INFO] [axolotl.calculate_total_num_steps:529] [PID:427] [RANK:1] total_num_steps: 689
[2024-01-16 16:35:27,260] [INFO] [axolotl.train.train:48] [PID:427] [RANK:1] loading tokenizer... TinyLlama/TinyLlama-1.1B-step-50K-105b
Downloading model.safetensors:   5%|████████▋                                                                                                                                                                     | 220M/4.40G [00:01<00:20, 201MB/s]
[2024-01-16 16:35:27,556] [DEBUG] [axolotl.load_tokenizer:75] [PID:427] [RANK:1] EOS: 2 / </s>
[2024-01-16 16:35:27,556] [DEBUG] [axolotl.load_tokenizer:76] [PID:427] [RANK:1] BOS: 1 / <s>
[2024-01-16 16:35:27,556] [DEBUG] [axolotl.load_tokenizer:77] [PID:427] [RANK:1] PAD: 2 / </s>
[2024-01-16 16:35:27,556] [DEBUG] [axolotl.load_tokenizer:78] [PID:427] [RANK:1] UNK: 0 / <unk>
Downloading model.safetensors:   6%|██████████▎                                                                                                                                                                   | 262M/4.40G [00:01<00:20, 203MB/s]
[2024-01-16 16:35:27,828] [INFO] [axolotl.train.train:56] [PID:427] [RANK:1] loading model and (optionally) peft_config...
Downloading model.safetensors: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.40G/4.40G [00:21<00:00, 203MB/s]
Downloading generation_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 129/129 [00:00<00:00, 502kB/s]
[2024-01-16 16:35:49,636] [INFO] [axolotl.load_model:396] [PID:427] [RANK:1] GPU memory usage after model load: 2.062GB (+0.087GB cache, +1.026GB misc)
[2024-01-16 16:35:49,642] [INFO] [axolotl.load_model:424] [PID:427] [RANK:1] converting modules to torch.float16 for flash attention
[2024-01-16 16:35:50,064] [INFO] [axolotl.train.train:108] [PID:427] [RANK:1] Starting trainer...
[2024-01-16 16:35:51,523] [INFO] [axolotl.load_model:396] [PID:426] [RANK:0] GPU memory usage after model load: 2.062GB (+0.087GB cache, +1.026GB misc)
[2024-01-16 16:35:51,529] [INFO] [axolotl.load_model:424] [PID:426] [RANK:0] converting modules to torch.float16 for flash attention
[2024-01-16 16:35:51,933] [INFO] [axolotl.train.train:108] [PID:426] [RANK:0] Starting trainer...
{'loss': 6.6129, 'learning_rate': 0.0, 'epoch': 0.0}                                                                                                                                                                                                 
  0%|▎                                                                                                                                                                                                             | 1/688 [00:47<9:00:26, 47.20s/it]
[2024-01-16 16:37:09,838] [INFO] [axolotl.callbacks.on_step_end:122] [PID:427] [RANK:1] GPU memory usage while training: 0.368GB (+9.704GB cache, +1.219GB misc)
[2024-01-16 16:37:09,839] [INFO] [axolotl.callbacks.on_step_end:122] [PID:426] [RANK:0] GPU memory usage while training: 0.368GB (+9.704GB cache, +1.219GB misc)
{'loss': 5.6427, 'learning_rate': 0.0, 'epoch': 0.01}                                                                                                                                                                                                
{'loss': 4.7613, 'learning_rate': 0.0, 'epoch': 0.01}                                                                                                                                                                                                
{'loss': 6.0507, 'learning_rate': 0.0, 'epoch': 0.01}                                                                                                                                                                                                
{'loss': 4.0042, 'learning_rate': 2e-05, 'epoch': 0.01}                                                                                                                                                                                              
{'loss': 0.0, 'learning_rate': 2e-05, 'epoch': 0.02}                                                                                                                                                                                                 
{'loss': 0.0, 'learning_rate': 2e-05, 'epoch': 0.02}                                                                                                                                                                                                 
{'loss': 0.0, 'learning_rate': 2e-05, 'epoch': 0.02}                                                                                                                                                                                                 
  1%|██▍                                          | 8/688 [02:44<3:25:46, 18.16s/it]

Current behaviour

Should be a model that doesn't have all NAN values for param and weights.

Steps to reproduce

Same as expected Expected Behavior

Config yaml

base_model: TinyLlama/TinyLlama-1.1B-step-50K-105b
base_model_config: TinyLlama/TinyLlama-1.1B-step-50K-105b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
tokenizer_legacy: false
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: data.jsonl
    type: completion
    field: text

dataset_prepared_path: 
val_set_size: 0.08
output_dir: model-out

adapter: 
lora_model_dir:

sequence_len: 512
sample_packing: false
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 12
eval_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint: 
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false

warmup_steps: 10
evals_per_epoch: 1
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed: 
weight_decay: 0.0
fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_offload_params: false
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

Possible solution

Which Operating Systems are you using?

  • [X] Linux
  • [ ] macOS
  • [ ] Windows

Python Version

3.10

axolotl branch-commit

a045db02146751548fec57a5d3f31382ce4e5959

Acknowledgements

  • [X] My issue title is concise, descriptive, and in title casing.
  • [X] I have searched the existing issues to make sure this bug has not been reported yet.
  • [X] I am using the latest version of axolotl.
  • [X] I have provided enough information for the maintainers to reproduce and diagnose the issue.

hahmad2008 avatar Jan 24 '24 19:01 hahmad2008

if you're using fp16, you'll likely have to change your learning rate way down. you're getting over/underflows of the fp16 values leading to 0 loss

winglian avatar Jan 24 '24 23:01 winglian

@winglian but I full finetuned the same TinyLLama model using fp16 with deepspeed Zero2, and there is no problem with it, no NAN weights.

hahmad2008 avatar Jan 25 '24 08:01 hahmad2008

@winglian btw I am using docker version with the following packages versions: cuda: 11.8 pytorch: 2.0.1+cu118 accelerate: 0.24.0.dev0 transformers: 4.35.0.dev0

hahmad2008 avatar Jan 25 '24 09:01 hahmad2008

@winglian I changed the learning rate to learning_rate: 0.000002 and the loss still become ZERO

hahmad2008 avatar Jan 25 '24 22:01 hahmad2008

@winglian Any idea, please?

hahmad2008 avatar Jan 31 '24 16:01 hahmad2008

Which GPU are you using? Are you able to use bf16? It should stabilize loss better.

NanoCode012 avatar Feb 17 '24 04:02 NanoCode012

@NanoCode012 I am using 2 X Tesla T4 using fp16

hahmad2008 avatar Feb 28 '24 17:02 hahmad2008

@hahmad2008 , it's possible that there is instability. Would you be able to use a newer gen GPU that's ampere gen? I would recommend enabling bf16: true to prevent this issue.

Alternatively, can you try deepspeed? I believe there was also some issue with fsdp a while back.

NanoCode012 avatar Feb 28 '24 19:02 NanoCode012

@NanoCode012 Thanks, I will give it a try and will come back to you.

hahmad2008 avatar Feb 29 '24 21:02 hahmad2008

@NanoCode012 @winglian I tried with bf16 on A10 GPU and the training loss was stable, but with fp16 it was not stable the loss was jumping to zero and weight of the generated model was NAN.

btw, training with FP16, the trainer should use GradScaler. I am wondering if axolotl with FSDP uses gradscaler with mixed precision fp16 or not?

hahmad2008 avatar Mar 11 '24 19:03 hahmad2008