fix external variable initialization
I've faced two bugs when trying to implement : https://github.com/lu-group/sbinn/blob/b2c1c94d6564732189722f6e6772af0f63cb0d8c/sbinn/sbinn_tf.py#L8
-
in
model.pyit's because the external variables were not initialized on the second compile as the parameters of the net were already -
in
pde.pyif I don't compile with external variables I still want the code to work with the default values of unknowns I think this code can be safely modified only for jax because the line after was already only for jax
Could you point out an example for using this code?
Here : https://github.com/bonneted/sbinn/blob/main/sbinn/sbinn_jax.py
The implementation of sbinn using JAX. We first train the model without the external variables :
def ODE(t, y, unknowns=[var.value for var in var_list_]):
...
model.compile("adam", lr=1e-3, loss_weights=[0, 0, 0, 0, 0, 0, 1e-2])
model.train(epochs=firsttrain, display_every=1000)
model.compile(
"adam",
lr=1e-3,
loss_weights=[1, 1, 1e-2, 1, 1, 1, 1e-2],
external_trainable_variables=var_list_,
)
variablefilename = "variables.csv"
variable = dde.callbacks.VariableValue(
var_list_, period=callbackperiod, filename=variablefilename
)
losshistory, train_state = model.train(
epochs=maxepochs, display_every=1000, callbacks=[variable]
)
For this first train, we want to use the default unknowns argument for the ODE
The code modification seems necessary. But there is another example https://github.com/lululxvi/deepxde/blob/master/examples/pinn_inverse/Lorenz_inverse.py , which works well (at least worked earlier).
This one was already working well because there is no pertaining without the external variables. The model is only compiled with the external trainable variables :
model.compile(
"adam", lr=0.001, external_trainable_variables=external_trainable_variables
)
losshistory, train_state = model.train(iterations=20000, callbacks=[variable])
The problem occurs when we compile without the external trainable variables, which is when we want the PDE to use the default unknowns argument.
The code seems OK. But the underlying logic becomes extremely complicated now.
In fact, you can simply add external_trainable_variables in the first compile. As the PDE loss weight is 0, those variables won't get updated any way.
That's true in that case, but it can be interesting to start training the model with frozen parameters (for example https://doi.org/10.1126/sciadv.abk0644)
Moreover, it would mean that putting default unknowns values for the PDE is useless and misleading as they can never be used.
Please resolve the conflicts.
I've resolved the conflict based on your improved logic. In the JAX backend conditional I added the possibility that there are no external trainable variables but a default value available.