aphrodite-engine icon indicating copy to clipboard operation
aphrodite-engine copied to clipboard

Reduce peak memory for prompt_logprobs requests

Open 50h100a opened this issue 1 year ago • 1 comments

First order of business, make prompt_logprobs "compatible" with prefix caching. It can't take advantage of the caching, but at least it will run.

Second order of business, reduce the peak memory usage of the samplers. This PR slightly reduces the memory load, but not nearly enough: On single-GPU, sampling can still take dozens of gigabytes at peak memory. (8b model at 16k was >10gb) On multi-GPU, sampling is no cheaper, and there's also a colossal memory spike when gathering the logits.

Thoughts:

  • In this PR, some operations are split into smaller batches. Can we split the entire sampling process the same way? Leaving it mostly unchanged, but only handling a fixed k of rows at a time?
  • No idea what the fix is for the gather spikes, deferring to @AlpinDale on that. That might not even be the specific issue, just where it ran out of VRAM for me, but there's something about multi-GPU that's aggravating the memory peaks.

50h100a avatar Dec 16 '24 20:12 50h100a

Will probably need some restructuring after #925

AlpinDale avatar Dec 19 '24 17:12 AlpinDale