torchcubicspline
torchcubicspline copied to clipboard
inefficient _validate_input and mistake
The function _validate_input
seems to be incorrect. The code tries to block non-monotonic data, but since prev_t_i
is not updated, it doesn't seem to work. And in a Cuda environment, this part of the code is very inefficient.
Good catch. I'd be happy to accept a PR fixing this.
Yep, none of this code is very efficient on GPU. My own use case for this was interpolating data as a preprocessing step prior to training, so it didn't matter.
If you're agnostic to the choice of autodiff framework, then you may like to have a look at Diffrax, which includes code for backward Hermite cubic splines. (Rather than the natural cubic splines in this package.) Unlike the code here that should be relatively GPU-efficient.
Thanks for the recommendation. I actually designed an interpolation module in a PyTorch model that needs to run online, so I am looking for a more efficient open-source PyTorch-based implementation of cubic spline interpolation.
I am also currently considering whether to switch to a more efficient interpolation method or to rewrite a more efficient implementation.
Right. If you need it to run online then the implementation in this repository won't be suitable, I'm afraid, as natural cubic splines don't satisfy that property: the "future" affects the "past".
The backward Hermite cubic splines I mentioned above are probably the appropriate algorithmic tool here, but of course if you're constrained to PyTorch then you'll have to reimplement them yourself. I think torchcde does have an implementation you can use as a starting but IIRC this implementation is mistakenly noncausal in the presence of missing data (represented as NaNs). More broadly if you want a paper reference for backward Hermite cubic splines btw then see here.
Really appreciate your help.
I think I need to spend some time researching this problem.