this sample func seems like a bug...
没有考虑topk topp? 而且至少应该算个softmax出来? https://github.com/Infini-AI-Lab/TriForce/blob/164c8c0131cf49951eefdea89a3fbcccb8ca326b/utils/sampling.py#L64
sample 之前有 apply 一个 norm_logits 函数。
https://github.com/Infini-AI-Lab/TriForce/blob/164c8c0131cf49951eefdea89a3fbcccb8ca326b/utils/sampling.py#L43
token = sample(norm_logits(x))
okk,不过这个norm_logits没有考虑temperature为0的情况,跑temperature=0会nan
temperature=0 的时候直接 argmax 就可以了,不用这么复杂的操作了,更 efficient 一些。
就是这一行会错 https://github.com/Infini-AI-Lab/TriForce/blob/main/utils/decoding.py#L23
或者你直接temperature=0.001,和 0 没什么区别其实。
还有这里 https://github.com/Infini-AI-Lab/TriForce/blob/164c8c0131cf49951eefdea89a3fbcccb8ca326b/utils/decoding.py#L32 和这里 https://github.com/Infini-AI-Lab/TriForce/blob/164c8c0131cf49951eefdea89a3fbcccb8ca326b/utils/decoding.py#L62