Update tensor_parallel.py
Resolve the issue of abnormal conversation performance in the Baichuan large model.
Fix the bug in the norm_head adaptation for Baichuan.
Fixes https://github.com/huggingface/text-generation-inference/issues/2780
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/modeling_baichuan.py#:~:text=self.weight.data%20%3D%20nn.functional.normalize(self.weight)
@OlivierDehaene OR @Narsil
We cannot really accept this. This is a bug in Baichuan weights, not in our code.
The issue with your proposed fix is that we support tensor parallelism (TP), which means weight values will depend on what TP value you're using, leading to potentially even more massive discrepancies.
The "true" fix in that sense would be to load the entire weight, normalize it, and then split it across GPU, but it will lead to other issues, the first of which will be excess of VRAM usage, which can cause unwanted OOMs.
Baichuan should fix their weights (unless there's a valid reason to keep the unnormalized weights, but I don't think there's one).