GeometricFlux.jl icon indicating copy to clipboard operation
GeometricFlux.jl copied to clipboard

fixed code for GPU

Open alperyilmaz opened this issue 3 years ago • 2 comments

runs in GPU now, reshape() function didn't work with ODE output, so I just get first element of the array. The training is stuck at accuracy of 0.24, not sure if I broke something.

alperyilmaz avatar Oct 26 '20 20:10 alperyilmaz

sorry, I didn't notice gde_gpu.jl file. Still, reshape function might cause problems gde_gpu as well..

alperyilmaz avatar Oct 26 '20 22:10 alperyilmaz

Would you please change your modification to gde_gpu.jl file?

yuehhua avatar Nov 02 '20 06:11 yuehhua