TopologyLayer
TopologyLayer copied to clipboard
Help with Topological Regularization
This is a very interesting paper. Thank you for sharing your code.
I am trying to add topological regularization to a two-class U-Net model. Currently, I am training this model with the cross entropy loss. I know that the output of my U-Net model is a single connected component without holes. However, in some cases, the output of my model has holes and islands (see the image below for a conceptual example).
I thought if I add a topological regularizer it might solve this issue. The following code is my best guess given the examples and #17:
class TopoLoss(nn.Module):
def __init__(self, size):
super(TopLoss, self).__init__()
self.pdfn = LevelSetLayer2D(size=size, sublevel=False)
self.topfn = PartialSumBarcodeLengths(dim=0, skip=0)
self.topfn2 = SumBarcodeLengths(dim=0)
def forward(self, beta):
dgminfo = self.pdfn(beta)
return self.topfn(dgminfo) + self.topfn2(dgminfo)
# Define the U-Net model and the cross entropy loss
# ...
tloss = TopoLoss((50, 50)) # image width and height
# blend tloss with lambda and add it to the cross entropy
loss = lambda * tloss(likelihoods) + ce_loss(likelihoods, ground_truth)
Is this roughly the correct implementation?
Hi! Sorry for the late reply. Did you figure this out?
No, but I am still interested in the question. I don't have a background in computational topology and would appreciate help.
Hi! I have the same question about this. Have you figured it out?
@SilenKZYoung unfortunately no. I do not think that the authors monitor the issue tracker.