aphrodite-engine
aphrodite-engine copied to clipboard
Reduce peak memory for prompt_logprobs requests
First order of business, make prompt_logprobs "compatible" with prefix caching. It can't take advantage of the caching, but at least it will run.
Second order of business, reduce the peak memory usage of the samplers. This PR slightly reduces the memory load, but not nearly enough: On single-GPU, sampling can still take dozens of gigabytes at peak memory. (8b model at 16k was >10gb) On multi-GPU, sampling is no cheaper, and there's also a colossal memory spike when gathering the logits.
Thoughts:
- In this PR, some operations are split into smaller batches. Can we split the entire sampling process the same way? Leaving it mostly unchanged, but only handling a fixed
kof rows at a time? - No idea what the fix is for the gather spikes, deferring to @AlpinDale on that. That might not even be the specific issue, just where it ran out of VRAM for me, but there's something about multi-GPU that's aggravating the memory peaks.
Will probably need some restructuring after #925