sd-scripts icon indicating copy to clipboard operation
sd-scripts copied to clipboard

How can I switch between LoRA models dynamically?

Open ljwdust opened this issue 5 months ago • 2 comments

I have trained several LoRA models using the Flux model, and I want to switch between these LoRA models dynamically without reloading the base Flux model.

I saw in #1185 that the methods backup_weights, pre_calculation, and restore_weights can be used to achieve this, but there are no examples provided. Could you please share an example of how to use them in practice?

ljwdust avatar Jun 18 '25 02:06 ljwdust

Hi, after I read through the lora code file. I think there are two ways to change lora without reload the whole base Flux model.

  1. Use Lora as definition $$y = W * x+ B * A * x$$
  2. Update the weight matrix first, then take x in, which means, pre-compute $$W' = W + B * A$$, then $$y = W' * x$$

Here https://github.com/kohya-ss/sd-scripts/issues/1185#issuecomment-2001941718 means the second way. Since it backup the original W first, so, you can restore the W from RAM fast. (instead of load from your disk)

At https://github.com/kohya-ss/sd-scripts/blob/498705770109e0823a465fc6872c691136b3202a/flux_minimal_inference.py#L497-L504, the args.merge_lora_weights control whether to do it in the first way (args.merge_lora_weights = False) or the second way (args.merge_lora_weights = True).

I use flux_minimal_inference.py as a base, and rewrite the last part to implement the second way

I change

(under branch sd3 commit af14eab6d7f81493d23a7b961e01084f52eb5adf ) https://github.com/kohya-ss/sd-scripts/blob/498705770109e0823a465fc6872c691136b3202a/flux_minimal_inference.py#L476-L576

to (I name this new script as "flux_minimal_inference_switch_lora_a.py")

    lora_models: List[lora_flux.LoRANetwork] = []
    for weights_file in args.lora_weights:
        if ";" in weights_file:
            weights_file, multiplier = weights_file.split(";")
            multiplier = float(multiplier)
        else:
            multiplier = 1.0

        weights_sd = load_file(weights_file)
        is_lora = is_oft = False
        for key in weights_sd.keys():
            if key.startswith("lora"):
                is_lora = True
            if key.startswith("oft"):
                is_oft = True
            if is_lora or is_oft:
                break

        module = lora_flux if is_lora else oft_flux
        lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)

        # if args.merge_lora_weights:
        #     lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
        # else:
        #     lora_model.apply_to([clip_l, t5xxl], model)
        #     info = lora_model.load_state_dict(weights_sd, strict=True)
        #     logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
        #     lora_model.eval()
        #     lora_model.to(device)
        #
        # lora_models.append(lora_model)



        generate_image(
            model,
            clip_l,
            t5xxl,
            ae,
            args.prompt,
            args.seed,
            args.width,
            args.height,
            args.steps,
            args.guidance,
            args.negative_prompt,
            args.cfg_scale,
        )


        def backup_weights(self):
            # 重みのバックアップを行う
            loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
            for lora in loras:
                org_module = lora.org_module_ref[0]
                if not hasattr(org_module, "_lora_org_weight"):
                    sd = org_module.state_dict()
                    org_module._lora_org_weight = sd["weight"].detach().clone().cpu()
                    org_module._lora_restored = True

        backup_weights(lora_model)
        print("lora backuped")

        lora_model.apply_to([clip_l, t5xxl], model)
        info = lora_model.load_state_dict(weights_sd, strict=True)
        logger.info(f"Loaded LoRA weights from {weights_file}: {info}")

        def pre_calculation(self):
            # 事前計算を行う
            loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
            for lora in loras:
                org_module = lora.org_module_ref[0]
                sd = org_module.state_dict()

                org_weight = sd["weight"]
                lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)

                #>>> to handle dtype conversion issues
                with torch.cuda.amp.autocast(enabled=False):
                    org_weight_fp32 = org_weight.float()
                    lora_weight_fp32 = lora_weight.float()
                    merged = (org_weight_fp32 + lora_weight_fp32).to(org_weight.dtype)
                    sd["weight"] = merged
                #<<<

                assert sd["weight"].shape == org_weight.shape
                org_module.load_state_dict(sd)

                org_module._lora_restored = False
                lora.enabled = False

        pre_calculation(lora_model)
        print("pre_calculation done")

        lora_model.eval()
        lora_model.to(device)

        generate_image(
            model,
            clip_l,
            t5xxl,
            ae,
            args.prompt,
            args.seed,
            args.width,
            args.height,
            args.steps,
            args.guidance,
            args.negative_prompt,
            args.cfg_scale,
        )

        def restore_weights(self):
            # 重みのリストアを行う
            loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
            for lora in loras:
                org_module = lora.org_module_ref[0]
                if not org_module._lora_restored:
                    sd = org_module.state_dict()
                    device = next(org_module.parameters()).device
                    sd["weight"] = org_module._lora_org_weight.to(device)
                    org_module.load_state_dict(sd)
                    org_module._lora_restored = True
        restore_weights(lora_model)

        generate_image(
            model,
            clip_l,
            t5xxl,
            ae,
            args.prompt,
            args.seed,
            args.width,
            args.height,
            args.steps,
            args.guidance,
            args.negative_prompt,
            args.cfg_scale,
        )

    torch.cuda.memory._dump_snapshot("VRAM_snapshot.pickle")

    torch.cuda.memory._record_memory_history(enabled=None)

    logger.info("Done!")

Then I call this new_script by

python flux_minimal_inference_switch_lora_a.py \
--ckpt_path "/root/autodl-tmp/ComfyUI/models/unet/flux1-dev-fp8_unet.safetensors" \
--clip_l "/root/autodl-tmp/ComfyUI/models/text_encoders/clip_l.safetensors" \
--t5xxl "/root/autodl-tmp/ComfyUI/models/text_encoders/t5xxl_fp8_e4m3fn.safetensors" \
--ae "/root/autodl-tmp/ComfyUI/models/vae/ae.safetensors" \
--apply_t5_attn_mask \
--flux_dtype fp8 \
--t5xxl_dtype fp8 \
--dtype bfloat16 \
--width 1024 \
--height 1024 \
--steps 20 \
--guidance 3.5 \
--width 1024 \
--height 1024 \
--prompt "style of H. R. Giger, style of aziib_pixel, A photo of a cat" \
--output_dir "/root/autodl-tmp/outputs" \
--seed 42 \
--lora_weights "/root/autodl-tmp/sd-scripts/lora/Aziib_Pixel_Style.safetensors;1.0" \
"/root/autodl-tmp/sd-scripts/lora/style_of_H._R._Giger_FLUX_295-000006.safetensors;1.0" 

Result is as follows: (to test whether the functionality is recovered after restore, I push two generate image function call before and after each time I call the generate image function, so 2 * (1 + 1 + 1) = 6 images)

Image Image Image Image Image Image

You can see this works.

At last, I may remind you that backup and restoring may have to do dtype conversion, since we have model weight being fp8_e4m3fn which does not support many operations on CUDA (this cause a error in my experiment if without dtype conversion).

Also, I think this backup and restoring is not fast when changing LoRA. Since pre_calculation is actually designed for further batch inference, a better way in logic should be the first one "Use Lora as definition $$y = W * x+ B * A * x$$" that I mention at the beginning. But since I still try to implement it, so I may end up here.

CHR-ray avatar Jul 29 '25 09:07 CHR-ray

I implement the method one by replacing

https://github.com/kohya-ss/sd-scripts/blob/498705770109e0823a465fc6872c691136b3202a/flux_minimal_inference.py#L475-L576

to

    original_apply_to = lora_flux.LoRAModule.apply_to

    def patched_apply_to(self):
        
        self.org_forward = self.org_module.forward
        
        self.org_module.forward = self.forward
        
        self.org_module_ref_for_patch = self.org_module
        del self.org_module

    def patched_remove(self):
        
        if hasattr(self, 'org_forward') and hasattr(self, 'org_module_ref'):
            self.org_module_ref_for_patch.forward = self.org_forward
            
            del self.org_forward
            del self.org_module_ref_for_patch

    lora_flux.LoRAModule.apply_to = patched_apply_to
    lora_flux.LoRAModule.remove = patched_remove

    def network_remove(self):
        for lora in self.text_encoder_loras + self.unet_loras:
            lora.remove()
    lora_flux.LoRANetwork.remove = network_remove

    # LoRA
    lora_models: List[lora_flux.LoRANetwork] = []
    for weights_file in args.lora_weights:
        if ";" in weights_file:
            weights_file, multiplier = weights_file.split(";")
            multiplier = float(multiplier)
        else:
            multiplier = 1.0

        weights_sd = load_file(weights_file)
        is_lora = is_oft = False
        for key in weights_sd.keys():
            if key.startswith("lora"):
                is_lora = True
            if key.startswith("oft"):
                is_oft = True
            if is_lora or is_oft:
                break

        module = lora_flux if is_lora else oft_flux
        lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)

        generate_image(
            model,
            clip_l,
            t5xxl,
            ae,
            args.prompt,
            args.seed,
            args.width,
            args.height,
            args.steps,
            args.guidance,
            args.negative_prompt,
            args.cfg_scale,
        )

        lora_model.apply_to([clip_l, t5xxl], model)
        info = lora_model.load_state_dict(weights_sd, strict=False)
        logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
        lora_model.eval()
        lora_model.to(device)

        generate_image(
            model,
            clip_l,
            t5xxl,
            ae,
            args.prompt,
            args.seed,
            args.width,
            args.height,
            args.steps,
            args.guidance,
            args.negative_prompt,
            args.cfg_scale,
        )
        lora_model.remove()

        generate_image(
            model,
            clip_l,
            t5xxl,
            ae,
            args.prompt,
            args.seed,
            args.width,
            args.height,
            args.steps,
            args.guidance,
            args.negative_prompt,
            args.cfg_scale,
        )

    logger.info("Done!")

The bash command is similar.

Results are as follows 2 * (1+1+1) (2 lora, 1 image before apply 1 image after apply), you can see that it works:

Image Image Image Image Image Image

At last, I may remind you that here, I use a workaround, since the LoraModule replace the forward method of original module, so I use monkey patch to deal with that. In order to do that, I have to make the lora_model.load_state_dict(weights_sd, strict=False)

Here, since we refer to original module with other name, if we use load_state_dict with strict = True, will have a error.

Not a elegant solution, but using this implementation is much faster than last one, since you do not need to do dtype conversion, do not need to move the weight between cpu and cuda, etc.

But, if you want to do inference batchwise, i still recommand to use the last implementation that with pre_calculation, since that one is better when model do not need to change during a lot of inference.

CHR-ray avatar Jul 29 '25 10:07 CHR-ray