score_sde_pytorch
score_sde_pytorch copied to clipboard
Added sliced score matching; does not learn (DO NOT MERGE)
Hello,
I'm wondering if you have any idea why this does not work? I have updated get_sde_loss_fn to do sliced score matching as you've done in your sibling repo (https://github.com/ermongroup/sliced_score_matching). The eventual goal is to implement numerical sampling for an SDE without a closed form perturbation kernel.
The best I can obtain are patterns that look like this:

This is CIFAR data set. I know 2000 iterations is substantially shorter than is required to get "good" results, but when I use get_sde_loss_fn as it exists in your code, I get acceptable results.

Do you have any idea where I could be going wrong? Thanks in advance