deepxde icon indicating copy to clipboard operation
deepxde copied to clipboard

fix external variable initialization

Open bonneted opened this issue 1 year ago • 8 comments

I've faced two bugs when trying to implement : https://github.com/lu-group/sbinn/blob/b2c1c94d6564732189722f6e6772af0f63cb0d8c/sbinn/sbinn_tf.py#L8

  • in model.py it's because the external variables were not initialized on the second compile as the parameters of the net were already

  • in pde.py if I don't compile with external variables I still want the code to work with the default values of unknowns I think this code can be safely modified only for jax because the line after was already only for jax

bonneted avatar Jun 17 '24 16:06 bonneted

Could you point out an example for using this code?

lululxvi avatar Jun 19 '24 15:06 lululxvi

Here : https://github.com/bonneted/sbinn/blob/main/sbinn/sbinn_jax.py

The implementation of sbinn using JAX. We first train the model without the external variables :

    def ODE(t, y, unknowns=[var.value for var in var_list_]):
    ...

    model.compile("adam", lr=1e-3, loss_weights=[0, 0, 0, 0, 0, 0, 1e-2])
    model.train(epochs=firsttrain, display_every=1000)
    model.compile(
        "adam",
        lr=1e-3,
        loss_weights=[1, 1, 1e-2, 1, 1, 1, 1e-2],
        external_trainable_variables=var_list_,
    )
    variablefilename = "variables.csv"
    variable = dde.callbacks.VariableValue(
        var_list_, period=callbackperiod, filename=variablefilename
    )
    losshistory, train_state = model.train(
        epochs=maxepochs, display_every=1000, callbacks=[variable]
    )

For this first train, we want to use the default unknowns argument for the ODE

bonneted avatar Jun 20 '24 08:06 bonneted

The code modification seems necessary. But there is another example https://github.com/lululxvi/deepxde/blob/master/examples/pinn_inverse/Lorenz_inverse.py , which works well (at least worked earlier).

lululxvi avatar Jun 20 '24 16:06 lululxvi

This one was already working well because there is no pertaining without the external variables. The model is only compiled with the external trainable variables :

model.compile(
    "adam", lr=0.001, external_trainable_variables=external_trainable_variables
)
losshistory, train_state = model.train(iterations=20000, callbacks=[variable])

The problem occurs when we compile without the external trainable variables, which is when we want the PDE to use the default unknowns argument.

bonneted avatar Jun 21 '24 08:06 bonneted

The code seems OK. But the underlying logic becomes extremely complicated now.

In fact, you can simply add external_trainable_variables in the first compile. As the PDE loss weight is 0, those variables won't get updated any way.

lululxvi avatar Jun 21 '24 15:06 lululxvi

That's true in that case, but it can be interesting to start training the model with frozen parameters (for example https://doi.org/10.1126/sciadv.abk0644) Moreover, it would mean that putting default unknowns values for the PDE is useless and misleading as they can never be used.

bonneted avatar Jul 02 '24 09:07 bonneted

Please resolve the conflicts.

lululxvi avatar Jul 03 '24 01:07 lululxvi

I've resolved the conflict based on your improved logic. In the JAX backend conditional I added the possibility that there are no external trainable variables but a default value available.

bonneted avatar Jul 08 '24 11:07 bonneted