Unofficial Training Code sample.
I have implemented some code related to understanding fine-tuning and used sample from inference.py as a reference. Feedback and suggestions are welcome!
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
accelerator = Accelerator(mixed_precision="bf16")
device = accelerator.device
model_path = "deepseek-ai/Janus-1.3B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True
).to(device)
vl_gpt.train()
for name, param in vl_gpt.named_parameters():
if "gen_embed" in name: #
print(f"Parameter: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}")
# freeze gen_embed parameters
param.requires_grad = False
# check if the parameters are frozen
print(f"Parameter: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}")
lr = 1e-4
optimizer = optim.AdamW(vl_gpt.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
gradient_clip = 1.0
vl_gpt, optimizer = accelerator.prepare(vl_gpt, optimizer)
def train_step(model, optimizer, criterion):
model.train()
optimizer.zero_grad()
conversation = [
{
"role": "User",
"content": "<image_placeholder>\nConvert the formula into latex code.",
"images": ["images/equation.png"],
},
{"role": "Assistant", "content": ""},
]
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(device)
model = model.module if hasattr(model, "module") else model
model = model.to(torch.bfloat16)
inputs_embeds = model.prepare_inputs_embeds(**prepare_inputs)
with accelerator.autocast():
outputs = model.language_model(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
)
logits = outputs.logits # (batch_size, seq_len, vocab_size)
labels = prepare_inputs.input_ids.clone().detach()
labels[labels == tokenizer.pad_token_id] = -100
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
optimizer.step()
optimizer.zero_grad()
return loss.item()
loss = train_step(vl_gpt, optimizer, criterion)
print(f"Training loss: {loss}")
for i in range(10):
loss = train_step(vl_gpt, optimizer, criterion)
print(f"Training loss: {loss}")
I tried to write the generation fine-tuning code, but since I'm not very familiar with distributed training, I haven't finished debugging it yet.
Let me share the method I used—it's a bit legacy, but here it is:
def gen_preprocess(self, images):
gen_codebooks = []
for image in images:
image_tensor = self.gen_resize(image).unsqueeze(0)
quant, emb_loss, info = self.gen_vision_model.encode(image_tensor)
gen_codebooks.append(info[2])
return torch.stack(gen_codebooks, dim=0)
def process_train(
self,
question: str = None,
answer: str = None,
images: List[Image] = None,
gen_images: List[Image] = None,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
# if self.image_gen_tag in answer:
# answer = answer.replace(self.image_gen_tag, self.image_gen_tag*576)
sft_format = question + answer
# tokenize
input_ids = self.tokenizer.encode(sft_format)
input_ids = torch.LongTensor(input_ids)
# add image tokens to the input_ids
image_token_mask: torch.BoolTensor = input_ids == self.image_id
image_indices = image_token_mask.nonzero()
input_ids, num_image_tokens = self.add_image_token(
image_indices=image_indices,
input_ids=input_ids,
)
gen_token_mask: torch.BoolTensor = input_ids == self.image_gen_id
gen_indices = gen_token_mask.nonzero()
input_ids, num_image_gen_tokens = self.add_image_gen_token(
image_indices=gen_indices,
input_ids=input_ids,
)
# load images
images_outputs = self.image_processor(images, return_tensors="pt")
images_gen_outputs = self.gen_preprocess(gen_images)
question_input_ids = self.tokenizer.encode(question)
question_input_ids = torch.LongTensor(question_input_ids)
question_image_token_mask: torch.BoolTensor = question_input_ids == self.image_id
question_image_indices = question_image_token_mask.nonzero()
question_input_ids, _ = self.add_image_token(
image_indices=question_image_indices,
input_ids=question_input_ids,
)
target_input_ids = input_ids.clone()
# append <image_start_tag>
target_input_ids[:len(question_input_ids)+1] = self.ignore_id
target_gen_input_ids = torch.full((len(input_ids),), self.ignore_id)
# legacy code
assert torch.sum(input_ids == self.image_gen_id) == len(images_gen_outputs[0])
target_gen_input_ids[input_ids == self.image_gen_id] = images_gen_outputs[0]
target_input_ids[input_ids == self.image_gen_id] = self.ignore_id
prepare = VLChatProcessorTrainOutput(
sft_format=sft_format,
input_ids=input_ids,
pixel_values=images_outputs.pixel_values,
num_image_tokens=num_image_tokens,
num_image_gen_tokens=num_image_gen_tokens,
target_ids=target_input_ids,
target_gen_ids=target_gen_input_ids,
gen_codebooks=images_gen_outputs,
)
I'm simply checking whether the model can transform an image that goes into the understanding encoder (SigLIP) into the generation encoder (VQ model) (image -> image task). So far, it doesn't seem to be working well, though... 😭
I wrote image generation finetuning code.
https://github.com/ladvu/Janus
I am not sure if there were any bugs. It is just my coursework. So just for your reference...