Felix Petersen

Results 10 comments of Felix Petersen

Hi, sorry for the late response, here is an implementation of it. Regarding the SM of the paper, after publishing the paper and for the publication of the code we...

Yes, `diffsort` is actually a diffargsort or more concretely the module returns a tuple of `sorted_vectors, permutation_matrices` where `sorted_vectors` is the output of the sorting operation and `permutation_matrices` is the...

Hi, yes, that should be possible. In the `forward` (https://github.com/Felix-Petersen/difftopk/blob/76ef96db648058a73571628f1db5e6a9f4478bfd/difftopk/losses.py#L86), it would primarily require replacing the losses of this style ```python torch.nn.functional.nll_loss(torch.log(topk_distribution * (1 - 2e-7) + 1e-7), labels) ```...

In this case, it would be something like ```python - (torch.log(topk_distribution * (1 - 2e-7) + 1e-7) * labels).mean(0).sum(-1) / labels.sum(-1) ``` where `labels` is a k-hot FloatTensor of shape...

Hi, this looks like you are using NeuralSort or SoftSort. I recommend Cauchy Odd-Even Differentiable Sorting Networks for better performance. In `soft_permutation[:, :(k+1)].sum(1)`, you are summing over the first entries,...

No, the probability of something being top-i, i.e., among the top i elements is: `topk_distribution = P_topk[:, :, -K:].sum(-1)`. In your convention, probably `topk_distribution = soft_permutation[:, :K, :].sum(-2)`. It's important...

This part ```python topk_distribution = soft_permutation[:, :K].sum(1) topk_distribution = -torch.log(topk_distribution + 1e-8) loss = (topk_distribution * labels).sum(-1).mean(0) ``` seems correct assuming you input a respectively correct `soft_permutation`. Again, I'd recommend...

In your application, what's your number of classes, and what is K?

In this case, I'd strongly recommend DSNs, and setting `m` to something like 32, 50, or 64, which empirically stabilizes training drastically.

Sorry for the delay in response. Differentiable Sorting Networks are a differentiable relaxation of the classic sorting algorithm called "Sorting Networks". Especially monotonic DSNs (like Cauchy DSN) provide an improved...