pytorch-transformer
pytorch-transformer copied to clipboard
Refactor Causal Mask Generation for Simplicity
I updated the causal_mask function to create a lower triangular matrix using torch.tril, which is more concise and clearer than inverting a mask generated with torch.triu. The functionality of the code is preserved. I hope this helps :)