why my train medusa head result is {'medusa0_top1': nan, 'medusa0_loss': nan, 'medusa1_top1': nan, 'medusa1_loss': nan, 'medusa2_top1': nan, 'medusa2_loss': nan, 'epoch': 0}
I'm now preparing to train the medusa header to the readme file and first ran into the following issue:
/data/lx/demo/Medusa/medusa/train/train_legacy.py:392: FutureWarning: tokenizer is deprecated and will be removed in version 5.0.0 for CustomizedTrainer.__init__. Use processing_class instead.
trainer = CustomizedTrainer(
Loading data...
Formatting inputs...Skip in lazy mode
/data/lx/demo/Medusa/medusa/train/train_legacy.py:392: FutureWarning: tokenizer is deprecated and will be removed in version 5.0.0 for CustomizedTrainer.__init__. Use processing_class instead.
trainer = CustomizedTrainer(
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Parameter Offload: Total persistent parameters: 278528 in 68 params
[rank1]: Traceback (most recent call last):
[rank1]: File "/data/lx/demo/Medusa/medusa/train/train_legacy.py", line 424, in
I added num_items_in_batch=none to the compute_loss method in the CustomizedTrainer class of the Medusa/medusa/train/train_legacy.py file, and the code was able to run, but the training result is like this, where did I do something wrong that caused the training loss to be nan?
My run code is (I have two GPU card ,so I change nproc_per_node is 2 ):
torchrun --nproc_per_node=2 medusa/train/train_legacy.py --model_name_or_path /data/models/Mistral-7B-Instruct-v0.2
--data_path mistral.json
--bf16 True
--output_dir test
--num_train_epochs 2
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--gradient_accumulation_steps 4
--evaluation_strategy "no"
--save_strategy "no"
--learning_rate 1e-3
--weight_decay 0.0
--warmup_ratio 0.1
--lr_scheduler_type "cosine"
--logging_steps 1
--tf32 True
--model_max_length 2048
--lazy_preprocess True
--medusa_num_heads 3
--medusa_num_layers 1
--deepspeed deepspeed.json
And result :
Hi, have you solved this problem?