Clarification on the Loss Function Implementation
In the provided code, I can see that MSELoss is used, but I don't see the term $log P(c_l = z_{t_l})$ being included. However, in the explanation of the loss function, there is a term $log P(c_l = z_{t_l})$ which seems to play an important role in supervising the mixture component cl to match the token ztl.
Additionally, the loss function includes the term $f(h_l, c_l)$, but in the code, the states appear to be computed by averaging over the mixture components cl. I’m unsure how this averaging of cl is related to the term $f(h_l, c_l)$ in the loss function.
Could you clarify how these components are handled in the code and if there is any specific reason why $log P(c_l = z_{t_l})$ is not explicitly included?
Thanks in advance for your help!