Diferentiating w.r.t. a LinearInterpolation
Hi there
I'm working with dynamical systems and have run in a small issue regarding differentiation with respect to control inputs.
From my studies of the examples provided I assumed the main way to pass control inputs to a dynamical system is to pass in a LinearInterpolator object. I implemented this throughout my code, but now got stuck trying to calculate the sensitivities of my system with respect to this control input.
I have not found any examples of how to work with this kind of example. So I will present my attempts to work around this issue:
Lets take the toy dynamical system as example:
import diffrax as dfx
import jax.numpy as jnp
import jax
# Create linear interpolation
ts = jnp.linspace(0.0, 1.0, 10)
ys = jnp.block([[jnp.linspace(0.0, 1.0, 10)], [jnp.linspace(0.0, 1.0, 10)]]).T
u_interp = dfx.LinearInterpolation(ts=ts, ys=ys)
# Toy dynamical system
def f(t, y, args):
"""Dynamical system function"""
u = args
return jnp.array([[1, 0], [0, 1]]) @ y + jnp.array([[1, 0], [0, 1]]) @ u.evaluate(t)
To calculate the sensitivities w.r.t. args:
dfdu = jax.jacfwd(f, argnums=2)(0, jnp.array([2.0,1.0]), u_interp)
print(dfdu)
Then we get a LinearInterpolation object, this surprised me, as I thought it would just spit out an error:
LinearInterpolation(ts=f32[2,10], ys=f32[2,10,2])
But now if we try to evaluate it at a given instant (maybe it should take another kind of input?????):
print(dfdu.evaluate(0))
We get:
ValueError: a should be 1-dimensional
It is not clear to me what value it refers to, but I believe this means that LinearInterpolation object has a f32 2-dimentional array as ts.
Inspecting the ts and ys values we get some pretty nonsensical values(to me at least).
print(dfdu.ts)
[[-1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[-1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
print(dfdu.ys)
[[[1. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]]
[[0. 1.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]]]
Is this expected behavior?
Should I pursue other solutions?
Is there a more direct/standard way to calculate the sensitivity of my dynamical system?
Thank you very much for the attention and support. All assistance is welcomed.
Hi,
I think you're just missing one thing here - you're taking a Jacobian, and the output of $f$ is multivariate. This means that the derivative of $f$ w.r.t. $u$ is going to contain extra array elements, since you are computing the derivative of every element of the output of f w.r.t. $u$ with jacfwd.
This should work out of the box if you are taking the derivative of a scalar loss with respect to your control input $u$, e.g. a quadratic objective that penalises deviations from a given point or some other loss that fits your application.
And yes, your derivative with respect to a linear interpolation is another linear interpolation - since a LinearInterpolation is just a PyTree to JAX, and hence a supported type. This comes in quite handily in optimal control applications :)
Hope this helps!
I recommend differentiating with respect to raw JAX arrays (and internally wrapping this into a LinearInterpolation object).
It should then be fairly clear what the meaning of the gradients are, as they'll just be the raw gradients on those numbers.
Thanks for the suggestions!
This should work out of the box if you are taking the derivative of a scalar loss with respect to your control input u , e.g. a quadratic objective that penalises deviations from a given point or some other loss that fits your application.
Good to know, unfortunately I would need this to work for the multivariable case too, as it would require a extensive rewrite and also probably would harm the "objectiveness" of the code.
I recommend differentiating with respect to raw JAX arrays (and internally wrapping this into a LinearInterpolation object).
Understood, but as I said above it would require a rewrite of the code, and that would be complicated, due to the fact that the system I programed the system to be composed of multiple nested functions (in retrospect a bad idea), where the interpolator is passes as argument to deeper levels.
I took some time to tinker with the results. I think I got a result that has worked out great.
jax.config.update("jax_enable_x64", True) #double precision for precaution
# Create linear interpolation
ts = jnp.linspace(0.0, 1.0, 100)
ys = jnp.block([[jnp.exp(jnp.linspace(0.0, 1.0, 100))], [jnp.linspace(0.0, 4, 100)]]).T
u_interp = dfx.LinearInterpolation(ts=ts, ys=ys)
B = jnp.array([[1.0, 0], [0, 2.0]])
def f(t, x, args):
"""Dynamical system function"""
u= args
return jnp.array([[1, 0], [0, 1]]) @ x + B @ u.evaluate(t)
My goal is to extract the B matrix, in this case it is trivial, but for Non-linear systems it is not readily available.
Inspecting the results from the jacobian:
#Evaluation time for derivative
t = 0.0
jacf_u = jax.jacfwd(f, argnums=2)(t, jnp.array([0.0, 0.0]), u_interp)
print(jacf_u.ts)
print(jnp.transpose(jacf_u.ys, axes=(1,0,2))) #Transposed to align results better
[[-1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[-8. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[[1. 0.]
[0. 2.]]
[[0. 0.]
[0. 0.]]
[[0. 0.]
[0. 0.]]
[[0. 0.]
[0. 0.]]
[[0. 0.]
[0. 0.]]
[[0. 0.]
[0. 0.]]
[[0. 0.]
[0. 0.]]
[[0. 0.]
[0. 0.]]
[[0. 0.]
[0. 0.]]
[[0. 0.]
[0. 0.]]]
We seem to get some results related to the values of interest. These results allows us to calculate the values for $\partial ys/\partial t$ and $\partial f/\partial u$ in the following form.
dfdu = jnp.sum(jacf_u.ys, axis=1)
dysdt= jnp.linalg.inv(dfdu)@jnp.sum(-jacf_u.ts, axis=1)
print(dfdu)
print(jnp.isclose(dfdu, B))
print(dysdt)
print(jnp.isclose(dysdt, jax.jacfwd(u_interp.evaluate)(t)))
[[1. 0.]
[0. 2.]]
[[ True True]
[ True True]]
[1. 4.]
[ True True]
Validating for the cases where the system is not linear:
# Non-Linear with Interpolator
def f_non_linear(t, x, args):
"""Non linear Dynamical system function"""
u = args
return jnp.sin(x**-1) + jnp.cos(u.evaluate(t))
# Non-Linear without Interpolator
def f_non_linear_wo_interpolator(t, x, args):
"""Non linear Dynamical system function"""
u = args
return jnp.sin(x**-1) + jnp.cos(u)
t = 0.1
# Classical definition of non-linear control sensitivity matrix
B = jax.jacfwd(f_non_linear_wo_interpolator, argnums=2)(t, jnp.array([0.0, 1.0]), u_interp.evaluate(t)) # Interpolator evaluated before passing
# Compute sensibility matrix w. proposed method for comaparison to classic
jacf_non_linear_u = jax.jacfwd(f_non_linear, argnums=2)(t, jnp.array([0.0, 1.0]), u_interp)
dfdu = jnp.sum(jacf_non_linear_u.ys, axis=1)
print(dfdu)
print(jnp.isclose(dfdu, B))
dysdt= jnp.linalg.inv(dfdu)@jnp.sum(-jacf_non_linear_u.ts, axis=1)
print(dysdt)
print(jnp.isclose(dysdt, jax.jacfwd(u_interp.evaluate)(t)))
[[-0.09983342 0. ]
[ 0. -0.38941834]]
[[ True True]
[ True True]]
[1. 4.]
[ True True]
So I think I can say that this problem is solved. If you happen to have any other suggestions I'd be happy to hear.
Once again, thanks for the support and dedication, and congrats on the awesome project!