torchchat
torchchat copied to clipboard
[distributed][perf] ensure that all decoding ops are happening on gpu with no cpu sync
🐛 Describe the bug
per @kwen2501 - when we are doing decoding step:
next_token = torch.tensor([decode_results[0][0]], device=device)
"nit: I am not sure if the use of torch.tensor here would cause a sync from GPU to CPU (to get the scalar) then move to the GPU again (to create the tensor). If there is no use of next_token in CPU domain, better to just use index op here.
Or, is decode_results already on CPU? Hmm, then we'd need to think about how to arrange these CPU ops and GPU ops. Ideally, you would like to fire the send right after step()."
Versions
n/a