add support for closed source model for Generalized Knowledge Distillation Trainer
Feature request
closed source model support for GKS, like openai gpt4-o and claude etc.
from datasets import Dataset
from trl import GKDConfig, GKDTrainer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
NUM_DUMMY_SAMPLES = 100
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The teacher model to calculate the KL divergence against
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
train_dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "user", "content": "Hi, how are you?"},
{"role": "assistant", "content": "I'm great thanks"},
]
]
* NUM_DUMMY_SAMPLES
}
)
eval_dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "user", "content": "What colour is the sky?"},
{"role": "assistant", "content": "The sky is blue"},
]
]
* NUM_DUMMY_SAMPLES
}
)
args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
trainer = GKDTrainer(
model=model,
teacher_model=teacher_model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
Motivation
Your contribution
@kashif @lewtun
since the logits/dictionary needs to match between the teacher and student model, I do not thinks possible to train with closed models
since the logits/dictionary needs to match between the teacher and student model, I do not thinks possible to train with closed models
Anthropic API doesn't output any logits or logprobs and they have no plans to, and OpenAI only allows a max of 20 logprobs. It seems like they really don't want you to distill. OpenAI recently announced a distillation service, but it's only for their own models and not open source.