Ahmed Elnaggar

Results 41 comments of Ahmed Elnaggar

I have tested the above script using GPUs and it is working without any issue. The long compilation process only occurs with TPUs.

Thanks a lot @skye for your explanation and support. I will use the GPUs for now and I hope the TPU issue will be solved in the near future.

+1 Any update on this requst ?

> > `jax.experimental` has an [implementation](https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py) of FlashAttention, written by Pallas kernels and therefore usable in both GPU and TPU. > > We can probably upstream this to Flax attention...

> `jax.experimental` has an [implementation](https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py) of FlashAttention, written by Pallas kernels and therefore usable in both GPU and TPU. > > We can probably upstream this to Flax attention if...

One more point Jax has two implementations: 1. Fused attention: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/attention.py 2. Flash attention: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py

Thanks a lot @IvyZX for integerating flash attention. I am just afraid that it is still missing some parameters like the dropout, mask and bias, if I am not mistaken.

@IvyZX could you please share the link for the current PR ?

If I am not mistaken, many of the state of the art Embedding method requires the attention bias like ALibi and relative attention encoding for T5 models.

Thanks a lot, @milot-mirdita, for your help. It did work out. However, when I applied "result2profile" and "profile2pssm" to extract the pssm it showed protein sequences that were not part...