classifier-free-guidance-pytorch icon indicating copy to clipboard operation
classifier-free-guidance-pytorch copied to clipboard

Added so the embeddings are actually dropped

Open MarcusLoppe opened this issue 8 months ago • 14 comments

At the moment the TextEmbeddingReturner just returns a mask where it given the % chance using cond_drop_prob, is setting some mask values to false.

This works great if the model you are working with respects the mask, however x-transformers attention does not do this since it takes the context and uses the raw context and feeds it to the k and v linear layers.

https://github.com/lucidrains/x-transformers/blob/0c6266ee44ea99a4449cd9201ba55924a6a7eae7/x_transformers/x_transformers.py#L944

kv_input = default(context, x)

q_input = x
k_input = kv_input
v_input = kv_input
r_input = x 

q = self.to_q(q_input)
k = self.to_k(k_input)
v = self.to_v(v_input) if exists(self.to_v) else k
r = self.to_r(r_input) if exists(self.to_r) else None 

MarcusLoppe avatar Jun 07 '24 19:06 MarcusLoppe