pytorch-summary
pytorch-summary copied to clipboard
Failing in the input stage Using Encoder and Decoder Model Architecture (DONUT model)
I am trying to get a model summary of the donut model but am unable to define the input for the torch summary. ########################################################### import argparse import gradio as gr import torch from PIL import Image from donut.donut.model import DonutModel from torchvision import models from torchsummary import summary
def demo_process_vqa(input_img, question): global pretrained_model, task_prompt, task_name # pretrained_model = './donut/result/train_docvqa/20220912_103244' # task_name = "docvqa" # task_prompt = "<s_pdf-donut>" input_img = Image.fromarray(input_img) user_prompt = task_prompt.replace("{user_input}", question) print(user_prompt) output = pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0] print('inf_out',output) return output
def demo_process(input_img): global pretrained_model, task_prompt, task_name input_img = Image.fromarray(input_img) output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0] return output
parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="docvqa") parser.add_argument("--pretrained_path", type=str, default="train_docvqa_for_all_atts/donut/result/train_docvqa/20220915_125713") args, left_argv = parser.parse_known_args()
task_name = args.task if "docvqa" == task_name: task_prompt = "<s_taco_eiko_pdf_donut>{user_input}</s_question><s_answer>" else: # rvlcdip, cord, ... task_prompt = f"<s_{task_name}>"
pretrained_model = DonutModel.from_pretrained(args.pretrained_path)
if torch.cuda.is_available(): # pretrained_model.half() device = torch.device("cuda") pretrained_model.to(device) else: pretrained_model.encoder.to(torch.bfloat16)
summary(pretrained_model, [(1, 3, 1280 , 960), (1, 21),(1, 21)])
The shape of the encoder and decoder is as follows. Encoder : torch.Size([1, 3, 1280, 960]) Decode : torch.Size([1, 21])
##Model forward architecture looks like this
encoder_outputs = self.encoder(image_tensors)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs,
labels=decoder_labels,
)
return decoder_outputs
Can you please guide how to pass down the model input in summary?