semi-memory
semi-memory copied to clipboard
key-value updates
Hi
Thanks for the wonderful work. I found it to be a great read and very easy to understand.
(1) I am wondering what loss function did you use for the key value updates?
based on eqn. (3) it seems that mean square error loss has been utilized like
loss_k_j = sum_i=1^n_j (k_j - x_i)^2
Is there any particular reason that 1/(n_j + 1)
has been selected instead of n_j
?
(2) I am also wondering if the MND and ME loss was simply defined for the unlabeled portion of the data, how does the performance degrade ?
(3) lastly what happens if instead of updating the key and values after every epoch, we simply average out the intermediate representations and the softmax of the labeled data to re-define the key and value pairs ?
(4) Any plans to release the code in Pytorch ? I am not very familiar with tensorflow but would like to understand your method more by studying the code.
any clarifications will be helpful.
Thanks Devraj