I2U-Net icon indicating copy to clipboard operation
I2U-Net copied to clipboard

我无法复现论文中的结果。

Open Aurora2580 opened this issue 7 months ago • 0 comments

我用下面的脚本在CVC-CLINICDB上做测试,但是最后dice只有0.85左右。

from Models.I2U_Net import I2U_Net_L
from tqdm import tqdm
import torch
from torch.optim import lr_scheduler

from utils.dice_loss import get_soft_label
from utils.dice_loss_github import SoftDiceLoss_git, CrossentropyND

from Dataset.Dataset import Dataset
from torch.utils.data import DataLoader

from torchmetrics.segmentation import DiceScore
import matplotlib.pyplot as plt

def one_loss(out_c, target, num_classes=2):

    soft_dice_loss2 = SoftDiceLoss_git(batch_dice=False, dc_log=False)
    CE_loss_F = CrossentropyND()
    
    target_soft_a = get_soft_label(target, num_classes) 
    target_soft = target_soft_a.permute(0, 3, 1, 2)
    dice_loss_f = 1 + soft_dice_loss2(out_c, target_soft)
    ce_loss_f = CE_loss_F(out_c, target)
    loss_f = dice_loss_f + ce_loss_f
    return loss_f



def train(model,data_loader,val_daloader,epoch):

    interval = 40
    dice = []
    
    # Define optimizers and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-8)    
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min = 0.00001)    # lr_3


    for i in tqdm(range(epoch)):
        scheduler.step()
        temp = 0
        num = 0

        for x, y in data_loader:
            image = x.float().cuda()                                   
            target = y.float().cuda()                                  

            out_f = model(image)

            loss = one_loss(out_f, target)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            temp += loss.item()
            num += 1
        
        print("loss", temp/num)
        
        if i % interval == 0 and i != 0:
            print("start eval")
            model.eval()
            dice_score = DiceScore(num_classes=1)
            with torch.no_grad():
                for x, y in tqdm(val_daloader):

                    image = x.float().cuda()
                                            
                    target = y.float().cuda()

                    out_f = model(image)
                                                
                    out_2 = out_f[:,1:2,:,:]

                    out_2 = (out_2 > 0.5).float()

                    dice_score.update(out_2, target)

            dice.append(dice_score.compute().item())
            print("dice", dice[-1])
            model.train()
    
    #绘制dice
    plt.plot(dice)
    maxdice = max(dice)
    plt.axhline(y=maxdice, color='r', linestyle='--')
    plt.text(0, maxdice, f'maxdice={maxdice:.4f}')
    plt.xlabel('epoch')
    plt.ylabel('dice')
    plt.savefig('dice.png')
    

if __name__ == '__main__':

    batch_size = 16
    model = I2U_Net_L().cuda()

    dataset = Dataset(is_train=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    val_dataset = Dataset(is_train=False)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4)

    train(model,dataloader,val_dataloader,250)

    torch.save(model.state_dict(), "model.pth")



    

Aurora2580 avatar Jun 07 '25 10:06 Aurora2580