Tri Dao

Results 640 comments of Tri Dao

Clarifying my understanding: if q_stage == 1, is there overlapping between softmax and mma?

Sure, would love to see contributions there

As mentioned, in general it's not a good idea to use equality to compare floating points. ``` In [11]: a = torch.randn(10, dtype=torch.bfloat16, device='cuda') In [12]: torch.equal(a + 0.3 -...

There are 2 code paths, one for local and one for causal. There's no guarantee that they produce identical outputs.

There's some code to detect that some local window size is equivalent to causal and run the causal path instead (since causal is faster). That's probably causing what you're observing....

We've just updated them

Yes, for varlen the kernel will read the V beyond the NTOKENS (it reads blocks of e.g. 128 tokens at a time). Typically this is ok because we then multiply...

Please don't use time.time() to measure time. CUDA operations are async. You can use torch benchmark. https://pytorch.org/tutorials/recipes/recipes/benchmark.html