transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Trainer: add predict with generate

Open zucchini-nlp opened this issue 1 year ago • 9 comments

What does this PR do?

Fixes #26474, fixes #31462 and fixes #31672. This PR adds possibility to generate and compute metrics on generated texts for decoder-only models.

The basic idea is almost same as in Seq2Seq Trainer, but decoder-only models need a prompt-only input for generation. While for loss computation we need the whole input. Therefore we can ask users to prepare train and eval datasets, so that the eval contains generation_inputs used for generation. Additionally, to make user's life easier, I added a possibility to pass in different collators for train and for eval/test datasets.

The args used for generation should be set via GenerationConfig, as imo that makes most sense instead of adding only max_length and num_beams as in Seq2SeqTrainer.

The code was tested with the below dummy train script.

import random
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics2ForConditionalGeneration, TrainingArguments, Trainer

DEVICE = "cuda:0"
processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", do_image_splitting=False)
pad_token_id = processor.tokenizer.pad_token_id

lora_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$',
    use_dora=False,
    init_lora_weights="gaussian"
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
)

model.add_adapter(lora_config)
model.enable_adapters()


eval_dataset = load_dataset("nielsr/docvqa_1200_examples", split="test")
eval_dataset = eval_dataset.remove_columns(['id', 'words', 'bounding_boxes', 'answer'])
eval_dataset = eval_dataset.select(range(10))


class DataCollatorForGeneration:
    def __init__(self, processor, eval_mode=False):
        self.processor = processor
        self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")
        ]
        self.eval_mode = eval_mode

    def __call__(self, examples):
        texts, texts_eval = [], []
        images = []
        for example in examples:
            image = example["image"]
            question = example["query"]["en"]
            answer = random.choice(example["answers"])
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Answer briefly."},
                        {"type": "image"},
                        {"type": "text", "text": question}
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": answer}
                    ]
                }
            ]
            text = processor.apply_chat_template(messages, add_generation_prompt=False)
            text_eval = processor.apply_chat_template([messages[0]], add_generation_prompt=True)
            texts.append(text.strip())
            texts_eval.append(text_eval.strip())
            images.append([image])

        # Make sure we have right padding in train and left padding for eval parts
        processor.tokenizer.padding_side = "right"
        batch = processor(text=texts, images=images, return_tensors="pt", padding=True) 
        
        if self.eval_mode:
            processor.tokenizer.padding_side = "left"
            batch_eval = processor(text=texts, images=images, return_tensors="pt", padding=True)
            batch['generation_input_ids'] = batch_eval['input_ids']
            batch['generation_attention_mask'] = batch_eval['attention_mask']

        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id
        batch["labels"] = labels

        return batch

gen_config = model.generation_config
gen_config.max_length = 200

training_args = TrainingArguments(
    max_steps=100,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=5,
    gradient_accumulation_steps=2,
    output_dir="tmp_delete",
    eval_strategy="steps",
    fp16=True,
    remove_unused_columns=False,
    report_to="none",
    predict_with_generate=True,
    generation_config=gen_config,
)

def custom_metrics(prediction_dict):
    # unmask for correct detokenization, because preds are padded to max length with -100
    preds = prediction_dict.predictions
    preds[preds == -100] = pad_token_id
    lbls = prediction_dict.label_ids
    lbls[lbls == -100] = pad_token_id

    # Decode and do magic for metrics
    preds = processor.batch_decode(preds)
    lbls = processor.batch_decode(lbls)
    bleu = rouge = 0
    return {"bleu" : bleu, "rouge": rouge}


trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=DataCollatorForGeneration(processor),
    eval_data_collator=DataCollatorForGeneration(processor, eval_mode=True),
    train_dataset=eval_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=custom_metrics,
)

print(trainer.evaluate())

zucchini-nlp avatar Jul 31 '24 06:07 zucchini-nlp

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

This seems good to me from a quick read, pinging @SunMarc and @muellerzr who have more experience with that code than I do.

LysandreJik avatar Jul 31 '24 09:07 LysandreJik

Can you please show an example how this code can work with HF text dataset (not the multimodal dataset) without Idefics2 processor? I mean using tokenizer.apply_chat_template ? how right and left padding would be handle in this case?

salrowili avatar Aug 01 '24 04:08 salrowili

@salrowili it should be similar to Idefics with the only difference that instead of processor.tokenizer you have simply tokenizer. The main thing to note is that Trainer needs inputs for loss calculation which are prepared same was as they were always, and also needs inputs for generation which should be prepended with generation_ prefix and left-padded

Below is a modified version of Idefics script, should work for text models

class DataCollatorForGeneration:
    def __init__(self, tokenizer, eval_mode=False):
        self.tokenizer = tokenizer
        self.eval_mode = eval_mode

    def __call__(self, examples):
        texts, texts_eval = [], []
        images = []
        for example in examples:
            question = example["query"]["en"]
            answer = random.choice(example["answers"])
            messages = [
                {
                    "role": "user",
                    "content": f"Answer question: {question}"
                },
                {
                    "role": "assistant",
                    "content": answer
                }
            ]
            text = tokenizer.apply_chat_template(messages, add_generation_prompt=False)
            text_eval = tokenizer.apply_chat_template([messages[0]], add_generation_prompt=True)
            texts.append(text.strip())
            texts_eval.append(text_eval.strip())
            images.append([image])

        # Make sure we have right padding in train and left padding for eval parts
        tokenizer.padding_side = "right"
        batch = tokenizer(text=texts, return_tensors="pt", padding=True) 
        
        if self.eval_mode:
            tokenizer.padding_side = "left"
            batch_eval = tokenizer(text=texts, return_tensors="pt", padding=True)
            batch['generation_input_ids'] = batch_eval['input_ids']
            batch['generation_attention_mask'] = batch_eval['attention_mask']

        labels = batch["input_ids"].clone()
        labels[labels == tokenizer.pad_token_id] = -100 # Ignore index for CE-loss
        batch["labels"] = labels

zucchini-nlp avatar Aug 01 '24 04:08 zucchini-nlp

@salrowili it should be similar to Idefics with the only difference that instead of processor.tokenizer you have simply tokenizer. The main thing to note is that Trainer needs inputs for loss calculation which are prepared same was as they were always, and also needs inputs for generation which should be prepended with generation_ prefix and left-padded

Below is a modified version of Idefics script, should work for text models

class DataCollatorForGeneration:
    def __init__(self, tokenizer, eval_mode=False):
        self.tokenizer = tokenizer
        self.eval_mode = eval_mode

    def __call__(self, examples):
        texts, texts_eval = [], []
        images = []
        for example in examples:
            question = example["query"]["en"]
            answer = random.choice(example["answers"])
            messages = [
                {
                    "role": "user",
                    "content": f"Answer question: {question}"
                },
                {
                    "role": "assistant",
                    "content": answer
                }
            ]
            text = tokenizer.apply_chat_template(messages, add_generation_prompt=False)
            text_eval = tokenizer.apply_chat_template([messages[0]], add_generation_prompt=True)
            texts.append(text.strip())
            texts_eval.append(text_eval.strip())
            images.append([image])

        # Make sure we have right padding in train and left padding for eval parts
        tokenizer.padding_side = "right"
        batch = tokenizer(text=texts, return_tensors="pt", padding=True) 
        
        if self.eval_mode:
            tokenizer.padding_side = "left"
            batch_eval = tokenizer(text=texts, return_tensors="pt", padding=True)
            batch['generation_input_ids'] = batch_eval['input_ids']
            batch['generation_attention_mask'] = batch_eval['attention_mask']

        labels = batch["input_ids"].clone()
        labels[labels == tokenizer.pad_token_id] = -100 # Ignore index for CE-loss
        batch["labels"] = labels

@zucchini-nlp Thank you for the update. I have added some lines to the code to make a complete example for QA task.

from transformers import AutoModelForCausalLM, AutoTokenizer,BitsAndBytesConfig,TrainingArguments, Trainer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
from peft import LoraConfig
import torch
from torchmetrics.text import SQuAD
from random import randrange
from transformers.utils import logging

dataset = load_dataset("Stanford/web_questions")

train_dataset=dataset["train"]
eval_dataset=dataset["test"]

eval_dataset = eval_dataset.select(range(256))

quant_config=BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_type=torch.bfloat16
)


model_id="meta-llama/Meta-Llama-3.1-8B"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quant_config,
    device_map="auto",
    torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

tokenizer.add_special_tokens({"pad_token":"</s>"})
pad_token_id = tokenizer.pad_token_id


model.resize_token_embeddings(len(tokenizer),pad_to_multiple_of=8)


gen_config=gen_config = model.generation_config
gen_config.max_length = 256

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model.add_adapter(peft_config)
model.enable_adapters()

tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ '  ' }}{% endif %}{% endfor %}{{ eos_token }}"

class DataCollatorForGeneration:
    def __init__(self, tokenizer, eval_mode=False):
        self.tokenizer = tokenizer
        self.eval_mode = eval_mode

    def __call__(self, examples):
        texts, texts_eval = [], []
        for example in examples:
            question = example["question"]
            answer = example["answers"][0] ### webquestion dataset has multiple answers so to make the code simple we choost the first answer
            messages = [
                {
                    "role": "user",
                    "content": f"Answer the following question: {question}"
                },
                {
                    "role": "assistant",
                    "content": answer
                }
            ]
            
            text = tokenizer.apply_chat_template(messages, add_generation_prompt=False,tokenize=False)
            text_eval = tokenizer.apply_chat_template([messages[0]],add_generation_prompt=True,tokenize=False)
            texts.append(text.strip())
            texts_eval.append(text_eval.strip())
            ## uncomment to check template format
            # print(text)
            #print(text_eval)
            #exit()
        # Make sure we have right padding in train and left padding for eval parts
        tokenizer.padding_side = "right"
        batch = tokenizer(text=texts, return_tensors="pt", padding=True) 
        
        if self.eval_mode:
            tokenizer.padding_side = "left"
            batch_eval = tokenizer(text=texts_eval, return_tensors="pt", padding=True)
            batch['generation_input_ids'] = batch_eval['input_ids']
            batch['generation_attention_mask'] = batch_eval['attention_mask']
        labels = batch["input_ids"].clone()
        labels[labels == tokenizer.pad_token_id] = -100 # Ignore index for CE-loss
        batch["labels"] = labels
        return batch




def custom_metrics(prediction_dict):
    # unmask for correct detokenization, because preds are padded to max length with -100
    preds = prediction_dict.predictions
    preds[preds == -100] = pad_token_id
    lbls = prediction_dict.label_ids
    lbls[lbls == -100] = pad_token_id

    # Decode and do magic for metrics
    preds = tokenizer.batch_decode(preds,skip_special_tokens=True)
    lbls = tokenizer.batch_decode(lbls,skip_special_tokens=True)
    ## uncomment if you want to see all special token (e.g, EOS)
    #preds = tokenizer.batch_decode(preds)
    #lbls = tokenizer.batch_decode(lbls)
    print("\n\n\n",'=' * 40,"Labels",'=' * 40)
    for item_x in lbls[:5]:
        print(item_x,"\n")
    print("\n",'=' * 40,"Predictions",'=' * 40)
    for item_x in preds[:5]:
        print(item_x,"\n")
    print("\n",'=' * 80)
    pred_list=[]
    label_list=[]
    idx=0
    ## visit https://lightning.ai/docs/torchmetrics/stable/text/squad.html for reference ##
    for x,y in zip(preds,lbls):
       pred_list.append({"prediction_text": x.split("?")[1], "id": idx})
       label_list.append({"answers": {"text": [y.split("?")[1]]}, "id": idx})
       squad = SQuAD()(pred_list,label_list)
       em_score=squad["exact_match"].item()
       f1_score=squad["f1"].item()
       idx+=1
    return {"exact_match" : em_score, "f1_score": f1_score}

def preprocess_logits_for_metrics(logits, labels):
        """Helper function for logits preprocessing for metrics"""
        preds = torch.argmax(logits, dim=-1)
        return preds, labels

training_args = TrainingArguments(
    per_device_train_batch_size=8,
    per_device_eval_batch_size=128,
    num_train_epochs=20,
    do_train=True,
    do_eval=True,
    eval_strategy="steps",
    eval_steps=500,
    save_steps=500000,
    bf16=True,
    output_dir="./test_predict",
    overwrite_output_dir=True,
    optim="adafactor",
    report_to="none",
    logging_steps=100000,
    remove_unused_columns=False,
    predict_with_generate=True,
    generation_config=gen_config)


trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=DataCollatorForGeneration(tokenizer),
    eval_data_collator=DataCollatorForGeneration(tokenizer, eval_mode=True),
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=custom_metrics,
)

trainer.train()

Here is my comments and question to think about:

  • Evaluation is very slow. I had to increase the batch size to 128. Usually with seq2seq trainer batch size after 8 would not affect the speed but in this example more batch size led to more inference and evaluation speed. This become a problem with XLA (e.g. TPU) with FSDP. The evaluation will freeze forever.
  • Is there any way to speed up the evaluation with pack=True?
  • I think we need to integrate the code to trl library including all trainer codes in trl repo (e.g, https://github.com/huggingface/trl/blob/main/trl/trainer/sft_trainer.py)
  • The code need "report_to" to be assigned to none. Otherwise there will be problem with generation config format. The code you post accept the generation config in json format where the model.generation_config will be different (e.g, for Llama) .
  • We need to fed the prompt (e.g., instruction and prompt) to compute_metric or preprocess_logits_for_metrics in trainer and SFT trainer. This way we can split the completion (e.g. answer) from the whole sequence (e.g., instructing + Question).
  • In this code i did not use the "preprocess_logits_for_metrics" function. do you think doing some codes in preprocess_logits_for_metrics function can speed up the evaluation?

The output for this code is pretty decent :


 ======================================== Labels ========================================
Answer the following question: what does jamaican people speak?  Jamaican Creole English Language

Answer the following question: what did james k polk do before he was president?  Lawyer

Answer the following question: what is the oregon ducks 2012 football schedule?  University of Oregon

Answer the following question: who plays ken barlow in coronation street?  Tony Warren

Answer the following question: what happened after mr. sugihara died?  Yaotsu


 ======================================== Predictions ========================================
Answer the following question: what does jamaican people speak?  Jamaican Creole English Language#include<fstream.h>ndependent institution?

Answer the following question: what did james k polk do before he was president?  Lawyer(Redirected from James K. Polk

Answer the following question: what is the oregon ducks 2012 football schedule?  University of Oregon"""
University of Oregon Ducks football

Answer the following question: who plays ken barlow in coronation street?  Tony Warrendef乍

Answer the following question: what happened after mr. sugihara died?  Yaotsu``

Thank you

salrowili avatar Aug 02 '24 01:08 salrowili

Thanks for feedback and testing the feature!

  • Hmm, let me see the speed issue, my guess is that generation in general will require more time as you are doing extra forward passes through the model. Increasing batch size makes it faster as now you generate 128 examples in one go. I will see if I can make generate faster, as we already do two forward passes, we could somehow use the past-key-values. Also will look into packing
  • Regarding TRL library, let's merge the PR in transformers first and we can port the changes to TRL if everyone is happy with the code
  • Are you reporting to WandB? I'll try to do report and see how to make it work
  • The question with prompts should be doable with current impl if we ask Trainer to include_inputs_for_metrics, iiuc that will let us see the prompt used by model and the whole completion generated

zucchini-nlp avatar Aug 02 '24 05:08 zucchini-nlp

Seems like there's not much we can do about long evaluation time when generating. I tried to track how long it takes with decoder-only and encoder-decoder models. Indeed most encoder-decoder models are fast mainly because they're lightweight while the model you tried with is 8B parameters. I did several checks to verify that the evaluation speed is approximately same when we have models of similar size. Increasing batch size is one option to generate faster as you tried already. Another option is to generate only on a small sample of the eval set, and let users enable generation on the whole dataset if they want to

Also, logging on WandB is working for me and the generation config is logged as a simple dict, can you share what errors you got there @salrowili ?

zucchini-nlp avatar Aug 16 '24 06:08 zucchini-nlp

Hi @zucchini-nlp . When i state that the prediction is slow i compare it to this script here https://huggingface.co/docs/trl/en/sft_trainer, which is much faster. I think one possible way to solve this problem is to integrate your code to SFTTrainer class from trl repo and see if the speed has changed. Another way is to do it through eval_packing which will group couple of example together to fill the sequence. see : https://github.com/huggingface/trl/blob/314e8eb367cbfaf74c2e9717085346360e779508/trl/trainer/sft_trainer.py#L110 .

For wandb logging, to reproduce the error, just change report_to from None to wandb and you will get the error. but this issue is minor as we can overcome it by using wandb.init and wandb.log inside the code it self.

salrowili avatar Aug 16 '24 11:08 salrowili

Oke, so it's SFTTrainer, then I'll see what is different there. For packing, we can calculate loss with packing but not generate, since generation tries to continue next several tokens from sequence and in a packed sequence there might be more than one.

In general, we had an idea to try out torch.nested_tensor which is the most similar to packing, but it won't be soon

zucchini-nlp avatar Aug 16 '24 11:08 zucchini-nlp

Thanks for this PR @zucchini-nlp, hoping it gets merged soon. I am using similar thing internally to train decoder-only models for information extraction. I saw a concern that this is slower than traditional SFT Trainer and something I experienced as well.

My belief is that this might be mainly because in SFT Trainer, during prediction it predicts n+1 token using previous n tokens that come from prompt + ground truth. While in this case, previous n tokens come from prompt + predictions. So it cannot be parallelized same way as the SFT Trainer, where you can literally predict n+1, n+2... tokens in parallel.

My belief comes from the fact that I saw a drop in eval performance and increase in time when using predict_with_generate, compared to using SFTTrainer as it is.

shubhamjain0594 avatar Aug 28 '24 15:08 shubhamjain0594

@shubhamjain0594 Yes, that's exactly what I meant that generation is expected to take more time than simple forward. As per the last comment from @salrowili , I compared SFT and HFTrained, with and without generation. I don't see any slow down caused by HFTrainer specifically, as the both of them rely on the same code to do training and evalution. The only diff is that SFT support packed dataset while HFTrainer doesn't.

The current implementation of predict_with_generate doesn't support packing for reasons mentioned above. To be more specific generation cannot happen if we pack dataset, because we don't support it in transformers yet. The model generates the next token given all prev tokens, and packing would result in several prompts being merged thus bad quality generation.

I think we can let users use a small sample of the eval set for generation, if they don't want to slow down evaluation loop. Applying some generation optimization tricks here might not be the optimal solution, as we are trying to verify how good the model is learning. The only technique I can think of that can be used is torch compile, but it is still a very new feature and I would rather not integrate it in GFTrainer yet.

As per WandB, it worked for me in SFT and HFTrained, the generation config is logged as a dict in parameters

So, I think I can request review from @muellerz now. The PR isn't very high priority, so feel free to take a look whenever you have bandwidth

zucchini-nlp avatar Aug 30 '24 10:08 zucchini-nlp

Hi @zucchini-nlp ! Thank you for adding this PR. I have been testing it and I have a few questions/thoughts:

  • When DeepSpeed zero3 is enabled, I am hitting the error mentioned in #32641 during the model.generate(). Hope we can merge that as well @SunMarc.
  • For the slower performance, I think this is expected since the new calls to model.gernerate() involves iterative decoding, which may have many forward passes. My understanding is, one forward pass for each new token, while the existing call to model() involves only single forward pass for each input sample.
  • For packing in SFT trainer, I think we can set packing=True and eval_packing=False to do packing for the training dataset only.
  • In your example code, I think you intended to get generation_input_ids from texts_eval (instead of texts). Otherwise, as we can see in the example outputs from @salrowili, every prediction includes the entire label.

qiuosier avatar Aug 31 '24 17:08 qiuosier

@qiuosier Yes, in SFT one can pack train dataset and not pack the evaluation. I am not 100% sure it work with SFT out-of-the-box, since afair SFT doesn't accept the train_with_predict arg. We can work on adding SFT support after this PR is merged

  • In your example code, I think you intended to get generation_input_ids from texts_eval (instead of texts).

Oh, right, a typo hehe

zucchini-nlp avatar Sep 04 '24 11:09 zucchini-nlp

Hi @zucchini-nlp , the SFTConfig(args argument), is a subclass of TrainingArguments, so once your version of transformers is installed, we will be able to use the predict_with_generate in the args for in the SFTTrainer. I have to set dataset_kwargs={"skip_prepare_dataset": True} to customize the data preparation in the __init__() and use a customized data collator though. Because by default, the SFTTrainer will tokenize (and pack) the data in the __init__() call and the data collator is used for padding only. I think tokenizing the data in __init__() is more efficient as we only need to do it once, instead of doing it every epoch in the data collator.

qiuosier avatar Sep 04 '24 20:09 qiuosier

@qiuosier oke, cool, then it might work out-of-the-box. Didn't really test it yet

zucchini-nlp avatar Sep 05 '24 08:09 zucchini-nlp

Any chance this can be merged?

pdufour avatar Nov 01 '24 21:11 pdufour

@zucchini-nlp I've been testing your branch and there were a couple issues that I fixed in regards to using it with SFTTrainer:

  1. https://github.com/huggingface/trl/pull/2311
  2. https://github.com/huggingface/trl/pull/2310

These can be landed after the transformers PR lands and a new version is released. I've tried to help you and re-merge in main into this branch, but when I made a PR it was not what I thought it'd be - https://github.com/zucchini-nlp/transformers/pull/1. There weren't many changes though, just a couple conflicts. Let me know if there's anything I can do to help.

pdufour avatar Nov 03 '24 17:11 pdufour

@pdufour hey! Thanks for opening PRs to integrate the feature in TRL. I agree that we should first get this merged in transformers.

For the progress, I am waiting for reviews from @muellerz . Can you review pls? 🤗 I will rebase main in a few hours

zucchini-nlp avatar Nov 04 '24 09:11 zucchini-nlp

Hopefully this get's merged soon. I thought it would be easy to implement some custom_metrics to calculate like mtbench scores at every eval step as previously something like bleu/wer were easy to calculate using just evaluate library

R4ZZ3 avatar Nov 05 '24 20:11 R4ZZ3

So am I correct to assume that all existing collators, e.g. DataCollatorForCompletionOnlyLM will need to be modified so that they contain generation_input_ids and generation_attention_mask?

If so, it's less than ideal.

tcz avatar Nov 26 '24 12:11 tcz

While I hope this PR gets merged, here's a messy workaround I currently use:

import torch
from tqdm import tqdm

resume = False
for steps in tqdm(range(0, 1701, 100)):
    print(f"Steps: {steps}")

    if steps > 0:
        trainer = create_trainer(model, tokenizer, training_data['train'], steps)
        if resume:
            trainer.train(resume_from_checkpoint=True)
        else:
            trainer.train()
            resume = True
        
    model = FastLanguageModel.for_inference(model)

    test_prediction(model, training_data['test'], steps)

    model = FastLanguageModel.for_training(model)

I basically set max_steps on the trainer each time, set the model to eval mode, run my metrics and then resume_from_checkpoint with an increased number of steps. It's not very elegant, but it works for the time being.

tcz avatar Dec 12 '24 12:12 tcz

I really hope this PR or a similar PR can be merged. Is there still any chance this can be merged?

haotong-yang avatar Mar 21 '25 18:03 haotong-yang

sorry, quite low priority PR for me so if anyone wants to work on it, feel free to take it from here :)

zucchini-nlp avatar Mar 24 '25 09:03 zucchini-nlp

@zucchini-nlp I can see if I have some time, what is the outstanding work that would need to get done?

pdufour avatar Apr 04 '25 15:04 pdufour

@pdufour thanks a lot! The most important parts are to address comments under this PR, if any are still unaddressed. And perform small experiments to make sure it works with model training and maybe support torch compile (might help for faster generation, otherwise eval takes too long). The last feedback also mentioned it didn't log generation params in tensorboard, though worth checking it as well

zucchini-nlp avatar Apr 07 '25 08:04 zucchini-nlp

I'm interested in this PR, so I was wondering if anyone is currently working on this.

nsbg avatar Apr 14 '25 15:04 nsbg

+1

hrshtkpr avatar Jun 23 '25 10:06 hrshtkpr

Waiting for this!

steveepreston avatar Oct 11 '25 03:10 steveepreston