deepxde icon indicating copy to clipboard operation
deepxde copied to clipboard

Towards SoftAdapt loss balancing for tf.compat.v1

Open pescap opened this issue 1 year ago • 9 comments

Work in progress!

pescap avatar Dec 06 '23 12:12 pescap

Implementing in TensorFlow is tricky, as it is static graph. It should be much easier to implement in pytorch, where you can directly change the loss_weights value.

lululxvi avatar Dec 06 '23 14:12 lululxvi

Implementing in TensorFlow is tricky, as it is static graph. It should be much easier to implement in pytorch, where you can directly change the loss_weights value.

Thank you for your feeback. I would really prefer to implement this adaptive loss callback in tensorflow.compat.v1.

I think I'll start with a simple two-terms loss (and one weighing parameter).

pescap avatar Dec 06 '23 15:12 pescap

Implementing in TensorFlow is tricky, as it is static graph. It should be much easier to implement in pytorch, where you can directly change the loss_weights value.

  • It can be done if the loss_weights is the argument of the train_step()
  • If not, when we iteratively change the loss weights, we will need tensorflow to make the graph all over again. In other words, model will .compile() again and the training might be slow.

haison19952013 avatar Feb 01 '24 05:02 haison19952013

Hi, if we define loss_weights as Variable, no need to compile several times, right?

Next, we have to define appopriately the total_loss.

pescap avatar Feb 01 '24 15:02 pescap

Hi, if we define loss_weights as Variable, no need to compile several times, right?

Next, we have to define appopriately the [total_loss]https://github.com/lululxvi/deepxde/blob/85920299331bd7c0bad01f3d2abba442a77c89c6/deepxde/model.py#L244.

  • My [last response] (https://github.com/lululxvi/deepxde/pull/1586#issuecomment-1920529102) might raise some confusion so I have corrected it.
  • You are right, we can update it like that without recompiling. But, it's better to verify this again on a toy problem.
  • For more details about implementation in deepxde, we might need to set loss_weights as the arguments of outputs_losses() and then total_loss().

haison19952013 avatar Feb 02 '24 00:02 haison19952013