boundary-loss
boundary-loss copied to clipboard
How to apply the boundary loss to 3D images both efficiently and correctly?
Hi, thanks for sharing your code. I am trying to use the boundary loss for 3D (really high-resolution) image segmentation, but I have problems with the implementation of the loss function both efficiently and correctly. For 3D image segmentation, a popular way is to train the networks using image patches. Often the time, the training samples include image patches that belong to the background. And for these samples, a naive generalization of your implementation may give SDM that are all 0s (using eras version of the loss function). To me, this does not make sense because even if these samples do not contain any foreground voxels, the SDM should not be 0s in reality. I think it makes more sense if the SDM is calculated based on the entire images rather than image patches. How do you think about this problem?
Also, I found it pretty time consuming to calculate SDM in 3D cases. How can the time efficiency be improved?
Thanks
The solution is to pre-compute offline the distance map in 3D, save them into a .npy
in the with the axises kxyz
, with k
being the class axis. Pay attention to spatial resolution at this step -- the scipy function has an extra, optional parameter for that.
Then, in the dataloader, you load the 3d distmap as is (no extra transform besides converting to a tensor), and subpatch it as the same time as the original image.
You can then do the usual multiplication between distance map and softmaxes.
This is what I implemented for the extension, but didn't had time to put it in the repo yet. I will do so soon, and then point to the exact code parts doing that, but this should already give you a rough idea on how to proceed. Let me know if you need other details in the meantime.
Hoel
The solution is to pre-compute offline the distance map in 3D, save them into a
.npy
in the with the axiseskxyz
, withk
being the class axis. Pay attention to spatial resolution at this step -- the scipy function has an extra, optional parameter for that.Then, in the dataloader, you load the 3d distmap as is (no extra transform besides converting to a tensor), and subpatch it as the same time as the original image.
You can then do the usual multiplication between distance map and softmaxes.
This is what I implemented for the extension, but didn't had time to put it in the repo yet. I will do so soon, and then point to the exact code parts doing that, but this should already give you a rough idea on how to proceed. Let me know if you need other details in the meantime.
Hoel
Yes, this is a solution that can solve part of the problem because for training data that are obtained via data augmentation, it seems there is no better choice other than computing the SDM on the fly. In this case, training the network efficiently can be a big issue.
BTW, I want to know if the segmentation performance is sensitive to the value of loss weight for boundary loss?
Thanks
With respect to the data augmentation
there is no better choice other than computing the SDM on the fly
When you refer to "on the fly", you mean to compute the distance map inside the loss function ?
The way I see it, the pre-computed distance map can be augmented as well, just like we perform the augmentation on the original ground truth. The overall code would look like this:
from pathlib import Path
from Typing import Dict, List, Tuple
from torch import Tensor
from torch.utils.data import Dataset
class DistDataset(Dataset):
def __init__(self, *args, **kwargs):
...
self.files: List[Tuple[Path, Path, Path]]
def __getitem__(index: int) -> Dict[str, Tensor]:
img_path, gt_path, dist_path = self.files[index]
# ... perform the transforms here
aug_img, aug_gt, aug_dist = augment(img, gt, dist)
del img, gt, dist # Avoid returning those by accident
return {"img": aug_img, # CWH shape
"gt": aug_gt, # KWH shape
"distmap": aug_dist} # KWH shape
# Then in the training loop
α = 0.01
for data in train_loader:
imgs = data["img"].to(device) # BKWH shape
gts = data["gt"].to(device) # BKWH shape
dists = data["distmap"].to(device) # BKWH shape
optimizer.zero_grads()
pred_probs = softmax(net(imgs))
dsc_loss = DiceLoss(gts, pred_probs)
bl_loss = BoundaryLoss(dists, pred_probs)
total_loss = dsc_loss + α * bl_loss
total_loss.backward()
optimizer.step()
BTW, I want to know if the segmentation performance is sensitive to the value of loss weight for boundary loss?
In our experiments, it was somewhat sensitive, but still consistently gave some improvement even with a sub-optimal value: Table 3 in our extension:
Increasing and rebalancing the values were not only better in perf, but also much simpler to tune -- to me that is their main advantage.
Thanks for your response. I tried your loss function on our dataset, however, I have not seen improved performance till now. I want to know if the simplification from the differential form to integral form hold for 3D cases as the 2D example you show in your paper? The reason I ask this question is that I think maybe the simplification does not hold for 3D boundary and 3D surface. If it still hold, could you please clarify and send me some reference articles to show that. Thanks!
Get Outlook for iOShttps://aka.ms/o0ukef
From: Hoel KERVADEC [email protected] Sent: Sunday, November 29, 2020 12:40:21 PM To: LIVIAETS/boundary-loss [email protected] Cc: Chen, Xiaoyang [email protected]; Author [email protected] Subject: Re: [LIVIAETS/boundary-loss] How to apply the boundary loss to 3D images both efficiently and correctly? (#29)
With respect to the data augmentation
there is no better choice other than computing the SDM on the fly
When you refer to "on the fly", you mean to compute the distance map inside the loss function ?
The way I see it, the pre-computed distance map can be augmented as well, just like we perform the augmentation on the original ground truth. The overall code would look like this:
from pathlib import Path
from Typing import Dict, List, Tuple
from torch import Tensor
from torch.utils.data import Dataset
class DistDataset(Dataset):
def __init__(self, *args, **kwargs):
...
self.files: List[Tuple[Path, Path, Path]]
def __getitem__(index: int) -> Dict[str, Tensor]:
img_path, gt_path, dist_path = self.files[index]
# ... perform the transforms here
aug_img, aug_gt, aug_dist = augment(img, gt, dist)
del img, gt, dist # Avoid returning those by accident
return {"img": aug_img, # CWH shape
"gt": aug_gt, # KWH shape
"distmap": aug_dist} # KWH shape
Then in the training loop
α = 0.01
for data in train_loader:
imgs = data["img"].to(device) # BKWH shape
gts = data["gt"].to(device) # BKWH shape
dists = data["distmap"].to(device) # BKWH shape
optimizer.zero_grads()
pred_probs = softmax(net(imgs))
dsc_loss = DiceLoss(gts, pred_probs)
bl_loss = BoundaryLoss(dists, pred_probs)
total_loss = dsc_loss + α * bl_loss
total_loss.backward()
optimizer.step()
BTW, I want to know if the segmentation performance is sensitive to the value of loss weight for boundary loss?
In our experiments, it was somewhat sensitive, but still consistently gave some improvement even with a sub-optimal value: Table 3 in our extensionhttps://arxiv.org/pdf/1812.07032.pdf#page=17:
[Screenshot_2020-11-29 Boundary loss for highly unbalanced segmentation - 1812 07032 pdf]https://user-images.githubusercontent.com/4191866/100549218-aa836b80-323f-11eb-80ce-8a4eaed1b952.png
Increasing and rebalancing the values were not only better in perf, but also much simpler to tune -- to me that is their main advantage.
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHubhttps://github.com/LIVIAETS/boundary-loss/issues/29#issuecomment-735429571, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AHVAPZ5B43LIUQADNU6VNPTSSKBQLANCNFSM4UFGJVQQ.
Also, what does ‘rebalance’ mean and how to implement it? Thanks
Get Outlook for iOShttps://aka.ms/o0ukef
From: Chen, Xiaoyang [email protected] Sent: Sunday, November 29, 2020 2:28:53 PM To: LIVIAETS/boundary-loss [email protected]; LIVIAETS/boundary-loss [email protected] Cc: Author [email protected] Subject: Re: [LIVIAETS/boundary-loss] How to apply the boundary loss to 3D images both efficiently and correctly? (#29)
Thanks for your response. I tried your loss function on our dataset, however, I have not seen improved performance till now. I want to know if the simplification from the differential form to integral form hold for 3D cases as the 2D example you show in your paper? The reason I ask this question is that I think maybe the simplification does not hold for 3D boundary and 3D surface. If it still hold, could you please clarify and send me some reference articles to show that. Thanks!
Get Outlook for iOShttps://aka.ms/o0ukef
From: Hoel KERVADEC [email protected] Sent: Sunday, November 29, 2020 12:40:21 PM To: LIVIAETS/boundary-loss [email protected] Cc: Chen, Xiaoyang [email protected]; Author [email protected] Subject: Re: [LIVIAETS/boundary-loss] How to apply the boundary loss to 3D images both efficiently and correctly? (#29)
With respect to the data augmentation
there is no better choice other than computing the SDM on the fly
When you refer to "on the fly", you mean to compute the distance map inside the loss function ?
The way I see it, the pre-computed distance map can be augmented as well, just like we perform the augmentation on the original ground truth. The overall code would look like this:
from pathlib import Path
from Typing import Dict, List, Tuple
from torch import Tensor
from torch.utils.data import Dataset
class DistDataset(Dataset):
def __init__(self, *args, **kwargs):
...
self.files: List[Tuple[Path, Path, Path]]
def __getitem__(index: int) -> Dict[str, Tensor]:
img_path, gt_path, dist_path = self.files[index]
# ... perform the transforms here
aug_img, aug_gt, aug_dist = augment(img, gt, dist)
del img, gt, dist # Avoid returning those by accident
return {"img": aug_img, # CWH shape
"gt": aug_gt, # KWH shape
"distmap": aug_dist} # KWH shape
Then in the training loop
α = 0.01
for data in train_loader:
imgs = data["img"].to(device) # BKWH shape
gts = data["gt"].to(device) # BKWH shape
dists = data["distmap"].to(device) # BKWH shape
optimizer.zero_grads()
pred_probs = softmax(net(imgs))
dsc_loss = DiceLoss(gts, pred_probs)
bl_loss = BoundaryLoss(dists, pred_probs)
total_loss = dsc_loss + α * bl_loss
total_loss.backward()
optimizer.step()
BTW, I want to know if the segmentation performance is sensitive to the value of loss weight for boundary loss?
In our experiments, it was somewhat sensitive, but still consistently gave some improvement even with a sub-optimal value: Table 3 in our extensionhttps://arxiv.org/pdf/1812.07032.pdf#page=17:
[Screenshot_2020-11-29 Boundary loss for highly unbalanced segmentation - 1812 07032 pdf]https://user-images.githubusercontent.com/4191866/100549218-aa836b80-323f-11eb-80ce-8a4eaed1b952.png
Increasing and rebalancing the values were not only better in perf, but also much simpler to tune -- to me that is their main advantage.
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHubhttps://github.com/LIVIAETS/boundary-loss/issues/29#issuecomment-735429571, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AHVAPZ5B43LIUQADNU6VNPTSSKBQLANCNFSM4UFGJVQQ.
Thanks for your response. I tried your loss function on our dataset, however, I have not seen improved performance till now. I want to know if the simplification from the differential form to integral form hold for 3D cases as the 2D example you show in your paper? The reason I ask this question is that I think maybe the simplification does not hold for 3D boundary and 3D surface. If it still hold, could you please clarify and send me some reference articles to show that. Thanks!
Yes the result still holds, though in 3D you need to take into account the spatial resolution of each axis, as it might differ. The updated distance computation function now looks like this:
def one_hot2dist(seg: np.ndarray, resolution: Tuple[float, float, float] = None,
dtype=None) -> np.ndarray:
assert one_hot(torch.tensor(seg), axis=0)
K: int = len(seg)
res = np.zeros_like(seg, dtype=dtype)
for k in range(K):
posmask = seg[k].astype(np.bool)
if posmask.any():
negmask = ~posmask
res[k] = eucl_distance(negmask, sampling=resolution) * negmask \
- (eucl_distance(posmask, sampling=resolution) - 1) * posmask
# The idea is to leave blank the negative classes
# since this is one-hot encoded, another class will supervise that pixel
return res
resolution = None
correspond to sampling = (1, 1, 1)
Another thing to take into account: if the space between each slice becomes too big (like 1cm on the z axis while it is 1mm on the x and y axises), then maybe the 3D distance will not make much sense. It will depend on your application.
Also, what does ‘rebalance’ mean and how to implement it? Thanks
Rebalancing correspond to start with a high weight for the DSC loss weight, and a smaller one on the boundary loss, and to slowly shift them:
α = 0.01
for e in range(epochs):
for data in train_loader:
...
total_loss = (1 - α) * dsc_loss + α * bl_loss
total_loss.backward()
optimizer.step()
α = max(α + 0.01, 0.99)
I tried to understand the the mathematics in your paper. It is interesting to see the beautiful connection between Eq 2 and 3. However, I found it difficult to understand your derivation to connect the two. Specifically, in the paper, you mentioned that the two can be connected using the following:
To me, it is not obvious why the first two are equivalent because getting dD_G/dq is not a constant for the second term after the minus sign in
is also related to q and I think cannot be easily formulated.
Could you please explain more on this?
That one fell through the cracks (sorry for that), please feel free to re-open/reply if still relevant.