BCD-Nets
BCD-Nets copied to clipboard
Non-hashable type error
Dear all,
I've read the BCD-nets paper which I found very interesting. I am trying now to recreate your results, but unfortunately, I have run into this error.
line 949, in <module>
) = parallel_gradient_step(
ValueError: Non-hashable static arguments are not supported.
An error occured during a call to 'parallel_gradient_step' while trying to hash
an object of type <class 'numpy.ndarray'>,
[[ 4.82755829e-01 2.30017473e+00 1.29824051e+00 1.94172572e+00 .....
which refers to this line of your code.
I must say that I have no experience with jax. For context, I did not manage to install all the required packages using your environment.yml, so I went on with a manual installation. My jax version is 0.3.1.
P.S.: the code was not compatible right away. To make it runnable I did the following:
- Replaced jax.partial (which is no longer available) with functools.partial (I read that jax.partial was an accidental leak)
- Copy-pasted _conv_transpose_padding in nux.util.convolution from the jax version you used. Couldn't find _conv_transpose_padding in the 0.3.1.
Any help is very appreciated,
Cheers, Luca