pytorch-struct
pytorch-struct copied to clipboard
add incremental linear-chain CRF
Is there a reason you closed this? It looks cool.
Oh, I just accidentally commit it to the original main branch. If you found it interesting, maybe I can reopen it?
I use forward-backward algorithm to calculate the marginal of linear-chain CRF of the prefix sequence (in order to support AR models). The parallel calculation can be done in O(logN) complexity with a modified parallel scan. I call it interval parallel scan :). The procedure looks like this:

Very neat. I'll take a read. Do you find that this gives a speedup? Seems hard to parallelize.
You mean the speed up comparing to using the gradient identity? At first I have tried only calculating the prefix sum and using back-propagation to get the marginal of the prefix sequence. But it requires O(N) times of back propagation separately. Otherwise we have to replicate the whole graph N times for O(1) parallel complexity but it will easily hit the GPU memory limit. The current version requires O(logN) parallel operations, but I haven't test the speed comparing to other implementations yet.