physicsnemo icon indicating copy to clipboard operation
physicsnemo copied to clipboard

🐛[BUG] CorrDiff: noise level inconsistent between regression model training and inference

Open nbren12 opened this issue 1 year ago • 1 comments

Version

main

On which installation method(s) does this occur?

No response

Describe the issue

there appears to be an inconsistency in the sigma passed to the regression model at training and inference time. Training has random sigma. Inference has sigma = 1. Not sure if it would have any impact. I imagine that the regression model quickly learns to ignore the randomized sigma during training.

We should probably still fix it. Training should have a constant sigma.

Minimum reproducible example

No response

Relevant log output

No response

Environment details

No response

nbren12 avatar Apr 13 '24 03:04 nbren12

cc @mnabian @MortezaMardani

nbren12 avatar Apr 13 '24 03:04 nbren12

Hi @nbren12 ,

I am working on that. The line in the inference inference script that you provided is for the U-Net regression, which is purely deterministic, so I don't think it should use any sigma.

Did you find this bug somewhere else as well, or can I close this issue?

CharlelieLrt avatar Feb 04 '25 04:02 CharlelieLrt

Yes it is for the regression model. It should be deterministic yes, but the RegressionLoss generates a random sigma, while the inference code has sigma=1. I think this is a bug.

Here are the relevant lines of code:

Training

        rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()

        # irrelevant lines redacted

        D_yn = net(input, y_lr, sigma, labels, augment_labels=augment_labels)
        # arg 3:  sigma is a random number here

Inference:

    t_hat = torch.tensor(1.0).to(torch.float64).cuda()

    # Run regression on just a single batch element and then repeat
    x_next = net(x_hat[0:1], x_lr, t_hat, class_labels).to(torch.float64)
    # arg 3 t_hat = 1

nbren12 avatar Feb 04 '25 04:02 nbren12