Merge-Stable-Diffusion-models-without-distortion icon indicating copy to clipboard operation
Merge-Stable-Diffusion-models-without-distortion copied to clipboard

KeyError: 'cond_stage_model.transformer.text_model.embeddings.position_ids'

Open R-N opened this issue 3 years ago • 5 comments

Error:

/content/merger
Using half precision

    ---------------------
         ITERATION 1
    ---------------------
    
new alpha = 0.045

Traceback (most recent call last):
  File "SD_rebasin_merge.py", line 55, in <module>
    theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
  File "SD_rebasin_merge.py", line 55, in <dictcomp>
    theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
KeyError: 'cond_stage_model.transformer.text_model.embeddings.position_ids'

Code:

!pip install pytorch-lightning torch==1.11.0+cu113 torchvision==0.12.0+cu113 

!git clone https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion merger
%cd /content/merger

#download models from hf
!curl ...

model_a_path = "novelai.ckpt"
model_b_path = "sd_1.5.ckpt"
output_name = f"NAI_f222_0.45ws.ckpt"
alpha = 0.45
device = "cuda"

%cd /content/merger
!python SD_rebasin_merge.py --model_a {model_a_path} --model_b {model_b_path} --output {output_name} --alpha {alpha} --device {device}

RAM usage: image

R-N avatar Feb 23 '23 04:02 R-N

I've tried adding it to the skipped layers as mentioned here. It's in weight_matching.py right?

  return permutation_spec_from_axes_to_perm({
     #Skipped Layers
     **skip("betas", None, None),
     **skip("alphas_cumprod", None, None),
     **skip("alphas_cumprod_prev", None, None),
     **skip("sqrt_alphas_cumprod", None, None),
     **skip("sqrt_one_minus_alphas_cumprod", None, None),
     **skip("log_one_minus_alphas_cumprods", None, None),
     **skip("sqrt_recip_alphas_cumprod", None, None),
     **skip("sqrt_recipm1_alphas_cumprod", None, None),
     **skip("posterior_variance", None, None),
     **skip("posterior_log_variance_clipped", None, None),
     **skip("posterior_mean_coef1", None, None),
     **skip("posterior_mean_coef2", None, None),
     **skip("log_one_minus_alphas_cumprod", None, None),
     **skip("model_ema.decay", None, None),
     **skip("model_ema.num_updates", None, None),
     **skip("cond_stage_model.transformer.text_model.embeddings.position_ids", None, None),

     #initial 

But I'm still getting the same error:

/content/merger
Using half precision

    ---------------------
         ITERATION 1
    ---------------------
    
new alpha = 0.045

Traceback (most recent call last):
  File "SD_rebasin_merge.py", line 55, in <module>
    theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
  File "SD_rebasin_merge.py", line 55, in <dictcomp>
    theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
KeyError: 'cond_stage_model.transformer.text_model.embeddings.position_ids'

R-N avatar Feb 23 '23 04:02 R-N

So I'm worried that my changes didn't apply so I printed something, and it's printed. So I think my change applied.

  print("HELLLOOOO")
  return permutation_spec_from_axes_to_perm({
     #Skipped Layers 
     **skip("betas", None, None),
     **skip("alphas_cumprod", None, None),
     **skip("alphas_cumprod_prev", None, None),
     **skip("sqrt_alphas_cumprod", None, None),
     **skip("sqrt_one_minus_alphas_cumprod", None, None),
     **skip("log_one_minus_alphas_cumprods", None, None),
     **skip("sqrt_recip_alphas_cumprod", None, None),
     **skip("sqrt_recipm1_alphas_cumprod", None, None),
     **skip("posterior_variance", None, None),
     **skip("posterior_log_variance_clipped", None, None),
     **skip("posterior_mean_coef1", None, None),
     **skip("posterior_mean_coef2", None, None),
     **skip("log_one_minus_alphas_cumprod", None, None),
     **skip("model_ema.decay", None, None),
     **skip("model_ema.num_updates", None, None),
     **skip("cond_stage_model.transformer.text_model.embeddings.position_ids", None, None),

     #initial 

Error is still the same:

/content/merger
HELLLOOOO
Using half precision

    ---------------------
         ITERATION 1
    ---------------------
    
new alpha = 0.045

Traceback (most recent call last):
  File "SD_rebasin_merge.py", line 55, in <module>
    theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
  File "SD_rebasin_merge.py", line 55, in <dictcomp>
    theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
KeyError: 'cond_stage_model.transformer.text_model.embeddings.position_ids'

R-N avatar Feb 23 '23 04:02 R-N

Weird. I checked your original weight_matching.py and it's already skipped.

R-N avatar Feb 23 '23 04:02 R-N

One of your models might not have the CLIP model embedded. Some versions of protogen suffer from this, maybe others.

In the main script near the start, where it loads the models:

model_a = torch.load(args.model_a, map_location=device)
model_b = torch.load(args.model_b, map_location=device)

You could try some code like this right after those two lines (untested) to copy across the CLIP model from the other model if one of them is missing it:

for key in model_a['state_dict'].keys():
    if 'cond_stage_model.' in key:
        if not key in model_b['state_dict']:
            model_b['state_dict'][key] = model_a['state_dict'][key].clone().detach()
            
for key in model_b['state_dict'].keys():
    if 'cond_stage_model.' in key:
        if not key in model_a['state_dict']:
            model_a['state_dict'][key] = model_b['state_dict'][key].clone().detach()

zwishenzug avatar Feb 25 '23 20:02 zwishenzug

@R-N Did @zwishenzug 's suggestion help ?

ogkalu2 avatar Feb 26 '23 20:02 ogkalu2