classifier-free-guidance-pytorch
classifier-free-guidance-pytorch copied to clipboard
Added so the embeddings are actually dropped
At the moment the TextEmbeddingReturner just returns a mask where it given the % chance using cond_drop_prob, is setting some mask values to false.
This works great if the model you are working with respects the mask, however x-transformers attention does not do this since it takes the context and uses the raw context and feeds it to the k and v linear layers.
https://github.com/lucidrains/x-transformers/blob/0c6266ee44ea99a4449cd9201ba55924a6a7eae7/x_transformers/x_transformers.py#L944
kv_input = default(context, x)
q_input = x
k_input = kv_input
v_input = kv_input
r_input = x
q = self.to_q(q_input)
k = self.to_k(k_input)
v = self.to_v(v_input) if exists(self.to_v) else k
r = self.to_r(r_input) if exists(self.to_r) else None