sd_dreambooth_extension icon indicating copy to clipboard operation
sd_dreambooth_extension copied to clipboard

Allow Lora models to be injected by monkey patching via script

Open ExponentialML opened this issue 3 years ago • 3 comments

WHAT

Lora models should be able to be injected via a script instead of creating an entirely new model.

HOW

Unless I'm wrong here, the Lora weights are applied to the unet upon inference. The same should be doable via a custom script before inference. Before p goes to inference, using the Lora weights this way would save a lot of time and add usability, because currently you have to compile the Lora weights with the model every time.

The functionality would be akin to hypernetworks.

monkeypatch_lora(p.unet, torch.load("./lora.pt"))
monkeypatch_lora(p.text_encoder, torch.load("./lora.pt"), target_replace_module=["CLIPAttention"])

I would rather not create a PR for this because you seem to be pretty fast at adding stuff :slightly_smiling_face: .

ExponentialML avatar Dec 12 '22 22:12 ExponentialML

What you're describing above is what the "weight_lora_apply" fuction does. ;)

image

d8ahazard avatar Dec 12 '22 23:12 d8ahazard

The problem is, regular inference via dreambooth with diffusers isn't possible, because dreambooth needs a ckpt. So, in order to use lora + model - we need to drop the lora weights into the unet/txt_encoder, save them as pretrained, then compile/convert.

d8ahazard avatar Dec 12 '22 23:12 d8ahazard

The problem is, regular inference via dreambooth with diffusers isn't possible, because dreambooth needs a ckpt. So, in order to use lora + model - we need to drop the lora weights into the unet/txt_encoder, save them as pretrained, then compile/convert.

def monkeypatch_lora(
    model, loras, target_replace_module=["CrossAttention", "Attention"]
):
    for _module in model.modules():
        if _module.__class__.__name__ in target_replace_module:
            for name, _child_module in _module.named_modules():
                if _child_module.__class__.__name__ == "Linear":

                    weight = _child_module.weight
                    bias = _child_module.bias
                    _tmp = LoraInjectedLinear(
                        _child_module.in_features,
                        _child_module.out_features,
                        _child_module.bias is not None,
                    )
                    _tmp.linear.weight = weight

                    if bias is not None:
                        _tmp.linear.bias = bias

                    # switch the module
                    _module._modules[name] = _tmp

                    up_weight = loras.pop(0)
                    down_weight = loras.pop(0)

                    _module._modules[name].lora_up.weight = nn.Parameter(
                        up_weight.type(weight.dtype)
                    )
                    _module._modules[name].lora_down.weight = nn.Parameter(
                        down_weight.type(weight.dtype)
                    )

                    _module._modules[name].to(weight.device)

https://github.com/cloneofsimo/lora/blob/b3a9bd0cf4a1d68bf461e4ad2ad150eda6deeb29/lora_diffusion/lora.py#L131

Would it not be possible to run a checkpoint's unet through this function when running a text to image prompt?

ExponentialML avatar Dec 13 '22 01:12 ExponentialML

This issue is stale because it has been open 5 days with no activity. Remove stale label or comment or this will be closed in 5 days

github-actions[bot] avatar Dec 19 '22 00:12 github-actions[bot]