trax
trax copied to clipboard
SelfAttention - problem with tensorflow 2.11.0
Description
When I have set a TF backend (2.11.0) to computation and try to use SelfAttention from research module than I receive Exception below:
File "/usr/local/lib/python3.10/dist-packages/trax/layers/research/efficient_attention.py", line 1536, in
Probbaly there is problem that TF datatype is not appopriate subtype. Shortly:
jax.np.issubdtype(tf.float64, np.floating)
Gives:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3442, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "
Environment information
OS: Windows and WSL (Ubuntu 20.04.5 LTS)
$ pip freeze | grep trax
trax==1.4.1
$ pip freeze | grep tensor
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.8.1
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.29.0
tensorflow-metadata==1.12.0
tensorflow-text==2.11.0
$ pip freeze | grep jax
jax==0.4.1
jaxlib==0.4.1+cuda11.cudnn86
$ python -V
Python 3.10.9
For bugs: reproduction and error logs
# Steps to reproduce:
jax.np.issubdtype(tf.float64, np.floating)
Gives:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3442, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-19-b62b80ea3d79>", line 1, in <module>
jax.np.issubdtype(tf.float64, np.floating)
File "/usr/local/lib/python3.10/dist-packages/numpy/core/numerictypes.py", line 416, in issubdtype
arg1 = dtype(arg1).type
TypeError: Cannot interpret 'tf.float64' as a data type
This situation happens in efficient_attention.py package during computation:
inputs_is_differentiable = fastmath.nested_map(
lambda x: np.issubdtype(x.dtype, np.inexact), inputs)
Error logs:
Traceback ebowe