GeometricFlux.jl
GeometricFlux.jl copied to clipboard
fixed code for GPU
runs in GPU now, reshape() function didn't work with ODE output, so I just get first element of the array. The training is stuck at accuracy of 0.24, not sure if I broke something.
sorry, I didn't notice gde_gpu.jl file. Still, reshape function might cause problems gde_gpu as well..
Would you please change your modification to gde_gpu.jl file?