donut icon indicating copy to clipboard operation
donut copied to clipboard

Working very well on training but doing pretty poorly at inference

Open WaterKnight1998 opened this issue 3 years ago • 12 comments

Hi,

Thanks you for publishing this model :)

I want to use this model for Document Parsing. I have annotations for two kinds of pdf, 20 images per type.

At training it achieves very good edit distances 0.08, 0.0, 0.1, 0.13... But when I try it at inference the edit distance is very bad 0.8

In addition in those predictions I am seeing that in several predictions it outputs: {"text_sequence": "word1 word1 word1 ... word1"} where word1 is a random word repeated all over the place

Thanks in advance

WaterKnight1998 avatar Oct 04 '22 08:10 WaterKnight1998

what is the prompt you are using while inference? most probably it is messing up the json output. If you have fine-tuned base model for your task, try using <s_REPLACE_YOUR_DATASET_NAME> as prompt at inference time.

satkatai avatar Oct 04 '22 11:10 satkatai

Hi @satheeshkatipomu , thank you for answering.

I was following this training tutorial.

If you check his lightning module:

def validation_step(self, batch, batch_idx, dataset_idx=0):
        pixel_values, labels, answers = batch
        batch_size = pixel_values.shape[0]
        # we feed the prompt to the model
        decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
        
        outputs = self.model.generate(pixel_values,
                                   decoder_input_ids=decoder_input_ids,
                                   max_length=max_length,
                                   early_stopping=True,
                                   pad_token_id=self.processor.tokenizer.pad_token_id,
                                   eos_token_id=self.processor.tokenizer.eos_token_id,
                                   use_cache=True,
                                   num_beams=1,
                                   bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                                   return_dict_in_generate=True,)
    
        predictions = []
        for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            predictions.append(seq)

        scores = list()
        for pred, answer in zip(predictions, answers):
            pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
            # NOT NEEDED ANYMORE
            # answer = re.sub(r"<.*?>", "", answer, count=1)
            answer = answer.replace(self.processor.tokenizer.eos_token, "")
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        return scores

He uses decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device) as decoder_input_ids instead of

    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

If at inference time I use: decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device) the performance improves

WaterKnight1998 avatar Oct 04 '22 12:10 WaterKnight1998

did you fine-tune the model on your custom dataset? or trying to use the off the shelf cord-v2 model?

satkatai avatar Oct 04 '22 12:10 satkatai

Yes, I finetuned the model with my custom dataset but I used that code @satheeshkatipomu

WaterKnight1998 avatar Oct 05 '22 07:10 WaterKnight1998

I see some changes compared to the code in this repo. but still I suggest checking your prompt at inference time if you think edit distance very low during training but very high while inference AND assuming your test images are similar to training data.

For example here in the fine-tuning notebook they hard corded start and end token as <s_cord-v2>. prompt_token

Please check if it same during training and inference, If yes, sorry I am unable to help solve your problem.

satkatai avatar Oct 05 '22 10:10 satkatai

I had the same issue as you, the model performed very well in validation during training, but very poorly during inference (the model just output gibberish).

Try to check added_tokens.json file in your trained model path. Your task prompt should be among the added tokens. For some reason, my task prompt was just <s_> instead of <s_{DATASET_NAME}>. added_json

outday29 avatar Dec 04 '22 12:12 outday29

I am having the same issue as it is giving very good accuracy and prediction during validation however it is giving garbish value while predicting the same image. My added tokens json is fine. Kindly help. @WaterKnight1998 @satheeshkatipomu @outday29

Mohtadrao avatar Jan 22 '24 11:01 Mohtadrao

@Mohtadrao @WaterKnight1998 have you founded a solución? I have the same error. Thanks

CarlosSerrano88 avatar Jun 15 '24 21:06 CarlosSerrano88

@CarlosSerrano88 yeah. My accuracy was increase when my number of epochs along with training data was increased. Also need to adjust the config file accordingly.

Mohtadrao avatar Jul 20 '24 10:07 Mohtadrao