vector-quantize-pytorch icon indicating copy to clipboard operation
vector-quantize-pytorch copied to clipboard

masking compatible with fullgraph compile

Open theAdamColton opened this issue 1 year ago • 10 comments
trafficstars

this adds some slightly confusing masking code, but improves speed by 3x by making the shape of intermediate tensors non-dynamic. The masked_mean code is equivalent, up to fp precision, with the old code that used tensor indexing

Before, using LFQ with masking was not compatible with torch.compile with fullgraph=True or with dynamic=False. It was compatible with plain torch.compile, but the masked tensor indexing caused graph breaks

I added an example that uses masked sequences, to make sure it works properly

I did a benchmark. I ran the example code that uses masking. This was on a 3090 GPU

  • the previous masked LFQ implementation, using torch.compile(model, fullgraph=False, mode='max-autotune'), had an average model.forward time of 1.18 milliseconds
  • with this commit, using torch.compile(model, fullgraph=True, mode='max-autotune'), the average time is 0.40 milliseconds

The speedup might be worth the extra confusingness in the code

theAdamColton avatar Dec 08 '23 18:12 theAdamColton

ah yea, that does look a bit confusing, needs a tiny bit more work

do you think you can try fitting all the logic into one function, masked_mean, where if mask is None, it simply takes a regular .mean()?

lucidrains avatar Dec 09 '23 15:12 lucidrains

we can reassess after your refactor

lucidrains avatar Dec 09 '23 15:12 lucidrains

@theAdamColton have you tried the updated LFQ? curious how you got good results on the previous broken one

lucidrains avatar Dec 09 '23 15:12 lucidrains

With the previous LFQ i set entropy loss and commit loss to very low weights and it did actually work.

theAdamColton avatar Dec 09 '23 17:12 theAdamColton

I've also been experimenting with the entropy loss from maskgit, it does it slightly different than the current lfq code here. The one there seems to work pretty well

theAdamColton avatar Dec 09 '23 17:12 theAdamColton

Also, this is a different issue, but I think here where the entropy is computed, maybe it should use F.log_softmax to separately compute the log probs from the distances, instead of taking the log of the probs to get the log probs.

theAdamColton avatar Dec 09 '23 17:12 theAdamColton

Also, this is a different issue, but I think here where the entropy is computed, maybe it should use F.log_softmax to separately compute the log probs from the distances, instead of taking the log of the probs to get the log probs.

@theAdamColton how is that different? can you show me in code?

lucidrains avatar Dec 09 '23 18:12 lucidrains

@lucidrains for example, instead of

prob = (-distance * inv_temperature).softmax(dim = -1)
per_sample_entropy = (-prob * log(prob)).sum(dim=-1).mean()

this is what I mean:

prob = (-distance * inv_temperature).softmax(dim = -1)
log_prob = F.log_softmax(-distance * inv_temperature, dim = -1)
per_sample_entropy = (-prob * log_prob).sum(dim=-1).mean()

I don't know if it would make a difference, but it's what the maskgit code does. Using log_softmax might fix precision issues

from the pytorch log_softmax doc "While mathematically equivalent to log(softmax(x)), doing these two operations separately is slower and numerically unstable. This function uses an alternative formulation to compute the output and gradient correctly."

theAdamColton avatar Dec 09 '23 19:12 theAdamColton

I think the numerical stability is accounted for by the epsilon in the log I have in the file, but do let me know otherwise

lucidrains avatar Dec 09 '23 19:12 lucidrains

anyways, I've put in my hours today, happy Saturday! See if you can get that mask to go into the masked mean fn and I'll review it again

lucidrains avatar Dec 09 '23 19:12 lucidrains