Physics-informed-neural-network-in-JAX
Physics-informed-neural-network-in-JAX copied to clipboard
[Question] Sum within gradients
Hi Mahmoud!
Thank you very much for this repo, I have been working on PINNs but only with third-party packages (deepxde) and now I have been wanting to implement my own with JAX.
I want to to something similar to the system of ODE you already have but I have been wondering why you sum the tensor t inside the gradient in the ode residual loss.
def ODE_loss(t,y1,y2):
y1_t=lambda t:jax.grad(lambda t:jnp.sum(y1(t)))(t)
y2_t=lambda t:jax.grad(lambda t:jnp.sum(y2(t)))(t)
return y1_t(t) - y1(t) , y2_t(t) - y1(t) + y2(t)
I have seen other implementations where then don't do that, but I think they use jacobian instead of grad.
Thank you very much!