pytorch-struct
pytorch-struct copied to clipboard
Differentiating Through Marginals of Dependency CRF
Hi,
I tried using the DependencyCRF in a learning setting which required me to differentiate through the marginals. This turned out to be really difficult to achieve. I noticed that the gradients computed for the marginals tended to be of high variance + larger than I would expect (even though I haven't deep-dived into the Eisner algorithm yet).
I wonder if this a feature of the Eisner algorithm or might potentially hint at a bug? Below is a minimal example which showcases that the maximum gradient returned for the arcscores can be quite large, even if they are on a reasonable scale.
import torch
from torch_struct import DependencyCRF
torch.manual_seed(99)
maxlen = 50
vals = torch.randn((1, maxlen, maxlen), requires_grad=True)
grad_output = torch.rand(1, maxlen, maxlen)
dist = DependencyCRF(vals)
marginals = dist.marginals
marginals.backward(grad_output)
print(vals.max().item())
print(marginals.max().item())
print(grad_output.max().item())
print(vals.grad.max().item())
#3.5494842529296875
#0.8289076089859009
#0.9995625615119934
#19.625778198242188
hi, sorry for the long delay here. I'm going to try to add some tests to make sure it is returning the right values. I don't have a great sense about whether this is a bug, underflow, or correct in this case.
I am pretty sure that the reason is due to the "Chart" class, one should set cache=False if want to reuse the computation graph
that sounds right. I will turn off chart by default.
Also now the backward on marginals approach works with fastlogsemiring.