Tri Dao

Results 429 comments of Tri Dao
trafficstars

You can try FA3 too, which runs on A100 now. Btw triton bwd does not support causal=False, when you call it w causal=False it still runs w causal=True. You can...

Most likely register spilling if I have to guess. You can try smaller block sizes to see if that helps.

@KimmiShi Can you post a short script to reproduce the error? Sth like ``` # Construct DropoutAddRMSNorm module # Generate q # Pass q to the module, get error ```...

It requires the last dimension to be multiple of 8, as mentioned in the README. We do call `.contiguous()` and check that dimension is divisible by 8. Maybe there's some...

If you can print out more info (shape, stride, dtype) of the input to DropoutAddRMSNorm that would also help me reproduce the error. e.g., before self.q_norm: ``` input = q.transpose(1,...

Thanks for the repro script, I've narrowed it down to a memory alignment problem. We expect all input tensors to be aligned to 16 bytes (in order to use vectorized...

I don't know a reliable way to get 16 bytes alignment, but I've posted a [question](https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440) to Pytorch forum.

I pushed a commit to (hopefully) make sure that memory addresses are aligned by 16 bytes by cloning the inputs.

Yes, I can reproduce it. I don't have the bandwidth right now to debug it. I'm not familiar with DeepSpeed, I suspect it puts all parameters in a buffer and...