I2U-Net
I2U-Net copied to clipboard
我无法复现论文中的结果。
我用下面的脚本在CVC-CLINICDB上做测试,但是最后dice只有0.85左右。
from Models.I2U_Net import I2U_Net_L
from tqdm import tqdm
import torch
from torch.optim import lr_scheduler
from utils.dice_loss import get_soft_label
from utils.dice_loss_github import SoftDiceLoss_git, CrossentropyND
from Dataset.Dataset import Dataset
from torch.utils.data import DataLoader
from torchmetrics.segmentation import DiceScore
import matplotlib.pyplot as plt
def one_loss(out_c, target, num_classes=2):
soft_dice_loss2 = SoftDiceLoss_git(batch_dice=False, dc_log=False)
CE_loss_F = CrossentropyND()
target_soft_a = get_soft_label(target, num_classes)
target_soft = target_soft_a.permute(0, 3, 1, 2)
dice_loss_f = 1 + soft_dice_loss2(out_c, target_soft)
ce_loss_f = CE_loss_F(out_c, target)
loss_f = dice_loss_f + ce_loss_f
return loss_f
def train(model,data_loader,val_daloader,epoch):
interval = 40
dice = []
# Define optimizers and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-8)
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min = 0.00001) # lr_3
for i in tqdm(range(epoch)):
scheduler.step()
temp = 0
num = 0
for x, y in data_loader:
image = x.float().cuda()
target = y.float().cuda()
out_f = model(image)
loss = one_loss(out_f, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
temp += loss.item()
num += 1
print("loss", temp/num)
if i % interval == 0 and i != 0:
print("start eval")
model.eval()
dice_score = DiceScore(num_classes=1)
with torch.no_grad():
for x, y in tqdm(val_daloader):
image = x.float().cuda()
target = y.float().cuda()
out_f = model(image)
out_2 = out_f[:,1:2,:,:]
out_2 = (out_2 > 0.5).float()
dice_score.update(out_2, target)
dice.append(dice_score.compute().item())
print("dice", dice[-1])
model.train()
#绘制dice
plt.plot(dice)
maxdice = max(dice)
plt.axhline(y=maxdice, color='r', linestyle='--')
plt.text(0, maxdice, f'maxdice={maxdice:.4f}')
plt.xlabel('epoch')
plt.ylabel('dice')
plt.savefig('dice.png')
if __name__ == '__main__':
batch_size = 16
model = I2U_Net_L().cuda()
dataset = Dataset(is_train=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataset = Dataset(is_train=False)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4)
train(model,dataloader,val_dataloader,250)
torch.save(model.state_dict(), "model.pth")