pytorch-struct icon indicating copy to clipboard operation
pytorch-struct copied to clipboard

add incremental linear-chain CRF

Open haozheji opened this issue 3 years ago • 5 comments

haozheji avatar Nov 04 '21 08:11 haozheji

Is there a reason you closed this? It looks cool.

srush avatar Nov 04 '21 14:11 srush

Oh, I just accidentally commit it to the original main branch. If you found it interesting, maybe I can reopen it?

haozheji avatar Nov 15 '21 13:11 haozheji

I use forward-backward algorithm to calculate the marginal of linear-chain CRF of the prefix sequence (in order to support AR models). The parallel calculation can be done in O(logN) complexity with a modified parallel scan. I call it interval parallel scan :). The procedure looks like this: image

haozheji avatar Nov 15 '21 13:11 haozheji

Very neat. I'll take a read. Do you find that this gives a speedup? Seems hard to parallelize.

srush avatar Nov 15 '21 22:11 srush

You mean the speed up comparing to using the gradient identity? At first I have tried only calculating the prefix sum and using back-propagation to get the marginal of the prefix sequence. But it requires O(N) times of back propagation separately. Otherwise we have to replicate the whole graph N times for O(1) parallel complexity but it will easily hit the GPU memory limit. The current version requires O(logN) parallel operations, but I haven't test the speed comparing to other implementations yet.

haozheji avatar Nov 16 '21 07:11 haozheji