vector-quantize-pytorch
vector-quantize-pytorch copied to clipboard
masking compatible with fullgraph compile
this adds some slightly confusing masking code, but improves speed by 3x by making the shape of intermediate tensors non-dynamic. The masked_mean code is equivalent, up to fp precision, with the old code that used tensor indexing
Before, using LFQ with masking was not compatible with torch.compile with fullgraph=True or with dynamic=False. It was compatible with plain torch.compile, but the masked tensor indexing caused graph breaks
I added an example that uses masked sequences, to make sure it works properly
I did a benchmark. I ran the example code that uses masking. This was on a 3090 GPU
- the previous masked LFQ implementation, using torch.compile(model, fullgraph=False, mode='max-autotune'), had an average model.forward time of 1.18 milliseconds
- with this commit, using torch.compile(model, fullgraph=True, mode='max-autotune'), the average time is 0.40 milliseconds
The speedup might be worth the extra confusingness in the code
ah yea, that does look a bit confusing, needs a tiny bit more work
do you think you can try fitting all the logic into one function, masked_mean, where if mask is None, it simply takes a regular .mean()?
we can reassess after your refactor
@theAdamColton have you tried the updated LFQ? curious how you got good results on the previous broken one
With the previous LFQ i set entropy loss and commit loss to very low weights and it did actually work.
I've also been experimenting with the entropy loss from maskgit, it does it slightly different than the current lfq code here. The one there seems to work pretty well
Also, this is a different issue, but I think here where the entropy is computed, maybe it should use F.log_softmax to separately compute the log probs from the distances, instead of taking the log of the probs to get the log probs.
Also, this is a different issue, but I think here where the entropy is computed, maybe it should use F.log_softmax to separately compute the log probs from the distances, instead of taking the log of the probs to get the log probs.
@theAdamColton how is that different? can you show me in code?
@lucidrains for example, instead of
prob = (-distance * inv_temperature).softmax(dim = -1)
per_sample_entropy = (-prob * log(prob)).sum(dim=-1).mean()
this is what I mean:
prob = (-distance * inv_temperature).softmax(dim = -1)
log_prob = F.log_softmax(-distance * inv_temperature, dim = -1)
per_sample_entropy = (-prob * log_prob).sum(dim=-1).mean()
I don't know if it would make a difference, but it's what the maskgit code does. Using log_softmax might fix precision issues
from the pytorch log_softmax doc "While mathematically equivalent to log(softmax(x)), doing these two operations separately is slower and numerically unstable. This function uses an alternative formulation to compute the output and gradient correctly."
I think the numerical stability is accounted for by the epsilon in the log I have in the file, but do let me know otherwise
anyways, I've put in my hours today, happy Saturday! See if you can get that mask to go into the masked mean fn and I'll review it again