SinGAN icon indicating copy to clipboard operation
SinGAN copied to clipboard

Possible one-line solution for Runtime error (variables modified in-place)

Open williantrevizan opened this issue 4 years ago • 8 comments

Hi, thanks for the repository and this amazing work!

I opened this issue because it might provide a solution for the runtime error reported by cdjameson in another topic, that happens in newer versions of torch ('one of the variables needed for gradient computation has been modified by an inplace operation...'), that seems to be more straightfoward than the solution that Clefspear99 is proposing as a pool request.

The problem happens in the function train_single_scale() in training.py This function is composed basically of two sequential loops, one for optimizing the discriminator D, and the other for optimizing the generator G. At the end of the first loop, a fake image is generated by the generator. As soon as the second loop starts, this fake image is passed throught the discriminator, with generates a patch discrimination map, which is then used to calculate the loss errG. The command errG.backwards() calculates the gradients which are used for the optimization of netG weights via optimizerG.step(). The first time we go through this second loop everything runs smoothly and the optimizer changes netG weights inplace. However, the second time we go through this loop, the same fake image is used to calculate the loss (that is, the fake image that had been generated with a previous set of netG weights). Therefore, once we call the backwards function, the computational graph will point back to netG weights that were in their original version, before the optimization step. Newer versions of torch are able to catch this inconsistency and that seems to be the reason why the error occurs.

So, instead of downgrading torch, a simple solution would be to add the line,

fake = netG(noise.detach(), prev.detach())

right in the beggining of the second loop, to always recalculate the fake image with the correct weights.

tamarott, I think this might solve this problem. If you allow, I will submit a pull request with this modification.

williantrevizan avatar Apr 20 '21 16:04 williantrevizan

This is a possible solution, but pat attention that it changed the optimization process and therefore might change performances. So the results won't necessarily be identical to the original version.

tamarott avatar Apr 20 '21 19:04 tamarott

You are right, I'll pay atention to that! I ran a few tests with the application I'm working on, and it seems to be doing fine with this modification, but I didn't stress these tests too much.

About the optimization process, when I first thought about your paper and code, it made sense to me that conceptually the fake image should be recalculated at every step on that loop (for optimizing G). However what seems to be going on is that the adversarial loss is kept fixed (because you use the same fake image 3 times) and only the reconstruction loss is updated inside the loop. Is there a reason why that should work better?

williantrevizan avatar Apr 21 '21 22:04 williantrevizan

We found it to work better empirically. But other solutions might also work. Just be careful and make sure performances are the same.

tamarott avatar Apr 22 '21 10:04 tamarott

Nice, thanks a lot!!

williantrevizan avatar Apr 22 '21 14:04 williantrevizan

Thanks @williantrevizan, Your fix worked for me

ariel415el avatar Aug 05 '21 11:08 ariel415el

It works for me well too. You saved my time!! Thanks a lot!

JasonBournePark avatar Jan 19 '22 02:01 JasonBournePark

thanks. You realy save my time!

WZLHQ avatar Aug 25 '22 03:08 WZLHQ

Thank you @williantrevizan! Confirmed that this solution works on torch==1.12.0.

jethrolam avatar Sep 27 '22 15:09 jethrolam