unilm
unilm copied to clipboard
gpu crashes during LayoutLMv3 Document classification Finetuning
I'm trying to Finetune LayoutLMv3 model for a document classification use case, during the training process gpu is getting crashed. When trying to run in cpu mode for debugging code is not taking more than 12 hours and still running.
Here are the important parts of my code, after going through some of the github issues it looks like this issue has something to do with the way I have used tokenizer and Dataloader here is important sections from my code
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False, ocr_lang="eng")
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
processor = LayoutLMv3Processor(feature_extractor,tokenizer)
class DocumentClassificationDataset(Dataset):
def __init__(self, image_paths, processor):
self.image_paths = image_paths
self.processor = processor
def __len__(self):
return len(self.image_paths)
def __getitem__(self, item):
image_path = self.image_paths[item]
image = Image.open(image_path).convert("RGB")
width, height = image.size
json_path = image_path.with_suffix(".json")
with open(json_path, "r") as f:
ocr_result = json.load(f)
width_scale = 1000/width
height_scale = 1000/height
words = []
boxes = []
for row in ocr_result:
boxes.append(scale_bounding_box(row["bounding_box"], width_scale, height_scale))
words.append(row["word"])
encoding = processor(
image,
words,
boxes=boxes,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
label = DOCUMENT_CLASSES.index(image_path.parent.name)
return dict(
input_ids = encoding["input_ids"].flatten(),
attention_mask = encoding["attention_mask"].flatten(),
bbox = encoding["bbox"].flatten(end_dim=1),
pixel_values = encoding["pixel_values"].flatten(end_dim=1),
labels = torch.tensor(label, dtype=torch.long)
)
train_dataset = DocumentClassificationDataset(train_images, processor)
test_dataset = DocumentClassificationDataset(test_images, processor)
train_data_loader = DataLoader(
train_dataset,
batch_size=4,
shuffle=True,
num_workers=6
)
test_data_loader = DataLoader(
test_dataset,
batch_size=4,
shuffle=False,
num_workers=6
)
class ModelModule(pl.LightningModule):
def __init__(self, n_classes: int):
super().__init__()
self.model = LayoutLMv3ForSequenceClassification.from_pretrained(
"microsoft/layoutlmv3-base",
num_labels = n_classes
)
self.train_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
self.val_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
def forward(self, input_ids, attention_mask, bbox, pixel_values, labels=None):
return self.model(
input_ids,
attention_mask=attention_mask,
bbox=bbox,
pixel_values=pixel_values,
labels=labels
)
def training_step(self, batch, batch_idx):
labels = batch["labels"]
outputs = self(
batch["input_ids"],
batch["attention_mask"],
batch["bbox"],
batch["pixel_values"],
labels
)
loss = outputs.loss
self.log("train_loss", loss)
self.train_accuracy(outputs.logits, labels)
self.log("train_acc", self.train_accuracy, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
labels = batch["labels"]
outputs = self(
batch["input_ids"],
batch["attention_mask"],
batch["bbox"],
batch["pixel_values"],
labels
)
loss = outputs.loss
self.log("val_loss", loss)
self.val_accuracy(outputs.logits, labels)
self.log("val_acc", self.val_accuracy, on_step=False, on_epoch=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.00001)
model_module = ModelModule(len(DOCUMENT_CLASSES))
model_checkpoint = ModelCheckpoint(
filename="{epcoh}-{step}-{val_loss:.4f}",
save_last=True,
save_top_k=3,
monitor="val_loss",
mode="min"
)
trainer = pl.Trainer(
accelerator="gpu",
precision=16,
devices=1,
max_epochs=4,
callbacks=[
model_checkpoint
]
)
trainer.fit(model_module, train_data_loader, test_data_loader)
I have also tried running with os.environ["TOKENIZERS_PARALLELISM"] = "false"
Environment : Python 3.8 , Azure Machine Learning Notebook Virtual machine size : Standard_NC6 (6 cores, 56 GB RAM, 380 GB disk)