lorax icon indicating copy to clipboard operation
lorax copied to clipboard

Improve the latency of `load_batched_adapter_weights`

Open thincal opened this issue 1 year ago • 1 comments

Feature request

Currently, every lora layer would be moved from CPU to target device of base model, results in extra 20ms in each layer, finally 500ms ~ 1+s latency overall.

  1. first loading into CPU memory
def load_module_map(
    ...
    for filename in adapter_filenames:
        adapter_weights.update(load_file(filename))
   ...
  1. moving to GPU device inside load_batched_adapter_weights
lora_a = lora_a.to(base_device, self.dtype)
lora_b = lora_b.to(base_device, self.dtype)

Motivation

Improve the adapter loading performance.

Your contribution

Yes, I will prepare a PR for review.

thincal avatar Apr 22 '24 17:04 thincal

Thanks for working on this @thincal! We could probably work around this by keeping weights in the safetensors file rather than loading to CPU as an intermediate step.

tgaddair avatar Apr 22 '24 17:04 tgaddair