Add torch.float32 support for Apple M1
Hello,
Thanks for your amazing package!
I wanted to ask if it would be possible to provide a mechanism to use torch.float32 as dtype for some of the adaptive solvers.
On Apple M1 (mps) torch.float64 is not supported...
E.g. here https://github.com/rtqichen/torchdiffeq/blob/d6ee52b349ddb6f7ba4a114a65fd8783db243ed6/torchdiffeq/_impl/dopri5.py#L6
the dtype is fixed independent of the input.
I am new to your package (and also a bit to pytorch, since I usually use Julia), therefore, sorry if I was overlooking something obvious.
I encountered this issue as well, but I think all you need to do is to specify options={'dtype': torch.float32} in your odeint call (assuming you're using odeint with an adaptive solver like I do).