Support add_loss (works currently for torch and tf, does NOT for jax)
This PR seeks to address https://github.com/bayesflow-org/bayesflow/issues/541.
It looks to me like we need to tweak JAXApproximator.stateless_compute_metrics for this to work in jax as well.
The other backends are already covered with just the changes in the initial commit.
EDIT: You can find an example in https://github.com/bayesflow-org/bayesflow/blob/add-loss/examples/Custom_losses_with_add_loss.ipynb
Codecov Report
Attention: Patch coverage is 75.00000% with 3 lines in your changes missing coverage. Please review.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| ...low/approximators/model_comparison_approximator.py | 50.00% | 3 Missing :warning: |
| Files with missing lines | Coverage Δ | |
|---|---|---|
| bayesflow/approximators/continuous_approximator.py | 91.45% <100.00%> (+0.22%) |
:arrow_up: |
| ...low/approximators/model_comparison_approximator.py | 83.90% <50.00%> (-1.30%) |
:arrow_down: |
I added tests and a minimal example notebook.
Tests are passing on torch and tensorflow, but fail on jax.
@LarsKue since you are the architect of the stateless_compute_metrics, could you look into how we can make this work for jax?
The final section of the keras guide on custom training loops in jax proves that this can be rather straight forward, but I am unsure how to implement it in our case: https://keras.io/guides/writing_a_custom_training_loop_in_jax/