deepxde
deepxde copied to clipboard
Towards SoftAdapt loss balancing for tf.compat.v1
Work in progress!
Implementing in TensorFlow is tricky, as it is static graph. It should be much easier to implement in pytorch, where you can directly change the loss_weights
value.
Implementing in TensorFlow is tricky, as it is static graph. It should be much easier to implement in pytorch, where you can directly change the loss_weights value.
Thank you for your feeback. I would really prefer to implement this adaptive loss callback in tensorflow.compat.v1
.
I think I'll start with a simple two-terms loss (and one weighing parameter).
Implementing in TensorFlow is tricky, as it is static graph. It should be much easier to implement in pytorch, where you can directly change the
loss_weights
value.
- It can be done if the
loss_weights
is the argument of the train_step() - If not, when we iteratively change the loss weights, we will need
tensorflow
to make the graph all over again. In other words, model will.compile()
again and the training might be slow.
Hi, if we define loss_weights
as Variable
, no need to compile several times, right?
Next, we have to define appopriately the total_loss.
Hi, if we define
loss_weights
asVariable
, no need to compile several times, right?Next, we have to define appopriately the [total_loss]https://github.com/lululxvi/deepxde/blob/85920299331bd7c0bad01f3d2abba442a77c89c6/deepxde/model.py#L244.
- My [last response] (https://github.com/lululxvi/deepxde/pull/1586#issuecomment-1920529102) might raise some confusion so I have corrected it.
- You are right, we can update it like that without recompiling. But, it's better to verify this again on a toy problem.
- For more details about implementation in
deepxde
, we might need to setloss_weights
as the arguments of outputs_losses() and then total_loss().