POT icon indicating copy to clipboard operation
POT copied to clipboard

[Fix] Gradient scaling in Partial GW solver

Open yikun-baio opened this issue 1 year ago • 4 comments

Types of changes

I modify the code ot.partial.partial_gromov_wasserstein

Motivation and context / Related issue

There seems to be an inconsistency between ot.partial.partial_gromov_wasserstein and the line search section in the paper [29]. I fixed this section. In addition, I have made minor change to the initial guess in the partial-GW solver since the original initial guess is np.out(p,q), which might not be suitable for unbalanced case, i.e. |p|\neq |q|.

N/A

N/A

How has this been tested (if it applies)

PR checklist

  • [ ] I have read the CONTRIBUTING document.
  • [ ] The documentation is up-to-date with the changes I made (check build artifacts).
  • [ ] All tests passed, and additional code has been covered with new tests.
  • [ ] I have added the PR and Issue fix to the RELEASES.md file.

yikun-baio avatar Feb 01 '24 02:02 yikun-baio

Hello @yikun-baio and thanks for the PR.

You should not propose anew file but instead implement directly your fix in th existing files. This is important because we have many tests that checks that nothing else breaks and we can compare easily your modification with the old implementation. We will do a code review with @lchapel when this is done.

rflamary avatar Feb 01 '24 07:02 rflamary

Hello @rflamary,

Thank you for your feedback. I apologize for not implementing the fix directly in the existing file and instead proposing a new file. This is my first time to contribute to a public project via a pull request.

I just realized that my email settings were inadvertently blocking emails from GitHub, which caused me to delay seeing your message.

Not sure if it's still needed, but I'll implement the fix as suggested and make sure it's done in an existing file for @lchapel's direct code review. Thank you very much for your guidance and I look forward to contributing more effectively in the future.

Thank you for your understanding and patience.

Sincerest regards, Yikun

yikun-baio avatar Feb 15 '24 22:02 yikun-baio

Hello @yikun-baio this is a friendly reminder to implement your fix directly in the code if possible. I used a rocket emoji in your previous question to say that we are interested but maybe you did not receive a notfication.

rflamary avatar Mar 29 '24 16:03 rflamary

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 96.68%. Comparing base (628a089) to head (b862248).

Additional details and impacted files
@@           Coverage Diff           @@
##           master     #602   +/-   ##
=======================================
  Coverage   96.68%   96.68%           
=======================================
  Files          85       85           
  Lines       16890    16890           
=======================================
  Hits        16330    16330           
  Misses        560      560           

codecov[bot] avatar Mar 29 '24 16:03 codecov[bot]

Hello @yikun-baio,

Thank you for your PR. In order to proceed with the review, could you first implement @rflamary corrections ? and make sure that all tests pass. Then I have some doubts on the current implementation.

Could you precise which inconsistency did you spot in the paper [29] ? According to Eq. 2 in the paper, authors considered the factor 1/2 for the GW cost which is omitted in our implementation of GW (cf docs). As long as the documentation of the PGW solver is clear, with the explicited loss (which is currently lacking), I don't see a major problem with having these differences within POT.

This difference implies that the computation of the gradient in gwgrad_partial had to be adapted and the implementation is correct. Same for gwloss_partial. However I also believe the current line-search implementation to be wrong but your fix does not seem to match:

M = gwgrad_partial(C1, C2, G0) # Here we want the 4D-tensor product to match calculus -> missing *0.5 
...
a = gwloss_partial(C1, C2, deltaG) # correct
b = 2 * np.sum(M * deltaG) # correct

Then I agree on the fact that the initial transport plan should be admissible such as the one you proposed. I would suggest to store p.sum() and q.sum() early in the function, and remove the redundancies in the current implementation.

Best, Cédric

cedricvincentcuaz avatar Jun 12 '24 22:06 cedricvincentcuaz

Hello, Cédric

Could you precise which inconsistency did you spot in the paper [29] ? According to Eq. 2 in the paper, authors considered the factor 1/2 for the GW cost which is omitted in our implementation of GW (cf docs). As long as the documentation of the PGW solver is clear, with the explicited loss (which is currently lacking), I don't see a major problem with having these differences within POT.

Please refer to the attached PDF. In Part 1, I explain the Linear Search problem and its solution. The solution I have derived is consistent with [29]. In Part 2, I go through the code related to the Line Search section. I have highlighted the important parts in red for your review.

explaination of the linear search.pdf

M = gwgrad_partial(C1, C2, G0) # Here we want the 4D-tensor product to match calculus -> missing *0.5 ... a = gwloss_partial(C1, C2, deltaG) # correct b = 2 * np.sum(M * deltaG) # correct

Based on the pdf, I think it should be changed to the following:

M = gwgrad_partial(C1, C2, G0) # M = \mathcal{M}\circ G there is no 1/2 

old_a = gwloss_partial(C1, C2, deltaG) # a= 1/2 <M\circ deltaG, deltaG>.  there is  1/2 term,  
old_b = 2 * np.sum(M * deltaG) # b = 2 <M\circ G, delta G>,   a,b are not consistant.

# option 1. (no 1/2 term for both a and b) 
a= 2* gwloss_partial(C1, C2, deltaG) # a= <M\circ deltaG, deltaG> 
b= 2*np.sum(M * deltaG) # b=2<M\circ G, delta G>

# option 2. (apply 1/2 term for both a and b) 
a=  gwloss_partial(C1, C2, deltaG) # a= 1/2 <M\circ deltaG, deltaG> 
b= np.sum(M * deltaG) # b=<M\circ G, delta G>

Could you take a look if my understanding is correct?

Thanks, Yikun Bai

yikun-baio avatar Jun 13 '24 05:06 yikun-baio

Thank you for these details.

I think we are saying the same thing, where you leverage the fact that anyways there will be a quotient between a and b, so it does not matter to rescale one or the other. It is just preferable to code the exact formula for these coefficients instead of using tricks that we might forget ;)

CF image below:

Screenshot 2024-06-13 at 6 28 40 PM

cedricvincentcuaz avatar Jun 13 '24 16:06 cedricvincentcuaz

Hello @yikun-baio, As we are rushing for a release to support numpy >= 2.0, I implemented the modifications mentioned above and some small modifications needed for tests to pass. I will merge when tests are finished.

Thank you for your contribution to POT :)

cedricvincentcuaz avatar Jun 21 '24 09:06 cedricvincentcuaz