bert-sentiment icon indicating copy to clipboard operation
bert-sentiment copied to clipboard

Loss doesn't decrease

Open KingS770234358 opened this issue 5 years ago • 0 comments

Hello: When I was running this code,my loss didn't decrease. I'm a newcomer to PyTorch. I will appreciate it very much if you could give me some advise.

coding:utf8

import os

import sys

sys.setrecursionlimit(1000000)

from tqdm import tqdm from loguru import logger

import torch from pytorch_transformers import BertConfig, BertForSequenceClassification

from .data import SSTDataset

os.environ["cuda_visible_devices"] = "0" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

device = torch.device("cpu")

class BertModel(BertForSequenceClassification): def init(self,dataset=None,epochs=30,batch_size=16,root=True,binary=False,bert="bert-large-uncased",save=False):

    logger.info(f"设置迭代轮数:{epochs}")
    self.epochs = epochs

    logger.info(f"设置批大小:{batch_size}")
    self.batch_size = batch_size

    logger.info(f"设置是否仅使用数据的根:{root}")
    self.root = root

    logger.info(f"设置模型保存模式:{save}")
    self.save = save

    logger.info(f"设置bert模型配置,初始化bert对象:{bert}")
    self.bert = bert
    self.config = BertConfig.from_pretrained(bert)
    logger.info(f"设置分类数:{binary}")
    self.binary = binary
    if not self.binary:
        self.config.num_labels = 5
    super(BertModel, self).__init__(config=self.config)

    logger.info(f"选择设备:{device}")
    self.device = device
    self.to(self.device)

    logger.info(f"设置数据集...")
    self.dataset = dataset
    if self.dataset == None:
        trainset = SSTDataset("train", root=True, binary=False)
        devset = SSTDataset("dev", root=True, binary=False)
        testset = SSTDataset("test", root=True, binary=False)
        self.dataset = {'trainset':trainset,'devset':devset,'testset':testset}

    logger.info("定义损失函数")
    self.lossfn = torch.nn.CrossEntropyLoss()
    logger.info("定义优化器")
    self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)

def train(self):

    for epoch in range(1, self.epochs):

        logger.info(f"训练阶段-{epoch}")
        train_loss, train_acc = self.train_one_epoch()
        logger.info(f"验证阶段-{epoch}")
        val_loss, val_acc = self.evaluate_one_epoch(dataset='devset')
        logger.info(f"测试阶段-{epoch}")
        test_loss, test_acc = self.evaluate_one_epoch(dataset='testset')
        logger.info(f"epoch={epoch}")
        logger.info(
            f"train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, test_loss={test_loss:.4f}"
        )
        logger.info(
            f"train_acc={train_acc:.3f}, val_acc={val_acc:.3f}, test_acc={test_acc:.3f}"
        )
        if self.save:
            label = "binary" if self.binary else "fine"
            nodes = "root" if self.root else "all"
            torch.save(self, f"{self.bert}__{nodes}__{label}__e{epoch}.pickle")

    logger.success("Done!")

def train_one_epoch(self):

    generator = torch.utils.data.DataLoader(
        self.dataset['trainset'], batch_size=self.batch_size, shuffle=True
    )
    # self.train()
    train_loss, train_acc = 0.0, 0.0
    for batch, labels in generator:

        batch, labels = batch.to(self.device), labels.to(self.device)

        loss, logits = self(batch, labels=labels)

        err = self.lossfn(logits, labels)
        self.optimizer.zero_grad()
        err.backward()
        self.optimizer.step()
        print(err)

        train_loss += loss.item()
        pred_labels = torch.argmax(logits, axis=1)
        train_acc += (pred_labels == labels).sum().item()
    train_loss /= len(self.dataset['trainset'])
    train_acc /= len(self.dataset['trainset'])

    return train_loss, train_acc

def evaluate_one_epoch(self,dataset='testset'):
    generator = torch.utils.data.DataLoader(
        self.dataset[dataset], batch_size=self.batch_size, shuffle=True
    )
    # self.eval()
    loss, acc = 0.0, 0.0
    with torch.no_grad():
        for batch, labels in tqdm(generator):
            batch, labels = batch.to(device), labels.to(device)
            logits = self(batch)[0]
            error = self.lossfn(logits, labels)
            loss += error.item()
            pred_labels = torch.argmax(logits, axis=1)
            acc += (pred_labels == labels).sum().item()
    loss /= len(self.dataset[dataset])
    acc /= len(self.dataset[dataset])
    return loss, acc

`

KingS770234358 avatar Dec 03 '19 16:12 KingS770234358