memit
                                
                                 memit copied to clipboard
                                
                                    memit copied to clipboard
                            
                            
                            
                        Distributing the update across multiple layer
Hey, Thanks for sharing your work! I have a question about how you chose to spread the residual across the remaining layers at each update step (Eq. 20). You chose the updated values as: M' = M + residual / (L - l + 1) claiming it spreads the residual equally across the updated layers, but actually if there are 4 updates layers: the first layer will provide 1/4 of the residual, the second layer will provide 1/12 (=1/3 - 1/4) of the residual, the third layer will provide 1/6 (=1/2-1/3) of the residual, and the fourth layer will provide 1/2 (=1-1/2) of the residual.
Shouldn't the correct update be: M' = M + residual * (l - first_edited_layer + 1) / (L - first_edited_layer + 1)?
Thanks
Hi @YoadTew, great question! I think it comes down to a notional clarification.
In Equation 20, we write $m^l_i = W_{out} k_i^l + r_i^l$ where $$r_i^l = \frac{z_i - h_i^L}{L-l+1}.$$
Critically, $r_i^l$ is re-evaluated at each $i$, since the value of $h_i^L$ is affected by every layer update. We perform these updates iteratively because MEMIT updates use error minimization, and thus the post-update residuals may not match exactly. Also note that the outputs of future modules $m_j, j > i$ shift when $m_i$ is updated, which introduces additional error.
Modulo these errors, the scheme should give us even spreading. Let me know if you have any further questions!