seqlearn icon indicating copy to clipboard operation
seqlearn copied to clipboard

Linear Chain CRF with SGD method

Open chyikwei opened this issue 11 years ago • 4 comments

Hi,

I tried to implement linear chain CRF with SGD method based on these two paper:

C. Sutton, "An Introduction to Conditional Random Fields for Relational Learning"

N. Schraudolph, "Accelerated Training of Conditional Random Fields with Stochastic Gradient Methods"

For performance, I tested it on CoNLL 2000 shared task and got:

Model iteration F1 training F1 testing training time
CRF 5 0.894 0.844 68 mins
Structure Perceptron 10 0.892 0.834 1 min

I think the biggest problem in my implementation is the training time. I will try to work on parallel training or improve the cython code.

Maybe you can review it and give me some feedback. Any suggestions are welcomed. Thanks!

chyikwei avatar Apr 11 '14 19:04 chyikwei

Sweet! As for the time, have you tried profiling with kernprof.py?

I suspect calling logsumexp from inside Cython might have to with this, too.

larsmans avatar Apr 12 '14 09:04 larsmans

I have not tried it yet. Will do this next week.

chyikwei avatar Apr 13 '14 02:04 chyikwei

Just profiling LinearChainCRF.fit() and here are some results:

  1. computer posterior takes about 20% of total time:

Line: 146, Time Per Hit: 12021.4, % Time: 19.2 post_state, post_trans, ll = _posterior(score, None, b_trans, b_init, b_final)

  1. compute w_update takes 41% of the total time:

Line: 162, Time Per Hit: 26131.2, % Time: 41.7 w_update = lr * (safe_sparse_dot(y_t_i.T, X_i) - safe_sparse_dot(post_state.T, X_i) - (reg * w))

  1. compute objective function takes about 20% of total time:

Line: 149, Time Per Hit: 7178.1, % Time: 11.5 feature_val = np.sum(w_true * w)

Line: 153, Time Per Hit: 6169.9, % Time: 9.9 sum_obj_val += feature_val + trans_val + init_val + final_val - ll - (0.5 * reg * np.sum(w * w))

chyikwei avatar Apr 22 '14 01:04 chyikwei

Saved some time on computing w_update, new profiling result on w_update:

Line: 162, Time Per Hit: 10535.2, % Time: 25.1 w_update = lr * (safe_sparse_dot(y_t_i.T, X_i) - safe_sparse_dot(post_state.T, X_i) - (reg * w))

chyikwei avatar Apr 22 '14 01:04 chyikwei