DESC
DESC copied to clipboard
Add `ShareParameters` Linear Objective
If someone wants to help with any of below feel free, I just got motivated for the initial implementation.
- Adds
ShareParameterslinear objective, which takes in a list of arbitrary length of objects and a dict ofparamsto fix which- [x] must all be the same type (implemented check)
- [x] be the same resolutions (i.e. equal
thing.dimensions). This might be too strict of a check though as technically only need shared params to be same resolution, but I think I am ok with this as there might be unintended behavior if other resolutions are unequal. If there is a good use case for the latter I can think harder about it.
The objective then will fix the indices of the parameters of each object to the same value as each other object.
TODO:
- [x] make sure implementation works for pytree objects such as
CoilSet(tests added) - [x] add check for
paramslengths being the same across objects (technically, we also should be checking that, say, ifp_lis being fixed btwn two equilibria, that they are also the same TYPE of pressure profile, but not sure how we can go about doing that without having very individualized logic like_check_typedoes incoils.py) - [ ] Add test for fixing only specific indices of a param
- [ ] Add more test for coils (to test pytree objects, I am not sure of all the use case right now and so can't think of any good tests)
Resolves #1250
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | -0.54 +/- 6.69 | -3.15e-03 +/- 3.92e-02 | 5.84e-01 +/- 9.4e-03 | 5.87e-01 +/- 3.8e-02 |
test_equilibrium_init_medres | +3.00 +/- 3.79 | +1.36e-01 +/- 1.72e-01 | 4.67e+00 +/- 1.2e-01 | 4.53e+00 +/- 1.2e-01 |
test_equilibrium_init_highres | +1.81 +/- 2.77 | +9.32e-02 +/- 1.43e-01 | 5.25e+00 +/- 1.3e-01 | 5.16e+00 +/- 6.8e-02 |
test_objective_compile_dshape_current | +0.89 +/- 2.17 | +3.17e-02 +/- 7.74e-02 | 3.60e+00 +/- 4.9e-02 | 3.57e+00 +/- 6.0e-02 |
test_objective_compute_dshape_current | -0.38 +/- 2.55 | -1.31e-05 +/- 8.75e-05 | 3.41e-03 +/- 5.7e-05 | 3.43e-03 +/- 6.6e-05 |
test_objective_jac_dshape_current | +0.54 +/- 11.63 | +1.66e-04 +/- 3.57e-03 | 3.09e-02 +/- 2.4e-03 | 3.07e-02 +/- 2.7e-03 |
test_perturb_2 | -0.26 +/- 1.74 | -4.75e-02 +/- 3.12e-01 | 1.79e+01 +/- 2.0e-01 | 1.80e+01 +/- 2.4e-01 |
test_proximal_jac_atf_with_eq_update | +0.67 +/- 0.70 | +1.02e-01 +/- 1.06e-01 | 1.53e+01 +/- 5.8e-02 | 1.52e+01 +/- 8.9e-02 |
test_proximal_freeb_jac | -1.60 +/- 10.26 | -8.10e-02 +/- 5.19e-01 | 4.98e+00 +/- 3.6e-01 | 5.06e+00 +/- 3.7e-01 |
test_solve_fixed_iter_compiled | -1.18 +/- 2.42 | -2.13e-01 +/- 4.38e-01 | 1.79e+01 +/- 3.1e-01 | 1.81e+01 +/- 3.1e-01 |
test_LinearConstraintProjection_build | -1.20 +/- 1.85 | -1.05e-01 +/- 1.61e-01 | 8.61e+00 +/- 1.1e-01 | 8.71e+00 +/- 1.2e-01 |
test_objective_compute_ripple_spline | +1.24 +/- 7.81 | +3.93e-03 +/- 2.47e-02 | 3.20e-01 +/- 2.1e-02 | 3.16e-01 +/- 1.3e-02 |
test_objective_grad_ripple_spline | +0.13 +/- 3.60 | +1.63e-03 +/- 4.48e-02 | 1.24e+00 +/- 2.2e-02 | 1.24e+00 +/- 3.9e-02 |
test_build_transform_fft_midres | +3.17 +/- 3.57 | +2.30e-02 +/- 2.59e-02 | 7.48e-01 +/- 2.3e-02 | 7.25e-01 +/- 1.3e-02 |
test_build_transform_fft_highres | +1.56 +/- 2.82 | +1.57e-02 +/- 2.82e-02 | 1.02e+00 +/- 2.3e-02 | 1.00e+00 +/- 1.6e-02 |
test_equilibrium_init_lowres | -0.77 +/- 3.87 | -3.40e-02 +/- 1.70e-01 | 4.36e+00 +/- 1.4e-01 | 4.40e+00 +/- 9.7e-02 |
test_objective_compile_atf | -1.23 +/- 1.88 | -7.98e-02 +/- 1.22e-01 | 6.42e+00 +/- 9.8e-02 | 6.50e+00 +/- 7.3e-02 |
test_objective_compute_atf | -3.03 +/- 13.55 | -2.63e-04 +/- 1.18e-03 | 8.43e-03 +/- 5.2e-04 | 8.69e-03 +/- 1.1e-03 |
test_objective_jac_atf | -0.07 +/- 2.93 | -1.05e-03 +/- 4.61e-02 | 1.57e+00 +/- 3.4e-02 | 1.57e+00 +/- 3.1e-02 |
test_perturb_1 | +0.12 +/- 1.72 | +1.72e-02 +/- 2.54e-01 | 1.48e+01 +/- 2.1e-01 | 1.48e+01 +/- 1.4e-01 |
test_proximal_jac_atf | +1.12 +/- 1.34 | +8.56e-02 +/- 1.02e-01 | 7.72e+00 +/- 7.1e-02 | 7.64e+00 +/- 7.3e-02 |
test_proximal_freeb_compute | +1.39 +/- 4.15 | +2.49e-03 +/- 7.42e-03 | 1.81e-01 +/- 5.5e-03 | 1.79e-01 +/- 5.0e-03 |
test_solve_fixed_iter | -1.08 +/- 1.68 | -3.32e-01 +/- 5.18e-01 | 3.05e+01 +/- 3.9e-01 | 3.09e+01 +/- 3.4e-01 |
test_objective_compute_ripple | -0.30 +/- 0.80 | -8.16e-03 +/- 2.19e-02 | 2.75e+00 +/- 1.4e-02 | 2.76e+00 +/- 1.7e-02 |
test_objective_grad_ripple | +1.74 +/- 1.42 | +8.68e-02 +/- 7.09e-02 | 5.08e+00 +/- 5.6e-02 | 4.99e+00 +/- 4.3e-02 |
Codecov Report
:white_check_mark: All modified and coverable lines are covered by tests.
:white_check_mark: Project coverage is 95.78%. Comparing base (faba7ec) to head (2ab0a87).
:warning: Report is 1 commits behind head on master.
Additional details and impacted files
@@ Coverage Diff @@
## master #1320 +/- ##
==========================================
+ Coverage 95.77% 95.78% +0.01%
==========================================
Files 101 101
Lines 27757 27794 +37
==========================================
+ Hits 26583 26622 +39
+ Misses 1174 1172 -2
| Files with missing lines | Coverage Δ | |
|---|---|---|
| desc/backend.py | 89.69% <ø> (ø) |
|
| desc/objectives/__init__.py | 100.00% <ø> (ø) |
|
| desc/objectives/linear_objectives.py | 97.27% <100.00%> (+0.14%) |
:arrow_up: |
| desc/objectives/objective_funs.py | 94.67% <ø> (ø) |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
OUTDATED
This would help a lot for multiple GPU stuff. When I put different ForceBalance objectives to different GPUs, optimizer thinks there are multiple different equilibrium (because jax.device_put() creates a copy of the equilibrium, so although they are the same at that moment, their memory address is different). I tried to prevent this by overriding the obj.things[0] = eq_replicated_to_all_devices but Jax complains about this saying that inputs of the jitted function should live on the same device (since we close on self, different portions of the objective live on different device).
eq = jax.device_put(eq, desc_config["sharding_replicated"])
obj = ForceBalance(eq, grid=grid)
obj = jax.device_put(obj, jax.devices("gpu")[i])
# if the eq is also distrubuted across GPUs, then some internal logic that
# checks if the things are different will fail, so we need to set the eq
# to be the same manually
obj._things[0] = eq
obj.build(use_jit=use_jit)
objs += (obj,)
This PR would help to just use different equilibrium but link their parameters.
@dpanici
Memory benchmark result
| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 4.08 % | 3.909e+03 | 4.068e+03 | 159.60 | 37.96 | 36.15 |
test_proximal_jac_w7x_with_eq_update | -3.30 % | 6.576e+03 | 6.359e+03 | -217.15 | 163.10 | 162.09 |
test_proximal_freeb_jac | -0.22 % | 1.320e+04 | 1.317e+04 | -29.46 | 84.33 | 83.73 |
test_proximal_freeb_jac_blocked | -0.29 % | 7.485e+03 | 7.463e+03 | -22.08 | 72.78 | 74.31 |
test_proximal_freeb_jac_batched | 0.80 % | 7.419e+03 | 7.479e+03 | 59.20 | 73.60 | 74.72 |
test_proximal_jac_ripple | -0.64 % | 3.465e+03 | 3.443e+03 | -22.04 | 65.09 | 66.61 |
test_proximal_jac_ripple_bounce1d | 0.33 % | 3.576e+03 | 3.588e+03 | 11.95 | 75.44 | 78.27 |
test_eq_solve | -4.75 % | 2.071e+03 | 1.972e+03 | -98.39 | 93.87 | 93.98 |
For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.
Getting a weird static attr-related error from one of the tests involving using this with an optimizable collection
import numpy as np
from desc.coils import CoilSet, FourierXYZCoil
from desc.objectives import (
ObjectiveFunction,
ShareParameters,
)
coils1 = CoilSet.linspaced_angular(FourierXYZCoil(), n=2)
coils2 = coils1.copy()
subobj = ShareParameters([coils1, coils2], {"X_n": True, "Y_n": True, "Z_n": [2]})
subobj.build()
obj = ObjectiveFunction(subobj)
obj.build()
# check dimensions
# dim_f should be 2 (for the 2 subcoils in each coilset) x 2 (for X_n, Y_n)
# x params["X_n"].size + 2 x 1 (Z_n) x 1 (bc only fixed idx=2)
assert subobj.dim_f == 2 * 2 * coils1.params_dict[0]["X_n"].size + 2
np.testing.assert_allclose(subobj.target, 0)
# check compute
np.testing.assert_allclose(obj.compute_unscaled(obj.x(coils1, coils2)), 0) # errors with message below
I am running in jupyter and if I run the cell above a second time, it works with no error... not sure at all what is happening. This obj has the same static attr as FixParameters, and it seems like the problem is within the static attr list itself which makes no sense to me
error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[2], line 16
13 np.testing.assert_allclose(subobj.target, 0)
15 # check compute
---> 16 np.testing.assert_allclose(obj.compute_unscaled(obj.x(coils1, coils2)), 0)
18 # check the jacobian
19 J = obj.jac_unscaled(obj.x(coils1, coils2))
[... skipping hidden 5 frame]
File [~/miniconda3/envs/descenv/lib/python3.13/site-packages/jax/_src/pjit.py:731](http://localhost:8889/lab/tree/~/miniconda3/envs/descenv/lib/python3.13/site-packages/jax/_src/pjit.py#line=730), in _infer_input_type(fun, dbg, explicit_args)
729 except TypeError:
730 arg_description = f"path {dbg.arg_names[i]}" # pytype: disable=name-error
--> 731 raise TypeError(
732 f"Error interpreting argument to {fun} as an abstract array."
733 f" The problematic value is of type {type(x)} and was passed to" # pytype: disable=name-error
734 f" the function at {arg_description}.\n"
735 "This typically means that a jit-wrapped function was called with a non-array"
736 " argument, and this argument was not marked as static using the"
737 " static_argnums or static_argnames parameters of jax.jit."
738 ) from None
739 if config.mutable_array_checks.value:
740 _check_no_aliased_ref_args(dbg, avals, explicit_args)
TypeError: Error interpreting argument to <function ObjectiveFunction.compute_unscaled at 0x12d9637e0> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path self[0]['_objectives'][0][0]['_static_attrs'][0].
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.
Interestingly, if I don't build the subobj first before wrapping in the ObjectiveFunction, it works without error. Really unsure of why this would be the case though, somehow setting _built in the subobj is causing issues?
@dpanici I had a similar error in multidevice PR and I ended up adding "_static_attrs" to the "_static_attrs" list. For both objective function and _Objective. I also couldn't find the exact reason, but that seemed to solve the problem
Another issue that I know the reason now but am unsure of the resolution, tho it is an edge case: for things like normal where we normalize the vector when the attribute is set, if I for instance do ShareParameters([coil1, coil2], {"normal":True}) everything works perfect, if I do though ShareParameters([coil1, coil2], {"normal":np.array([1])}) to only make the index-1 compenent of the normal the same, the constraint works correctly, but when the attributes are then set at the end of the optimization, they are normalized in the setter, and the 0th and 2nd components may differ (as they were not shared), and thus the resulting normalized vector may have a different 1-component.
This is sort of a useless edge case, but still unsure of how to fully resolve a case like this.
Also don't forget to add the new objective to the docs
@f0uriest if you have time to review, we can also try to get this in before the new pip release.