pytorch-llama
pytorch-llama copied to clipboard
Mask while decoding in `_sample_top_p`
The mask in https://github.com/hkproj/pytorch-llama/blob/067f8a37fe36ac8b52dca9cc6f2a2e8d6aa372d6/inference.py#L121 should be ~mask since we want to select all those indices where value is less than p.