FBPINNs
FBPINNs copied to clipboard
Higher-Order Gradient Derivative Problem
Thank you for sharing your work, it's very interesting! The new version using JAX is indeed much faster, but I'm not very familiar with it (I use PyTorch more). Recently, when solving a PDE, I encountered this problem:
$\mathrm{Loss}_1=\frac{\partial u}{\partial x}+\frac{\partial v}{\partial y}$
$\mathrm{Loss}_2=\frac{\partial}{\partial y}\left[ \left( v+\frac{v_t}{\sigma _k} \right) \frac{\partial k}{\partial y} \right] $
$\sigma _k$ is given, the input of the neural network is $x$, $y$, and the output of the neural network is $u$, $v$, and $k$.
When constructing the physical loss $Loss_2$ of the above equation, $\frac{\partial k}{\partial y}$ needs to be used. The current FBPINN framework uses required_ujs_phys
to callback gradients, as shown in the following code framework:
def sample_constraints(all_params, domain, key, sampler, batch_shapes):
# physics loss
y_batch_phys = domain.sample_interior(all_params, key, sampler, batch_shapes[0])
required_ujs_phys = (
(0,()), # u
(1,()), # v
(2,()), # k
(2,(1,)), # k_y
)
return [[y_batch_phys, required_ujs_phys]]
This causes a problem: I can't calculate the gradient of $\frac{\partial}{\partial y}\left[ \left( v+\frac{v_t}{\sigma _k} \right) \frac{\partial k}{\partial y} \right] $, because it's a mixed second-order gradient that requires the first-order $\frac{\partial k}{\partial y}$ to calculate the final gradient. It can't be recalled through required_ujs_phys
.
This kind of composite gradient is quite common. Do you have any good suggestions to solve this problem?
Thank you for your reading!