simpleT5 icon indicating copy to clipboard operation
simpleT5 copied to clipboard

how to train in multi-gpus

Open jiangliqin opened this issue 2 years ago • 2 comments

I don't find the parameter for multi-gpus training

jiangliqin avatar Apr 13 '22 01:04 jiangliqin

You can change the num_gpus here: https://github.com/Shivanandroy/simpleT5/blob/cb4d89299824312ca468b741298fd3cc391e29e2/simplet5/simplet5.py#L379

Mahyar-Ali avatar Jan 19 '23 14:01 Mahyar-Ali

You can change the num_gpus here:

https://github.com/Shivanandroy/simpleT5/blob/cb4d89299824312ca468b741298fd3cc391e29e2/simplet5/simplet5.py#L379

To elaborate for anyone who wants to make this change - you can monkeypatch it with this code:

# monkeypatching this to have 2 gpus
from simplet5.simplet5 import LightningDataModule, LightningModel
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.progress import TQDMProgressBar
import pandas as pd
from simplet5 import SimpleT5

def modified_train(
    self,
    train_df: pd.DataFrame,
    eval_df: pd.DataFrame,
    source_max_token_len: int = 512,
    target_max_token_len: int = 512,
    batch_size: int = 8,
    max_epochs: int = 5,
    use_gpu: bool = True,
    outputdir: str = "outputs",
    early_stopping_patience_epochs: int = 0,  # 0 to disable early stopping feature
    precision=32,
    logger="default",
    dataloader_num_workers: int = 2,
    save_only_last_epoch: bool = False,
):
    """
    trains T5/MT5 model on custom dataset
    Args:
        train_df (pd.DataFrame): training datarame. Dataframe must have 2 column --> "source_text" and "target_text"
        eval_df ([type], optional): validation datarame. Dataframe must have 2 column --> "source_text" and "target_text"
        source_max_token_len (int, optional): max token length of source text. Defaults to 512.
        target_max_token_len (int, optional): max token length of target text. Defaults to 512.
        batch_size (int, optional): batch size. Defaults to 8.
        max_epochs (int, optional): max number of epochs. Defaults to 5.
        use_gpu (bool, optional): if True, model uses gpu for training. Defaults to True.
        outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
        early_stopping_patience_epochs (int, optional): monitors val_loss on epoch end and stops training, if val_loss does not improve after the specied number of epochs. set 0 to disable early stopping. Defaults to 0 (disabled)
        precision (int, optional): sets precision training - Double precision (64), full precision (32) or half precision (16). Defaults to 32.
        logger (pytorch_lightning.loggers) : any logger supported by PyTorch Lightning. Defaults to "default". If "default", pytorch lightning default logger is used.
        dataloader_num_workers (int, optional): number of workers in train/test/val dataloader
        save_only_last_epoch (bool, optional): If True, saves only the last epoch else models are saved at every epoch
    """
    self.data_module = LightningDataModule(
        train_df,
        eval_df,
        self.tokenizer,
        batch_size=batch_size,
        source_max_token_len=source_max_token_len,
        target_max_token_len=target_max_token_len,
        num_workers=dataloader_num_workers,
    )

    self.T5Model = LightningModel(
        tokenizer=self.tokenizer,
        model=self.model,
        outputdir=outputdir,
        save_only_last_epoch=save_only_last_epoch,
    )

    # add callbacks
    callbacks = [TQDMProgressBar(refresh_rate=5)]

    if early_stopping_patience_epochs > 0:
        early_stop_callback = EarlyStopping(
            monitor="val_loss",
            min_delta=0.00,
            patience=early_stopping_patience_epochs,
            verbose=True,
            mode="min",
        )
        callbacks.append(early_stop_callback)

    # add gpu support
    gpus = 2 if use_gpu else 0

    # add logger
    loggers = True if logger == "default" else logger

    # prepare trainer
    trainer = pl.Trainer(
        logger=loggers,
        callbacks=callbacks,
        max_epochs=max_epochs,
        gpus=gpus,
        #accelerator="gpu",
        devices="auto",
        precision=precision,
        log_every_n_steps=1,
    )

    # fit trainer
    trainer.fit(self.T5Model, self.data_module)

SimpleT5.train = modified_train

Of course, you can change the gpu variable to however many GPUs you want (I was training on the 2x T4 GPUs in Kaggle) through the above code and it should register properly.

JIBSIL avatar Mar 10 '24 18:03 JIBSIL