lightly icon indicating copy to clipboard operation
lightly copied to clipboard

PMSN Loss

Open jafarinia opened this issue 10 months ago • 3 comments

Hi,

Thank you for the wonderful repository, and I truly appreciate your implementation of PMSN loss—something even the original author did not provide.

As you can see, the loss function for MSN is illustrated in this image:
Image

It consists of two components, each implemented as follows:

loss = torch.mean(torch.sum(torch.log(probs**(-targets)), dim=1))

# Step 4: compute me-max regularizer
rloss = 0.
if me_max:
    avg_probs = AllReduce.apply(torch.mean(probs, dim=0))
    rloss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))

However, the author/implementer also added the following term:

sloss = 0.
if use_entropy:
    sloss = torch.mean(torch.sum(torch.log(probs**(-probs)), dim=1))

This additional term is not mentioned anywhere in the paper. However, it is actively used in their configuration file (msn_vits16.yaml), where it is set to true and included in the loss function.

In your implementation of MSNLoss and PMSNLoss (as shown in msn_loss.py and pmsn_loss.py), Image We do not see this sloss term—it is entirely ignored. I would like to understand why this was omitted. What was your reasoning behind this decision?

Do you think incorporating it could have improved the final results?

Finally, my main question: If we want to follow the approach taken by the author of PMSN (who unfortunately does not respond to emails), what would be the correct choice? Should we simply replace the rloss term with the KL term you provided while removing sloss, or should we keep it?

Looking forward to your insights.

jafarinia avatar Feb 03 '25 12:02 jafarinia

Hi @jafarinia 👋 glad you are enjoying lightly!

I will try to dissect a bit, which also helps myself figuring out if we have an issue at hand here, the original loss implementation of MSN comes from #895, but does not give any hints towards the issue, and the loss from the paper looks as follows:

$${\frac{1}{M B}}\sum_{i=1}^{B}\sum_{m=1}^{M}H(p_{i}^{+},p_{i,m})-\lambda H(\overline{p}), \quad where \quad \overline{{{p}}}:=\frac{1}{M B}\sum_{i=1}^{B}\sum_{m=1}^{M}p_{i,m}$$

The first term of this loss is implemented in https://github.com/lightly-ai/lightly/blob/9e1d2c63b1653771a8f21fc7e83efd181b42024f/lightly/loss/msn_loss.py#L252

and the second one in https://github.com/lightly-ai/lightly/blob/9e1d2c63b1653771a8f21fc7e83efd181b42024f/lightly/loss/msn_loss.py#L257

So far so good.

However you are right that the sloss term from the original implementation is missing. It seems that with sloss, they add another entropy term $c \cdot H(p_i^+)$, giving the total loss function:

$$ {\frac{1}{M B}}\sum_{i=1}^{B}\sum_{m=1}^{M}H(p_{i}^{+},p_{i,m})-\lambda H(\overline{p}) + c H(p_i^+) $$

Tbh, without the paper mentioning this loss term it is really difficult to draw conclusions as to why they decided to go this route and/or if this has significant influence on the model performance. Maybe @guarin knows more. However, if you would like to investigate more (which is of course very welcome from our side), I would rather approach the authors behind MSN than the ones from PMSN, since that seems to be the place where the issue arises. We might consider benchmarking MSN and PMSN with/without this additional loss term in the future and compare the performance to get a better idea.

liopeer avatar Feb 03 '25 15:02 liopeer

Hi @jafarinia 👋 glad you are enjoying lightly!

I will try to dissect a bit, which also helps myself figuring out if we have an issue at hand here, the original loss implementation of MSN comes from #895, but does not give any hints towards the issue, and the loss from the paper looks as follows:

1 M B ∑ i = 1 B ∑ m = 1 M H ( p i + , p i , m ) − λ H ( p ― ) , w h e r e p ― := 1 M B ∑ i = 1 B ∑ m = 1 M p i , m

The first term of this loss is implemented in

lightly/lightly/loss/msn_loss.py

Line 252 in 9e1d2c6

loss = torch.mean(torch.sum(torch.log(anchor_probs ** (-target_probs)), dim=1)) and the second one in

lightly/lightly/loss/msn_loss.py

Line 257 in 9e1d2c6

reg_loss = self.regularization_loss(mean_anchor_probs=mean_anchor_probs) So far so good.

However you are right that the sloss term from the original implementation is missing. It seems that with sloss, they add another cross-entropy c ⋅ H ( p i + ) , giving the total loss function:

1 M B ∑ i = 1 B ∑ m = 1 M H ( p i + , p i , m ) − λ H ( p ― ) + c H ( p i + )

Tbh, without the paper mentioning this loss term it is really difficult to draw conclusions as to why they decided to go this route and/or if this has significant influence on the model performance. Maybe @guarin knows more. However, if you would like to investigate more (which is of course very welcome from our side), I would rather approach the authors behind MSN than the ones from PMSN, since that seems to be the place where the issue arises. We might consider benchmarking MSN and PMSN with/without this additional loss term in the future and compare the performance to get a better idea.

Thank you very much for the fast reply. The first author of both papers is Mahmoud Assran. He archived the entire repository and, in general, did not address any questions in the issues. For example, there is an issue where someone asks for both the model weights, and he responds that he will provide them, but he never did. Even after I emailed them, I received no reply.

If you have any way to contact them and get answers, I would really appreciate it. However, if you plan to benchmark both with and without the sloss term, I would greatly appreciate it if the evaluation is conducted on a dataset like Clevr/Dist. As the PMSN paper suggests, ImageNet is uniform, and pretraining with PMSN does not provide any gains when using this dataset.

By the way, I do not have enough resources to run such an experiment. However, if there is any way I can help, please feel free to ask—I would be happy to assist. For example, I could submit a PR that includes both options in some way.

jafarinia avatar Feb 03 '25 16:02 jafarinia

@jafarinia A PR that allows dispatching to both implementation would of course be highly welcome. However, in that case we should also consider updating the tests/loss/test_msn_loss.py and try to test against the original implementation from the MSN repo. When it comes to benchmarking, I don't expect us to have the resources for this in the next few weeks, but in the mid-term future it might be very realistic. I also appreciate the hint to the potential datasets that might be interesting.

liopeer avatar Feb 03 '25 16:02 liopeer