Janus icon indicating copy to clipboard operation
Janus copied to clipboard

Unofficial Training Code sample.

Open Meaw0415 opened this issue 10 months ago • 3 comments

I have implemented some code related to understanding fine-tuning and used sample from inference.py as a reference. Feedback and suggestions are welcome!

import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images


accelerator = Accelerator(mixed_precision="bf16")  
device = accelerator.device


model_path = "deepseek-ai/Janus-1.3B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True
).to(device)

vl_gpt.train()  
for name, param in vl_gpt.named_parameters():
    if "gen_embed" in name:  # 
        print(f"Parameter: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}")
        # freeze gen_embed parameters
        param.requires_grad = False
        # check if the parameters are frozen
        print(f"Parameter: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}")


lr = 1e-4 
optimizer = optim.AdamW(vl_gpt.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1)
criterion = nn.CrossEntropyLoss(ignore_index=-100)  
gradient_clip = 1.0


vl_gpt, optimizer = accelerator.prepare(vl_gpt, optimizer)


def train_step(model, optimizer, criterion):
    model.train()
    optimizer.zero_grad()


    conversation = [
        {
            "role": "User",
            "content": "<image_placeholder>\nConvert the formula into latex code.",
            "images": ["images/equation.png"],
        },
        {"role": "Assistant", "content": ""},
    ]


    pil_images = load_pil_images(conversation)


    prepare_inputs = vl_chat_processor(
        conversations=conversation, images=pil_images, force_batchify=True
    ).to(device) 


    model = model.module if hasattr(model, "module") else model  
    model = model.to(torch.bfloat16)

    inputs_embeds = model.prepare_inputs_embeds(**prepare_inputs)


    with accelerator.autocast():  
        outputs = model.language_model(
            inputs_embeds=inputs_embeds,
            attention_mask=prepare_inputs.attention_mask,
        )
        logits = outputs.logits  # (batch_size, seq_len, vocab_size)


    labels = prepare_inputs.input_ids.clone().detach()
    labels[labels == tokenizer.pad_token_id] = -100  
    loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

    accelerator.backward(loss)  
    torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)  
    optimizer.step()
    optimizer.zero_grad()

    return loss.item()

loss = train_step(vl_gpt, optimizer, criterion)
print(f"Training loss: {loss}")

for i in range(10):
    loss = train_step(vl_gpt, optimizer, criterion)
    print(f"Training loss: {loss}")

Meaw0415 avatar Feb 12 '25 14:02 Meaw0415

I tried to write the generation fine-tuning code, but since I'm not very familiar with distributed training, I haven't finished debugging it yet.

Meaw0415 avatar Feb 12 '25 14:02 Meaw0415

Let me share the method I used—it's a bit legacy, but here it is:

def gen_preprocess(self, images):
    gen_codebooks = []
    for image in images:
        image_tensor = self.gen_resize(image).unsqueeze(0)
        quant, emb_loss, info = self.gen_vision_model.encode(image_tensor)
        gen_codebooks.append(info[2])
    return torch.stack(gen_codebooks, dim=0)

def process_train(
    self,
    question: str = None,
    answer: str = None,
    images: List[Image] = None,
    gen_images: List[Image] = None,
    **kwargs,
):
    """

    Args:
        prompt (str): the formatted prompt;
        conversations (List[Dict]): conversations with a list of messages;
        images (List[ImageType]): the list of images;
        **kwargs:

    Returns:
        outputs (BaseProcessorOutput): the output of the processor,
            - input_ids (torch.LongTensor): [N + image tokens]
            - target_ids (torch.LongTensor): [N + image tokens]
            - images (torch.FloatTensor): [n_images, 3, H, W]
            - image_id (int): the id of the image token
            - num_image_tokens (List[int]): the number of image tokens
    """
    # if self.image_gen_tag in answer:
    #     answer = answer.replace(self.image_gen_tag, self.image_gen_tag*576)
        
    sft_format = question + answer

    # tokenize
    input_ids = self.tokenizer.encode(sft_format)
    input_ids = torch.LongTensor(input_ids)

    # add image tokens to the input_ids
    image_token_mask: torch.BoolTensor = input_ids == self.image_id
    image_indices = image_token_mask.nonzero()
    input_ids, num_image_tokens = self.add_image_token(
        image_indices=image_indices,
        input_ids=input_ids,
    )
    
    gen_token_mask: torch.BoolTensor = input_ids == self.image_gen_id
    gen_indices = gen_token_mask.nonzero()
    input_ids, num_image_gen_tokens = self.add_image_gen_token(
        image_indices=gen_indices,
        input_ids=input_ids,
    )

    # load images
    images_outputs = self.image_processor(images, return_tensors="pt")
    images_gen_outputs = self.gen_preprocess(gen_images)
    
    question_input_ids = self.tokenizer.encode(question)
    question_input_ids = torch.LongTensor(question_input_ids)
    
    question_image_token_mask: torch.BoolTensor = question_input_ids == self.image_id
    question_image_indices = question_image_token_mask.nonzero()
    question_input_ids, _ = self.add_image_token(
        image_indices=question_image_indices,
        input_ids=question_input_ids,
    )
    
    target_input_ids = input_ids.clone()
    # append <image_start_tag> 
    target_input_ids[:len(question_input_ids)+1] = self.ignore_id
    
    target_gen_input_ids = torch.full((len(input_ids),), self.ignore_id)
    # legacy code
    assert torch.sum(input_ids == self.image_gen_id) == len(images_gen_outputs[0])
    target_gen_input_ids[input_ids == self.image_gen_id] = images_gen_outputs[0]
    
    target_input_ids[input_ids == self.image_gen_id] = self.ignore_id
    
    
    prepare = VLChatProcessorTrainOutput(
        sft_format=sft_format,
        input_ids=input_ids,
        pixel_values=images_outputs.pixel_values,
        num_image_tokens=num_image_tokens,
        num_image_gen_tokens=num_image_gen_tokens,
        target_ids=target_input_ids,
        target_gen_ids=target_gen_input_ids,
        gen_codebooks=images_gen_outputs,
    )

I'm simply checking whether the model can transform an image that goes into the understanding encoder (SigLIP) into the generation encoder (VQ model) (image -> image task). So far, it doesn't seem to be working well, though... 😭

top-yun avatar Feb 17 '25 05:02 top-yun

I wrote image generation finetuning code.

https://github.com/ladvu/Janus

I am not sure if there were any bugs. It is just my coursework. So just for your reference...

ladvu avatar Jun 14 '25 10:06 ladvu