evidential-deep-learning
evidential-deep-learning copied to clipboard
Pytorch implementation
Hi all, great talk and paper.
I did the preliminary work of porting this to PyTorch. There are a few niceties that could be further implemented like specifying batch dimension and some customization with reduction, and how huggingface/transformers has both tf and torch implementations without requiring both as dependencies.
Otherwise it's all there for NIG. I didn't implement the Dirichlet_SOS loss since it wasn't clear where it would be used. I'll work on porting the NeurIPS examples but since that will take a while, I figured it would be useful to give the base torch code for now.
Of note: I found some numerical instabilities/issues with the student t distribution when the model has very confident regressions. Just to test the torch version, I did SGD on a simple contrived linear regression and found that as the model achieved a strong fit, its probabilities via student t went > 1 (and nll went < 0). It obviously doesn't hinder the training, but it seems a bit off to have a calculation produce probabilities > 1.
Pytorch has an implementation of StudentT, which also suffers from the same instabilities but are numerically different. I went with directly porting the TF code for numerical consistency.
Something like the below will recreate this instability.
import torch
import evidential_deep_learning as edl
import numpy as np
xs = [torch.rand(8) for i in range(10_000)]
ys = [x.mean() * np.pi for x in xs]
class BasicNetwork(torch.nn.Module):
def __init__(self, n_in=8, n_tasks=1):
super(BasicNetwork, self).__init__()
self.l1 = torch.nn.Linear(n_in, 6)
self.l2 = edl.pytorch.layers.DenseNormalGamma(6, 1)
def forward(self, x):
x = self.l1(x)
x = torch.nn.functional.relu(x)
x = self.l2(x)
return x
model = BasicNetwork()
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
for i in range(1):
for x, y in zip(xs, ys):
output = model(x)
nll, reg = edl.pytorch.losses.EvidentialRegression(y, output)
with torch.no_grad():
print((-1 * nll).exp().item())
print(nll.item())
print(reg.item())
print(y)
print(output)
print((y - output[0]).abs().item())
loss = nll + reg
print(loss.item())
print()
loss.backward()
optim.step()
optim.zero_grad()
First of all, thank you for contributing the new pytorch implementation! We really appreciate this initiative on your part.
I would like to take some time to review your updates before merging in to master / PyPi. I can provide some comments on the PR and (if it's okay with you) commit some suggestions to your branch before merging. In the meantime, I believe your code will serve as a great starting point for others who would like to try the method in pytorch.
Hi! I'd actually suggest you make a pytorch dev branch and I'll re-PR into that. I always wonder why that isn't an option on github since it's more logical. That way I can also revert all of the neurips2020 import refactoring and leave all of the original code untouched.
Please do give suggestions - I'm especially not happy at the moment with how the two implementations handle the namespace. I think there's potentially a much cleaner way (for the user) to go about this and it's definitely not ready to merge into main as is.
Good point, I just created the pytorch dev branch (dev/pytorch
) and modified the target branch of your repo to point to it instead of master.
I agree with the namespace issue. I think one way to cleanly handle this is like how keras used to handle multiple backends and read the backend from an OS variable (doc). Similar approach is adopted by pyrender. Alternatively, we could adopt an approach similar to how matplotlib works (they have a base method that allows to switch backends).
- Reverted everything with NeurIPS directory
- See commit ce9f606. This is a bit messy but uses the matplotlib 'backend' idea. It checks if torch or tf can be imported and then manipulates the edl.layers and edl.losses names to point to the correct implementation. I'm sure that there is a cleaner way.
- I implemented Dirichlet_SOS since @benjiachong seems to want it. It's a very 1:1 port and I'm not sure if it is valid. I'm unfamiliar with tf so I'm not sure what dimension the axis=1 is really referring to and what that would then conventionally be in pytorch (it's usually Batchdim, [channeldim]?, Datadim, [Datadim2]?, ...)
I think the next steps would be first validating the pytorch code and then finding a cleaner way to handle the namespace.
Dirichlet UQ for discrete classification is implemented and validated @benjiachong
Hi, @Dariusrussellkish. As you comment on 14 Dec 2020, I have also encountered the problem that nll loss goes to negative value. I found the problem occer when the evidence value(nu, alpha) go high and so the variance becomes to be small.
Did you solve this problem in pytorch implementation?
Thank you.