François Chollet
François Chollet
> Hey, sorry for not finishing this PR, I have a quick question, where should I add the tests? In `keras/src/ops/nn_test.py`. Ops are tested through the op class in e.g....
@james77777778 do you think flash attention should be a standalone op, or could this be managed at the level of the dot_product_attention op (e.g. as an argument)?
Very cool -- @hazemessamm can we do that, e.g. by adding a `flash_attention` argument in `dot_product_attention`? This makes it quite easy to also add support for JAX ( in addition...
The test fails on torch + GPU: > FAILED keras/src/ops/nn_test.py::NNOpsCorrectnessTest::test_dot_product_attention_none_none_(true, false)_true - RuntimeError: No available kernel. Aborting execution. Do you know if this is an issue with the torch version?...
> If you think that this is a good idea then I will use this snippet in the flash attention function in PyTorch backend. That sounds great! Then, we can...
Unfortunately this code snippet is not reproducible since it refers to `'old_keras_model.h5'`. Do you have a reproducible code snippet? You can attach files to GitHub comments.
@mattdangerw I think you're right -- Keras 2 was applying input dropout randomly across timesteps (while using a temporally constant mask for recurrent dropout), while Keras 3 is using a...
When you are calling `__call__` or `predict_step`, you are using eager execution by default. When you are calling `predict` or `predict_on_batch` you are using a compiled function. So it sounds...
My commendation here would be try with another backend, e.g. torch or JAX. It is likely to be a TF specific issue.
That sounds good to me, but a caveat is that we cannot test such a change on CI. Did you try it out on your end and does it work?