unsloth
unsloth copied to clipboard
No Validation Loss logged (possibly related to train_on_responses_only?)
Evaluations are being run, but no validation loss is logged or sent to WandB
The console shows that eval is running, but displays a table along the lines of:
eval loss | validation loss |
---|---|
1 | no log |
.9 | no log |
.8 | no log |
WandB shows evidence validation run occurs, but doesn't display loss either:
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported
import os
os.environ["WANDB_PROJECT"] = "my_project" # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = train_dataset,
eval_dataset=validation_dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 2,
packing = False, # Can make training 5x faster for short sequences.
args = TrainingArguments(
fp16_full_eval = True,
per_device_eval_batch_size = 2,
eval_accumulation_steps = 4,
report_to = "wandb",
run_name = "run-name-here",
per_device_train_batch_size = 4,
gradient_accumulation_steps = 4,
gradient_checkpointing=True,
warmup_steps = 5,
do_predict=True,
logging_first_step=True,
num_train_epochs = 3, # Set this for 1 full training run.
save_steps = 88,
evaluation_strategy="steps",
eval_steps=88,
do_eval=True,
learning_rate = 1e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
),
)
---
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
trainer,
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)
trainer.train()
Very similar settings work when using plain SFTTrainer in another project