jax-cfd
jax-cfd copied to clipboard
Can't run spectral_forced_turbulence due to jnp.linalg.norm error
I have copied the code from spectral_forced_turbulence.ipynb, but it gives the following error:
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<timed exec> in <module>
[/usr/local/lib/python3.8/dist-packages/jax_cfd/base/initial_conditions.py](https://localhost:8080/#) in filtered_velocity_field(rng_key, grid, maximum_velocity, peak_wavenumber, iterations)
110 # specified maximum velocity.
--> 111 return funcutils.repeated(project_and_normalize, iterations)(velocity)
112
28 frames
UnfilteredStackTrace: TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
<timed exec> in <module>
[/usr/local/lib/python3.8/dist-packages/jax/_src/numpy/util.py](https://localhost:8080/#) in _check_arraylike(fun_name, *args)
343 if not _arraylike(arg))
344 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 345 raise TypeError(msg.format(fun_name, type(arg), pos))
346
347
TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.
You can reproduce this with this Google Colab. It is also confusing that the Google Colab provided in the README does not install jax-cfd. What is the intended way of running them?
Edit: it seems that the jax-cfd version on PyPi is outdated. Downloading the source from the repository works.