progrock-stable icon indicating copy to clipboard operation
progrock-stable copied to clipboard

Issues with negative weights

Open JPPhoto opened this issue 2 years ago • 2 comments

If prompt values sum to less than zero, weight_sum becomes negative and weight-adjusted prompts are all negated. I don't think the resulting image(s) make a lot of sense. I think this condition of weight_sum being less than zero should result in an error.

JPPhoto avatar Sep 27 '22 19:09 JPPhoto

I've also noticed that negative prompt weights can result in images that resemble noise or have other strange artifacts. Since I'm useless at GitHub, I'm proposing my patch here:

--- prs.py      2022-10-01 14:33:20.332094400 +0000
+++ prs-jsp.py  2022-10-01 12:55:38.413053000 +0000
@@ -150,12 +150,15 @@
     parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, input_string)]
     if not normalize:
         return parsed_prompts
-    weight_sum = sum(map(lambda x: x[1], parsed_prompts))
-    if weight_sum == 0:
-        print("Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
-        equal_weight = 1 / (len(parsed_prompts) or 1)
-        return [(x[0], equal_weight) for x in parsed_prompts]
-    return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
+    positive_weight_sum = sum(map(lambda px: px[1] if (px[1] > 0) else 0, parsed_prompts))
+    negative_weight_sum = -sum(map(lambda nx: nx[1] if (nx[1] < 0) else 0, parsed_prompts))
+    if positive_weight_sum == 0:
+        print("Warning: Positive subprompt weights add up to zero. Discarding and using even weights instead.")
+        positive_weight_sum = 1
+    if negative_weight_sum == 0:
+        print("Warning: Negative subprompt weights add up to zero or there are no negative subprompts. Discarding and using even weights instead.")
+        negative_weight_sum = 1
+    return [(x[0], (x[1] / positive_weight_sum) if (x[1] > 0) else (x[1] / negative_weight_sum)) for x in parsed_prompts]
 
 prompt_parser = re.compile("""
     (?P<prompt>     # capture group for 'prompt'
@@ -287,11 +290,18 @@
                             if len(weighted_subprompts) > 1:
                                 c = torch.zeros_like(uc) # i dont know if this is correct.. but it works
                                 for i in range(0, len(weighted_subprompts)):
-                                    # note if alpha negative, it functions same as torch.sub
-                                    c = torch.add(c, model.get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1])
+                                    if weighted_subprompts[i][1] < 0:
+                                        uc = torch.zeros_like(uc)
+                                        break
+                                for i in range(0, len(weighted_subprompts)):
+                                    tensor = model.get_learned_conditioning(weighted_subprompts[i][0])
+                                    if weighted_subprompts[i][1] > 0:
+                                        c = torch.add(c, tensor, alpha=weighted_subprompts[i][1])
+                                    else:
+                                        uc = torch.add(uc, tensor, alpha=-weighted_subprompts[i][1])
                             else: # just behave like usual
                                 c = model.get_learned_conditioning(prompts)
-                            
+
                             if opt.variance != 0.0:
                                 # add a little extra random noise to get varying output with same seed
                                 base_x = og_start_code # torch.randn(rand_size, device=device) * sigmas[0]

JPPhoto avatar Oct 01 '22 14:10 JPPhoto

See #56

JPPhoto avatar Oct 13 '22 13:10 JPPhoto