jax-fem
jax-fem copied to clipboard
How to do parallel solver
Hi, I'm trying to create an hybrid model, that use your FEM-solver and a neural network. To do that, I need to solve the same equations with different parameters (e.g the heat diffusion, where the diffusion coefficient and ic are different for each data). I can't use vmap because the solver is using scipy and numpy, which isn't compatible. Do you think the solver can be adapted so it can managed batches or can be pass into vmap ? Thanks in advance for any idea !