SegLossBias
SegLossBias copied to clipboard
Why are the tests/directories differently structured?
(base) :~/loss/SegLossBias/tests$ python test_compound_losses.py Traceback (most recent call last): File "~/SegLossBias/tests/test_compound_losses.py", line 5, in <module> from seglossbias.modeling.compound_losses import CrossEntropyWithL1, CrossEntropyWithKL ModuleNotFoundError: No module named 'seglossbias'
I simply cloned and tried to run it but the test_compound_losses.py
seems to throw the error.
If I place it on the top most directory then it throws the following:
Traceback (most recent call last): File "~/SegLossBias/test_compound_losses.py", line 5, in <module> from seglossbias.modeling.compound_losses import CrossEntropyWithL1, CrossEntropyWithKL File "~/SegLossBias/seglossbias/modeling/__init__.py", line 1, in <module> from .network import build_model File "~/SegLossBias/seglossbias/modeling/network.py", line 6, in <module> import segmentation_models_pytorch as smp ModuleNotFoundError: No module named 'segmentation_models_pytorch'
Am I doing something wrong? Is there some other way to run the tests?
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import matplotlib.pyplot as plt
torch.manual_seed(101)
BINARY_MODE: str = "binary"
MULTICLASS_MODE: str = "multiclass"
MULTILABEL_MODE: str = "multilabel"
EPS = 1e-10
def expand_onehot_labels(labels, target_shape, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_zeros(target_shape)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(valid_mask, as_tuple=True)
if inds[0].numel() > 0:
if labels.dim() == 3:
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
else:
bin_labels[inds[0], labels[valid_mask]] = 1
return bin_labels, valid_mask
def get_region_proportion(x: torch.Tensor, valid_mask: torch.Tensor = None) -> torch.Tensor:
"""Get region proportion
Args:
x : one-hot label map/mask
valid_mask : indicate the considered elements
"""
if valid_mask is not None:
if valid_mask.dim() == 4:
x = torch.einsum("bcwh, bcwh->bcwh", x, valid_mask)
cardinality = torch.einsum("bcwh->bc", valid_mask)
else:
x = torch.einsum("bcwh,bwh->bcwh", x, valid_mask)
cardinality = torch.einsum("bwh->b", valid_mask).unsqueeze(dim=1).repeat(1, x.shape[1])
else:
cardinality = x.shape[2] * x.shape[3]
region_proportion = (torch.einsum("bcwh->bc", x) + EPS) / (cardinality + EPS)
return region_proportion
class CompoundLoss(nn.Module):
"""
The base class for implementing a compound loss:
l = l_1 + alpha * l_2
"""
def __init__(self, mode: str,
alpha: float = 1.,
factor: float = 1.,
step_size: int = 0,
max_alpha: float = 100.,
temp: float = 1.,
ignore_index: int = 255,
background_index: int = -1,
weight: Optional[torch.Tensor] = None) -> None:
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__()
self.mode = mode
self.alpha = alpha
self.max_alpha = max_alpha
self.factor = factor
self.step_size = step_size
self.temp = temp
self.ignore_index = ignore_index
self.background_index = background_index
self.weight = weight
def cross_entropy(self, inputs: torch.Tensor, labels: torch.Tensor):
if self.mode == MULTICLASS_MODE:
loss = F.cross_entropy(
inputs, labels, weight=self.weight, ignore_index=self.ignore_index)
else:
if labels.dim() == 3:
labels = labels.unsqueeze(dim=1)
loss = F.binary_cross_entropy_with_logits(inputs, labels.type(torch.float32))
# loss = F.binary_cross_entropy(inputs, labels)
return loss
def adjust_alpha(self, epoch: int) -> None:
if self.step_size == 0:
return
if (epoch + 1) % self.step_size == 0:
# curr_alpha = self.alpha
self.alpha = min(self.alpha * self.factor, self.max_alpha)
# logger.info(
# "CompoundLoss : Adjust the tradoff param alpha : {:.3g} -> {:.3g}".format(curr_alpha, self.alpha)
# )
def get_gt_proportion(self, mode: str,
labels: torch.Tensor,
target_shape,
ignore_index: int = 255):
if mode == MULTICLASS_MODE:
bin_labels, valid_mask = expand_onehot_labels(labels, target_shape, ignore_index)
else:
valid_mask = (labels >= 0) & (labels != ignore_index)
if labels.dim() == 3:
labels = labels.unsqueeze(dim=1)
bin_labels = labels
gt_proportion = get_region_proportion(bin_labels, valid_mask)
return gt_proportion, valid_mask
def get_pred_proportion(self, mode: str,
logits: torch.Tensor,
temp: float = 1.0,
valid_mask=None):
if mode == MULTICLASS_MODE:
preds = F.log_softmax(temp * logits, dim=1).exp()
else:
preds = F.logsigmoid(temp * logits).exp()
# print('Logits:', logits)
# print('Logits Shape:', logits.shape)
# plt.imshow(logits[0, 0, :, :].detach().numpy())
# plt.show()
pred_proportion = get_region_proportion(preds, valid_mask)
return pred_proportion
class CrossEntropyWithL1(CompoundLoss):
"""
Cross entropy loss with region size priors measured by l1.
The loss can be described as:
l = CE(X, Y) + alpha * |gt_region - prob_region|
"""
def forward(self, inputs: torch.Tensor, labels: torch.Tensor):
# ce term
loss_ce = self.cross_entropy(inputs, labels)
# regularization
gt_proportion, valid_mask = self.get_gt_proportion(self.mode, labels, inputs.shape)
pred_proportion = self.get_pred_proportion(self.mode, inputs, temp=self.temp, valid_mask=valid_mask)
loss_reg = (pred_proportion - gt_proportion).abs().mean()
loss = loss_ce + self.alpha * loss_reg
return loss, loss_ce, loss_reg
# Initialize tensors
network_outputs = torch.zeros(1, 55, 55) # 1 batch, 1 channel, 55x55 size
binary_label = torch.zeros(1, 55, 55) # 1 batch, 1 channel, 55x55 size
''' 100% overlap '''
network_outputs[:, 1:11, 35:45] = 1
binary_label[:, 1:11, 35:45] = 1
# fig, ax = plt.subplots(1, 2)
# ax[0].imshow(network_outputs[0, :, :].detach().numpy())
# ax[0].set_title('Network Outputs')
# ax[1].imshow(binary_label[0, :, :].detach().numpy())
# ax[1].set_title('Binary Label')
# plt.show()
print('Dice Coefficient:', (2 * (network_outputs * binary_label).sum()) / (network_outputs.sum() + binary_label.sum()))
print('Cross Entropy:', F.binary_cross_entropy(network_outputs, binary_label))
loss = CrossEntropyWithL1(mode='binary', alpha=1.0, temp=1.0)
network_outputs = network_outputs.unsqueeze(0)
print(loss(network_outputs, binary_label))
Although crude, this works as a tester :)