trl
trl copied to clipboard
GKD Degradation from trl v0.21.0 to 0.25.1
Reproduction
import torch
import os
from datasets import Dataset, load_dataset
from trl import GKDConfig, GKDTrainer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
import trl
print(trl.__version__)
devices="cuda:0"
tokenizer = AutoTokenizer.from_pretrained(f"{t_model_name}", device_map=devices)
t_model = AutoModelForCausalLM.from_pretrained(f"{t_model_name}", dtype=torch.bfloat16, device_map="cuda:1")
s_model = AutoModelForCausalLM.from_pretrained(f"{s_model_name}", dtype=torch.bfloat16, device_map=devices)
train_dataset = load_dataset("parquet", data_files="abc.parquet")["train"]
training_args = GKDConfig(
output_dir="KD_Test_1",
per_device_train_batch_size=4,
per_device_eval_batch_size= 4,
num_train_epochs=3,
beta=0,
lmbda=0,
temperature=0.9,
learning_rate=1e-4,
lr_scheduler_type="constant",
logging_steps=100,
save_steps=1000,
bf16=True,
run_name="KD_Test_1",
report_to="none",
)
trainer = GKDTrainer(
model=s_model,
teacher_model=t_model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
)
trainer.train()
outputs (trl version: 0.21.0):
[ 206/3750 01:10 < 20:17, 2.91 it/s, Epoch 0.16/3]
outputs (trl version: 0.25.1):
[ 151/3750 02:58 < 1:11:49, 0.84 it/s, Epoch 0.12/3]
The training speed has been degraded more that thrice and code remain the same.
Edit: Further checking it out, it is occurring due to two model initialize in different device. But, the same does on occur with old trl version. After, initializing GKDTrainer, both model teacher and student seems to go on same cuda device in both the version. Further investigating, backward pass seems to be consuming more time.
System Info
OS Enviroment:
- Platform: Linux-5.15.0-1046-nvidia-x86_64-with-glibc2.35
- Python version: 3.10.12
- TRL version: 0.25.1
- PyTorch version: 2.7.1
- accelerator(s): NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3
- Transformers version: 4.57.1
- Accelerate version: 1.7.0
- Accelerate config: not found
- Datasets version: 3.6.0
- HF Hub version: 0.34.4
- bitsandbytes version: 0.45.3
- DeepSpeed version: 0.18.2
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: 1.100.1
- PEFT version: 0.15.2
- vLLM version: 0.10.1.1
Checklist
- [x] I have checked that my issue isn't already filed (see open issues)
- [x] I have included my system information
- [x] Any code provided is minimal, complete, and reproducible (more on MREs)
- [x] Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- [x] Any traceback provided is complete