mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Additoinal losses

Open Jyun1998 opened this issue 1 year ago • 3 comments

Proposed changes

Added commonly used losses with tests

Checklist

Put an x in the boxes that apply.

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [ ] I have updated the necessary documentation (if needed)

Jyun1998 avatar Jan 01 '24 17:01 Jyun1998

Hi @awni , rebased and make cos loss to remain on this pr.

Also commented on #324 based on my implementation.

Thanks :)

Jyun1998 avatar Jan 02 '24 14:01 Jyun1998

Hey @Jyun1998 sorry but somehow my main comment did not get included. I must have not pushed the save button by accident.

Basically I'm wondering what reference you use for this loss as it looks quite different than the similarly named cosine similarity loss in PyTorch ?

For example there is a target and a margin both of which I do not expect. Unless there is a good reason for the difference, I would suggest we follow the PyTorch implementation as a reference.

awni avatar Jan 03 '24 14:01 awni

@Jyun1998 are you still planning to follow up on this?

awni avatar Jan 05 '24 03:01 awni

@Jyun1998 are you still planning to follow up on this?

https://github.com/keras-team/keras/blob/v2.14.0/keras/losses.py#L1162-L1236

Hi awni, according to common tensorflow and pytorch implementaiton, the functions are composed of doing l2 norm to each embedding and returns the negative of dot product of both embedding.

My codes also do so, and margin-based loss is applying if it's necessary :)

Jyun1998 avatar Jan 07 '24 14:01 Jyun1998

@Jyun1998 got it. We should keep it simple until we see that we need more features. Could you follow the PyTorch cosine similarity loss? I think that one covers the most common case (margin / targets are are niche and possibly never needed so we don't want to add them to the API until we are sure they are necessary)

awni avatar Jan 07 '24 14:01 awni

@Jyun1998 got it. We should keep it simple until we see that we need more features. Could you follow the PyTorch cosine similarity loss? I think that one covers the most common case (margin / targets are are niche and possibly never needed so we don't want to add them to the API until we are sure they are necessary)

I also agree. Even though there's also margin for pytorch F.cosine_similarity, it defaultly do not use it.

There's only slight change needed for changes and I will test the function and commit asap. Thanks for the review

Jyun1998 avatar Jan 07 '24 14:01 Jyun1998

Even though there's also margin for pytorch F.cosine_similarity

I don't see the margin in the docs? Is it in the source code?

Screenshot 2024-01-07 at 6 30 32 AM

awni avatar Jan 07 '24 14:01 awni

Even though there's also margin for pytorch F.cosine_similarity

I don't see the margin in the docs? Is it in the source code?

Screenshot 2024-01-07 at 6 30 32 AM

https://pytorch.org/docs/stable/generated/torch.nn.CosineEmbeddingLoss.html#torch.nn.CosineEmbeddingLoss

Jyun1998 avatar Jan 07 '24 14:01 Jyun1998

I see, thanks. Yes let's go with the plain cosine similarity for now. Thank you!

awni avatar Jan 07 '24 14:01 awni

Also could you rebase and resolve conflicts?

awni avatar Jan 07 '24 14:01 awni

Also could you rebase and resolve conflicts?

Am I correct that losses test codes are gone?


nvm found new losses test file

Jyun1998 avatar Jan 07 '24 15:01 Jyun1998

Also could you rebase and resolve conflicts?

Could you check? Thanks :)

Jyun1998 avatar Jan 07 '24 15:01 Jyun1998