transformers
transformers copied to clipboard
Early stopping Abnormality
System Info
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
-
transformers
version: 4.32.1 - Platform: Linux-4.18.0-372.9.1.el8_lustre.x86_64-x86_64-with-glibc2.17
- Python version: 3.11.5
- Huggingface_hub version: 0.15.1
- Safetensors version: 0.3.2
- Accelerate version: 0.23.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.0.1 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
@ArthurZucker and @younesbelkada
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
1. function definition
def compute_metrics_binary(eval_preds): logits, labels = eval_preds prediction_scores = torch.nn.functional.softmax(torch.from_numpy(logits).double(), dim=-1).numpy() predictions = np.argmax(prediction_scores, axis=-1)
# 计算各种评价指标
accuracy = accuracy_score(labels, predictions)
eval_loss=1 - accuracy
f1 = f1_score(labels, predictions)
recall = recall_score(labels, predictions)
precision = precision_score(labels, predictions)
roc_auc_macro = roc_auc_score(labels, prediction_scores[:, 1], average='macro')
roc_auc_weighted = roc_auc_score(labels, prediction_scores[:, 1], average='weighted')
pr_auc = average_precision_score(labels, prediction_scores[:, 1])
mcc = matthews_corrcoef(labels, predictions) # 添加MCC指标
print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ Hello eval_loss",eval_loss)
# 返回评价指标
return {
'loss':eval_loss,
'accuracy': accuracy,
'f1': f1,
'recall': recall,
'precision': precision,
'roc_auc_macro': roc_auc_macro,
'roc_auc_weighted': roc_auc_weighted,
'pr_auc': pr_auc,
'mcc': mcc # 返回MCC指标
}
compute_metrics = compute_metrics_binary if len(labels) == 2 else compute_metrics_multi
2. Model training parameters
args=TrainingArguments(output_dir='outputsN', learning_rate=LEARNING_RATE, warmup_ratio=warmup_ratio, lr_scheduler_type='cosine',fp16=True,evaluation_strategy="epoch", per_device_train_batch_size=BATCH_SIZE,per_device_eval_batch_size=eval_batch_size,gradient_accumulation_steps=ACCUMULATION,num_train_epochs=EPOCHS, weight_decay=0.01,save_strategy='epoch', report_to='none',load_best_model_at_end=True,seed=seed_val,metric_for_best_model='eval_f1',eval_steps=5, )#linear
#early stopping 5 epochs callbacks= [EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.05),CometCallback()]
3. Model training Training Errors
early stopping required metric_for_best_model, but did not find eval_f1 so early stopping is disabled +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ Hello eval_loss 0.3282442748091603
Expected behavior
Question 1: Do you have any suggestions about Early stopping? Is there a complete pre-trained model fine-tuning code as a reference?
Question 2: How to modify the Trainer function to make it suitable for multiclassification problems with class imbalance.
Maybe cc @pacman100 regarding the Trainer
Dear @pacman100,
Thanks in advance, I'm looking forward for your reply.
Best,Du
Gentle ping @pacman100
Another ping @pacman100
Hello @GeorgeBGM, please let me know if there is an issue explicilty as I can see this is more of a clarification questions:
Question 1: Do you have any suggestions about Early stopping? Is there a complete pre-trained model fine-tuning code as a reference?
Question 2: How to modify the Trainer function to make it suitable for multiclassification problems with class imbalance.
Let me quickly look around and share a minimal example using early stopping functionality.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.