difftopk
difftopk copied to clipboard
Multi Label
Hi, is it possible to use top-k entropy loss for multi-label classification problem? Which each of the gt_label can be the top-K?
Hi, yes, that should be possible.
In the forward
(https://github.com/Felix-Petersen/difftopk/blob/76ef96db648058a73571628f1db5e6a9f4478bfd/difftopk/losses.py#L86), it would primarily require replacing the losses of this style
torch.nn.functional.nll_loss(torch.log(topk_distribution * (1 - 2e-7) + 1e-7), labels)
to be replaced by something like:
l = 0
for label_idx in range(num_labels):
l = l + torch.nn.functional.nll_loss(torch.log(topk_distribution * (1 - 2e-7) + 1e-7), labels[:, label_idx])
l = l / num_labels
(This is for num_labels
labels, represented using labels
being a batch_size x num_labels
LongTensor.)
For the optional self.m is not None
part of the forward
, keep in mind that this would also require some adjustment.
I'm interested in your results. If you can, please share your results, I'm happy to help.
Thanks for your explanation. Here is my case, actually I have a K-Hot label, and when K=1, I would do softmax and cross-entropy for training and multinomial for sampling. Now when K > 1, the softmax operator would suppress the Top-K probability (Pr(Top-k) <= 1/K). So I would think this would be the K-Hot to be the top-k prediction problem. But I cannot find if this is the case for topk_distribution since we will maximize the probability of sum_i^k sum_j^k P[i, j] to be maximized.
In this case, it would be something like
- (torch.log(topk_distribution * (1 - 2e-7) + 1e-7) * labels).mean(0).sum(-1) / labels.sum(-1)
where labels
is a k-hot FloatTensor of shape batch_size x num_labels
. The above equation supports different k for different elements of the batch, and reweighs them correspondingly. In extended case of using the m
(which is helpful for large numbers of classes / recommended for >100 classes), the respective selection will require additional considerations for ensuring that the "reduced m classes" are always containing the top-k and also all remaining m-k top predicted scoring classes.
Is your k the same for all elements of the batch? If not, you could get topk_distribution
(with assuming p_k=[0s..., 1, 0s...] with the 1 at index k for each element) via
topk_distribution = (P_topk * labels.sort(dim=-1, descending=False)[0].unsqueeze(1)[:, :, -self.k:]).sum(-1)
where self.k
is the maximum k to be considered (as in the code). Here, the sort has the purpose of producing k-hot vectors that each have the last k entries being 1.
If k is the same for each elem in the batch, you can do
topk_distribution = P_topk[:, :, -self.k:].sum(-1)
Yes, here is my implementation.
marginal_probs = [1/K] * K
scores = scores.unsqueeze(-1)
sorted_scores = scores.topk(k=K, dim=1, largest=True)[0]
soft_permutation = (scores.transpose(1, 2) - sorted_scores).abs().pow(power).neg() / tau
soft_permutation = soft_permutation.softmax(-1) # [B, K, M]
topk_distribution = 0
for k, p_k in enumerate(marginal_probs):
if p_k == 0:
continue
topk_distribution += p_k * soft_permutation[:, :(k+1)].sum(1)
topk_distribution = -torch.log(topk_distribution + 1e-8)
loss = (topk_distribution * labels).sum(-1).mean(0)
Please correct me if I made mistake. Besides, if it is a generation problem. How should I do stochastic sampling during inference?
Hi, this looks like you are using NeuralSort or SoftSort. I recommend Cauchy Odd-Even Differentiable Sorting Networks for better performance.
In soft_permutation[:, :(k+1)].sum(1)
, you are summing over the first entries, which seems wrong unless you use a convention opposite of the convention in the difftopk
library (in 2 ways: negative order, and transposition).
Considering that your K
is constant, it should simply be
topk_distribution = P_topk[:, :, -K:].sum(-1)
- (torch.log(topk_distribution * (1 - 2e-7) + 1e-7) * labels).mean(0).sum(-1)
where P_topk
is computed as in this line: https://github.com/Felix-Petersen/difftopk/blob/76ef96db648058a73571628f1db5e6a9f4478bfd/difftopk/losses.py#L120
I think I utilized the convention opposite of the convention in the difftopk library. For example, the shape of soft_permutation
would be [B, K, N]
, and the soft_permutation[:, i]
would be the probability of being top-i? Is that correct?
No, the probability of something being top-i, i.e., among the top i elements is: topk_distribution = P_topk[:, :, -K:].sum(-1)
. In your convention, probably topk_distribution = soft_permutation[:, :K, :].sum(-2)
.
It's important to use "among the top-k". If not, it will be exactly the kth largest, implying an ordering, and the loss doesn't make sense.
I am still have little confused. Since in my implementation. I used the for loop from top-1 to top-K, where for each iteration k, I apply the topk_distribution += p_k * soft_permutation[:, :(k+1)].sum(1)
to calculate the accumlated probability of 'among top-k', and p_k is the probability of being among top-k?
So should I directly use topk_distribution = soft_permutation[:, :K].sum(1)
without for loop iteration? Please correct me if I misunderstood.
Is the following updated code snippet correct?
scores = scores.unsqueeze(-1)
sorted_scores = scores.topk(k=K, dim=1, largest=True)[0]
soft_permutation = (scores.transpose(1, 2) - sorted_scores).abs().pow(power).neg() / tau
soft_permutation = soft_permutation.softmax(-1) # [B, K, M]
topk_distribution = soft_permutation[:, :K].sum(1)
topk_distribution = -torch.log(topk_distribution + 1e-8)
loss = (topk_distribution * labels).sum(-1).mean(0)
This part
topk_distribution = soft_permutation[:, :K].sum(1)
topk_distribution = -torch.log(topk_distribution + 1e-8)
loss = (topk_distribution * labels).sum(-1).mean(0)
seems correct assuming you input a respectively correct soft_permutation
. Again, I'd recommend going with DSNs and integrating the "m
" trick for best performance.
Yes, without a loop, the sum is correct. But if you use a loop, you actually still need the sum, just that you don't need a loop because your p_k
weights should be 0 for the K-1 first places and then 1 (because you care about top-K and not about some percentage top-1, some percentage top-2, etc. )
In your application, what's your number of classes, and what is K?
In your application, what's your number of classes, and what is K?
The number of classes would be depending on the vocab size (i.e., 1k, 4k, 16k, 64k). And the K usually is [1, 2, 4, 8]. Since my goal is to predict an unordered set sized K from the classes.
In this case, I'd strongly recommend DSNs, and setting m
to something like 32, 50, or 64, which empirically stabilizes training drastically.
I don't know much about the DSNs. Since I prefer the simplicity for loss function, that is why I choose softsort. Would you mind explaining what is DSNs and how it works better than softsort?
Sorry for the delay in response.
Differentiable Sorting Networks are a differentiable relaxation of the classic sorting algorithm called "Sorting Networks". Especially monotonic DSNs (like Cauchy DSN) provide an improved gradient quality and better optimization behavior compared to algorithms like SoftSort. I think it will be easiest to understand via my videos on the topic:
https://www.youtube.com/watch?v=38dvqdYEs1o (original DSNs) https://www.youtube.com/watch?v=Rl-sFaE1z4M (monotonic DSNs extension, animated)
Feel free to ask questions about it here if anything is unclear.