BCD-Nets icon indicating copy to clipboard operation
BCD-Nets copied to clipboard

Non-hashable type error

Open lucfra opened this issue 3 years ago • 6 comments

Dear all,

I've read the BCD-nets paper which I found very interesting. I am trying now to recreate your results, but unfortunately, I have run into this error.

line 949, in <module>
    ) = parallel_gradient_step(
ValueError: Non-hashable static arguments are not supported. 
An error occured during a call to 'parallel_gradient_step' while trying to hash
 an object of type <class 'numpy.ndarray'>,
 [[ 4.82755829e-01  2.30017473e+00  1.29824051e+00  1.94172572e+00 .....

which refers to this line of your code.

I must say that I have no experience with jax. For context, I did not manage to install all the required packages using your environment.yml, so I went on with a manual installation. My jax version is 0.3.1.

P.S.: the code was not compatible right away. To make it runnable I did the following:

  • Replaced jax.partial (which is no longer available) with functools.partial (I read that jax.partial was an accidental leak)
  • Copy-pasted _conv_transpose_padding in nux.util.convolution from the jax version you used. Couldn't find _conv_transpose_padding in the 0.3.1.

Any help is very appreciated,

Cheers, Luca

lucfra avatar Feb 22 '22 17:02 lucfra