Physics-informed-neural-network-in-JAX icon indicating copy to clipboard operation
Physics-informed-neural-network-in-JAX copied to clipboard

[Question] Sum within gradients

Open aoguedao opened this issue 1 year ago • 0 comments

Hi Mahmoud!

Thank you very much for this repo, I have been working on PINNs but only with third-party packages (deepxde) and now I have been wanting to implement my own with JAX.

I want to to something similar to the system of ODE you already have but I have been wondering why you sum the tensor t inside the gradient in the ode residual loss.

def ODE_loss(t,y1,y2):

  y1_t=lambda t:jax.grad(lambda t:jnp.sum(y1(t)))(t)
  y2_t=lambda t:jax.grad(lambda t:jnp.sum(y2(t)))(t)

  return y1_t(t) - y1(t)  , y2_t(t) - y1(t) + y2(t)

I have seen other implementations where then don't do that, but I think they use jacobian instead of grad.

Thank you very much!

aoguedao avatar May 23 '24 16:05 aoguedao