st-moe-pytorch
st-moe-pytorch copied to clipboard
differentiable top k
IIUC, the topk in colt5_attention uses coor_descent, and, according to the original paper Eq 8 - 11, it seems to expect the input to be unnormalized.
However, in the forward of TopNGating, it seems that normalized score is passed into the topk.
I wonder if I misunderstood something and whether I should use normalized or unnormalized score here.