sd-scripts
sd-scripts copied to clipboard
How can I switch between LoRA models dynamically?
I have trained several LoRA models using the Flux model, and I want to switch between these LoRA models dynamically without reloading the base Flux model.
I saw in #1185 that the methods backup_weights, pre_calculation, and restore_weights can be used to achieve this, but there are no examples provided. Could you please share an example of how to use them in practice?
Hi, after I read through the lora code file. I think there are two ways to change lora without reload the whole base Flux model.
- Use Lora as definition $$y = W * x+ B * A * x$$
- Update the weight matrix first, then take x in, which means, pre-compute $$W' = W + B * A$$, then $$y = W' * x$$
Here https://github.com/kohya-ss/sd-scripts/issues/1185#issuecomment-2001941718 means the second way. Since it backup the original W first, so, you can restore the W from RAM fast. (instead of load from your disk)
At https://github.com/kohya-ss/sd-scripts/blob/498705770109e0823a465fc6872c691136b3202a/flux_minimal_inference.py#L497-L504, the args.merge_lora_weights control whether to do it in the first way (args.merge_lora_weights = False) or the second way (args.merge_lora_weights = True).
I use flux_minimal_inference.py as a base, and rewrite the last part to implement the second way
I change
(under branch sd3 commit af14eab6d7f81493d23a7b961e01084f52eb5adf ) https://github.com/kohya-ss/sd-scripts/blob/498705770109e0823a465fc6872c691136b3202a/flux_minimal_inference.py#L476-L576
to (I name this new script as "flux_minimal_inference_switch_lora_a.py")
lora_models: List[lora_flux.LoRANetwork] = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
weights_sd = load_file(weights_file)
is_lora = is_oft = False
for key in weights_sd.keys():
if key.startswith("lora"):
is_lora = True
if key.startswith("oft"):
is_oft = True
if is_lora or is_oft:
break
module = lora_flux if is_lora else oft_flux
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
# if args.merge_lora_weights:
# lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
# else:
# lora_model.apply_to([clip_l, t5xxl], model)
# info = lora_model.load_state_dict(weights_sd, strict=True)
# logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
# lora_model.eval()
# lora_model.to(device)
#
# lora_models.append(lora_model)
generate_image(
model,
clip_l,
t5xxl,
ae,
args.prompt,
args.seed,
args.width,
args.height,
args.steps,
args.guidance,
args.negative_prompt,
args.cfg_scale,
)
def backup_weights(self):
# 重みのバックアップを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone().cpu()
org_module._lora_restored = True
backup_weights(lora_model)
print("lora backuped")
lora_model.apply_to([clip_l, t5xxl], model)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
def pre_calculation(self):
# 事前計算を行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()
org_weight = sd["weight"]
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
#>>> to handle dtype conversion issues
with torch.cuda.amp.autocast(enabled=False):
org_weight_fp32 = org_weight.float()
lora_weight_fp32 = lora_weight.float()
merged = (org_weight_fp32 + lora_weight_fp32).to(org_weight.dtype)
sd["weight"] = merged
#<<<
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)
org_module._lora_restored = False
lora.enabled = False
pre_calculation(lora_model)
print("pre_calculation done")
lora_model.eval()
lora_model.to(device)
generate_image(
model,
clip_l,
t5xxl,
ae,
args.prompt,
args.seed,
args.width,
args.height,
args.steps,
args.guidance,
args.negative_prompt,
args.cfg_scale,
)
def restore_weights(self):
# 重みのリストアを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
device = next(org_module.parameters()).device
sd["weight"] = org_module._lora_org_weight.to(device)
org_module.load_state_dict(sd)
org_module._lora_restored = True
restore_weights(lora_model)
generate_image(
model,
clip_l,
t5xxl,
ae,
args.prompt,
args.seed,
args.width,
args.height,
args.steps,
args.guidance,
args.negative_prompt,
args.cfg_scale,
)
torch.cuda.memory._dump_snapshot("VRAM_snapshot.pickle")
torch.cuda.memory._record_memory_history(enabled=None)
logger.info("Done!")
Then I call this new_script by
python flux_minimal_inference_switch_lora_a.py \
--ckpt_path "/root/autodl-tmp/ComfyUI/models/unet/flux1-dev-fp8_unet.safetensors" \
--clip_l "/root/autodl-tmp/ComfyUI/models/text_encoders/clip_l.safetensors" \
--t5xxl "/root/autodl-tmp/ComfyUI/models/text_encoders/t5xxl_fp8_e4m3fn.safetensors" \
--ae "/root/autodl-tmp/ComfyUI/models/vae/ae.safetensors" \
--apply_t5_attn_mask \
--flux_dtype fp8 \
--t5xxl_dtype fp8 \
--dtype bfloat16 \
--width 1024 \
--height 1024 \
--steps 20 \
--guidance 3.5 \
--width 1024 \
--height 1024 \
--prompt "style of H. R. Giger, style of aziib_pixel, A photo of a cat" \
--output_dir "/root/autodl-tmp/outputs" \
--seed 42 \
--lora_weights "/root/autodl-tmp/sd-scripts/lora/Aziib_Pixel_Style.safetensors;1.0" \
"/root/autodl-tmp/sd-scripts/lora/style_of_H._R._Giger_FLUX_295-000006.safetensors;1.0"
Result is as follows: (to test whether the functionality is recovered after restore, I push two generate image function call before and after each time I call the generate image function, so 2 * (1 + 1 + 1) = 6 images)
You can see this works.
At last, I may remind you that backup and restoring may have to do dtype conversion, since we have model weight being fp8_e4m3fn which does not support many operations on CUDA (this cause a error in my experiment if without dtype conversion).
Also, I think this backup and restoring is not fast when changing LoRA. Since pre_calculation is actually designed for further batch inference, a better way in logic should be the first one "Use Lora as definition $$y = W * x+ B * A * x$$" that I mention at the beginning. But since I still try to implement it, so I may end up here.
I implement the method one by replacing
https://github.com/kohya-ss/sd-scripts/blob/498705770109e0823a465fc6872c691136b3202a/flux_minimal_inference.py#L475-L576
to
original_apply_to = lora_flux.LoRAModule.apply_to
def patched_apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
self.org_module_ref_for_patch = self.org_module
del self.org_module
def patched_remove(self):
if hasattr(self, 'org_forward') and hasattr(self, 'org_module_ref'):
self.org_module_ref_for_patch.forward = self.org_forward
del self.org_forward
del self.org_module_ref_for_patch
lora_flux.LoRAModule.apply_to = patched_apply_to
lora_flux.LoRAModule.remove = patched_remove
def network_remove(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.remove()
lora_flux.LoRANetwork.remove = network_remove
# LoRA
lora_models: List[lora_flux.LoRANetwork] = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
weights_sd = load_file(weights_file)
is_lora = is_oft = False
for key in weights_sd.keys():
if key.startswith("lora"):
is_lora = True
if key.startswith("oft"):
is_oft = True
if is_lora or is_oft:
break
module = lora_flux if is_lora else oft_flux
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
generate_image(
model,
clip_l,
t5xxl,
ae,
args.prompt,
args.seed,
args.width,
args.height,
args.steps,
args.guidance,
args.negative_prompt,
args.cfg_scale,
)
lora_model.apply_to([clip_l, t5xxl], model)
info = lora_model.load_state_dict(weights_sd, strict=False)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)
generate_image(
model,
clip_l,
t5xxl,
ae,
args.prompt,
args.seed,
args.width,
args.height,
args.steps,
args.guidance,
args.negative_prompt,
args.cfg_scale,
)
lora_model.remove()
generate_image(
model,
clip_l,
t5xxl,
ae,
args.prompt,
args.seed,
args.width,
args.height,
args.steps,
args.guidance,
args.negative_prompt,
args.cfg_scale,
)
logger.info("Done!")
The bash command is similar.
Results are as follows 2 * (1+1+1) (2 lora, 1 image before apply 1 image after apply), you can see that it works:
At last, I may remind you that here, I use a workaround, since the LoraModule replace the forward method of original module, so I use monkey patch to deal with that. In order to do that, I have to make the lora_model.load_state_dict(weights_sd, strict=False)
Here, since we refer to original module with other name, if we use load_state_dict with strict = True, will have a error.
Not a elegant solution, but using this implementation is much faster than last one, since you do not need to do dtype conversion, do not need to move the weight between cpu and cuda, etc.
But, if you want to do inference batchwise, i still recommand to use the last implementation that with pre_calculation, since that one is better when model do not need to change during a lot of inference.