MSCG-Net icon indicating copy to clipboard operation
MSCG-Net copied to clipboard

Memory issue

Open vvuonghn opened this issue 1 year ago • 3 comments

Hi @samleoqh

Thank you for your release source code. It helps me a lot.

During training process, I met a problem related to memory. image

The process consume a lot of memory, over 150GB RAM. I think the problem in the validate function. Because you append all the input/output data to the inputs_all, gts_all, predictions_all

def validate(net, val_set, val_loader, criterion, optimizer, epoch, new_ep):
    net.eval()
    val_loss = AverageMeter()
    inputs_all, gts_all, predictions_all = [], [], []

    with torch.no_grad():
        for vi, (inputs, gts) in enumerate(val_loader):
            inputs, gts = inputs.cuda(), gts.cuda()
            N = inputs.size(0) * inputs.size(2) * inputs.size(3)
            outputs = net(inputs)

            val_loss.update(criterion(outputs, gts).item(), N)
            # val_loss.update(criterion(gts, outputs).item(), N)
            if random.random() > train_args.save_rate:
                inputs_all.append(None)
            else:
                inputs_all.append(inputs.data.squeeze(0).cpu())

            gts_all.append(gts.data.squeeze(0).cpu().numpy())
            predictions = outputs.data.max(1)[1].squeeze(1).squeeze(0).cpu().numpy()
            predictions_all.append(predictions)

    update_ckpt(net, optimizer, epoch, new_ep, val_loss,
                inputs_all, gts_all, predictions_all)

    net.train()
    return val_loss, inputs_all, gts_all, predictions_all

vvuonghn avatar Aug 14 '23 21:08 vvuonghn

ah, there is a bug, should move the three lines within else:

gts_all.append(gts.data.squeeze(0).cpu().numpy())
            predictions = outputs.data.max(1)[1].squeeze(1).squeeze(0).cpu().numpy()
            predictions_all.append(predictions)

and the value of save_rate is used to control what percent of val images will be appended for later visualization. I set it to 0.1 as default, it can be further lower like to be 0.001 if the number of val images is very large.

samleoqh avatar Aug 14 '23 22:08 samleoqh

Hi

i think if you remove the line as above, maybe the source code can not run evaluate, because they need predictions_all, gts_all, train_args.nb_classes

acc, acc_cls, mean_iu, fwavacc, f1 = evaluate(predictions_all, gts_all, train_args.nb_classes)

I think the best way to fix that is evaluate for every sample not for all val set

vvuonghn avatar Aug 15 '23 20:08 vvuonghn

yes, you are right. I'd refactor the code a bit when I get some free time. I remember the reason I appended all predictions together to compute metrics, it's because each test image only contains one or two classes among 9 classes, and appending them together can get a whole/stable confusion metrics across all val images. It's also fine to evaluate one by one then average them, like the val loss.

samleoqh avatar Aug 16 '23 08:08 samleoqh