lorax
lorax copied to clipboard
Improve the latency of `load_batched_adapter_weights`
Feature request
Currently, every lora layer would be moved from CPU to target device of base model, results in extra 20ms in each layer, finally 500ms ~ 1+s latency overall.
- first loading into CPU memory
def load_module_map(
...
for filename in adapter_filenames:
adapter_weights.update(load_file(filename))
...
- moving to GPU device inside load_batched_adapter_weights
lora_a = lora_a.to(base_device, self.dtype)
lora_b = lora_b.to(base_device, self.dtype)
Motivation
Improve the adapter loading performance.
Your contribution
Yes, I will prepare a PR for review.
Thanks for working on this @thincal! We could probably work around this by keeping weights in the safetensors file rather than loading to CPU as an intermediate step.