trl icon indicating copy to clipboard operation
trl copied to clipboard

add support for closed source model for Generalized Knowledge Distillation Trainer

Open imrankh46 opened this issue 1 year ago • 3 comments

Feature request

closed source model support for GKS, like openai gpt4-o and claude etc.

from datasets import Dataset
from trl import GKDConfig, GKDTrainer
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)

NUM_DUMMY_SAMPLES = 100

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The teacher model to calculate the KL divergence against
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")

train_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "Hi, how are you?"},
                {"role": "assistant", "content": "I'm great thanks"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)
eval_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "What colour is the sky?"},
                {"role": "assistant", "content": "The sky is blue"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)

args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
trainer = GKDTrainer(
    model=model,
    teacher_model=teacher_model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

Motivation

Your contribution

imrankh46 avatar Oct 05 '24 04:10 imrankh46

@kashif @lewtun

imrankh46 avatar Oct 05 '24 04:10 imrankh46

since the logits/dictionary needs to match between the teacher and student model, I do not thinks possible to train with closed models

kashif avatar Oct 07 '24 10:10 kashif

since the logits/dictionary needs to match between the teacher and student model, I do not thinks possible to train with closed models

Anthropic API doesn't output any logits or logprobs and they have no plans to, and OpenAI only allows a max of 20 logprobs. It seems like they really don't want you to distill. OpenAI recently announced a distillation service, but it's only for their own models and not open source.

August-murr avatar Oct 08 '24 06:10 August-murr