diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

replace_nans_at_start overwrite existing values

Open JadM133 opened this issue 1 year ago • 1 comments

Hello, I am not sure if the behavior of replace_nans_at_start in backward_hermite_coefficients is expected when we have initial values that aren't all NaNs.

The documentation proposes that "replace_nans_at_start" will change the NaN values at the start by the ones given.

I believe that this parameter is only useful if ALL the values at the start are NaNs. I give an example below (based again on get_data from the NCDE example) where the dataset has no NaN values. It actually raises an assertion error because values are being switched to "means" regardless wether they're NaNs or not.

snap

I think the problem is that in _backward_hermite_coefficients the following line:

y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape)

doesn't consider if the values are NaNs and switches them automatically. In my opinion, the logic behind the parameter would be to only switch the NaN values and not overwrite the already existing ones. I would propose either mentioning in the docstring if this is the expected behvaior or adjusting the proposed line to consider if values are actually NaNs or not.

Thank you in advance! The library is great!

JadM133 avatar Oct 16 '23 16:10 JadM133

Yup, agreed. This at least should be an easy fix -- if you get the chance I'd be happy to take a pull request on this!

patrick-kidger avatar Oct 16 '23 18:10 patrick-kidger