DESC icon indicating copy to clipboard operation
DESC copied to clipboard

Add `ShareParameters` Linear Objective

Open dpanici opened this issue 1 year ago • 3 comments

If someone wants to help with any of below feel free, I just got motivated for the initial implementation.

  • Adds ShareParameters linear objective, which takes in a list of arbitrary length of objects and a dict of params to fix which
    • [x] must all be the same type (implemented check)
    • [x] be the same resolutions (i.e. equal thing.dimensions). This might be too strict of a check though as technically only need shared params to be same resolution, but I think I am ok with this as there might be unintended behavior if other resolutions are unequal. If there is a good use case for the latter I can think harder about it.

The objective then will fix the indices of the parameters of each object to the same value as each other object.

TODO:

  • [x] make sure implementation works for pytree objects such as CoilSet (tests added)
  • [x] add check for params lengths being the same across objects (technically, we also should be checking that, say, if p_l is being fixed btwn two equilibria, that they are also the same TYPE of pressure profile, but not sure how we can go about doing that without having very individualized logic like _check_type does in coils.py)
  • [ ] Add test for fixing only specific indices of a param
  • [ ] Add more test for coils (to test pytree objects, I am not sure of all the use case right now and so can't think of any good tests)

Resolves #1250

dpanici avatar Oct 24 '24 03:10 dpanici

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     -0.54 +/- 6.69     | -3.15e-03 +/- 3.92e-02 |  5.84e-01 +/- 9.4e-03  |  5.87e-01 +/- 3.8e-02  |
 test_equilibrium_init_medres            |     +3.00 +/- 3.79     | +1.36e-01 +/- 1.72e-01 |  4.67e+00 +/- 1.2e-01  |  4.53e+00 +/- 1.2e-01  |
 test_equilibrium_init_highres           |     +1.81 +/- 2.77     | +9.32e-02 +/- 1.43e-01 |  5.25e+00 +/- 1.3e-01  |  5.16e+00 +/- 6.8e-02  |
 test_objective_compile_dshape_current   |     +0.89 +/- 2.17     | +3.17e-02 +/- 7.74e-02 |  3.60e+00 +/- 4.9e-02  |  3.57e+00 +/- 6.0e-02  |
 test_objective_compute_dshape_current   |     -0.38 +/- 2.55     | -1.31e-05 +/- 8.75e-05 |  3.41e-03 +/- 5.7e-05  |  3.43e-03 +/- 6.6e-05  |
 test_objective_jac_dshape_current       |     +0.54 +/- 11.63    | +1.66e-04 +/- 3.57e-03 |  3.09e-02 +/- 2.4e-03  |  3.07e-02 +/- 2.7e-03  |
 test_perturb_2                          |     -0.26 +/- 1.74     | -4.75e-02 +/- 3.12e-01 |  1.79e+01 +/- 2.0e-01  |  1.80e+01 +/- 2.4e-01  |
 test_proximal_jac_atf_with_eq_update    |     +0.67 +/- 0.70     | +1.02e-01 +/- 1.06e-01 |  1.53e+01 +/- 5.8e-02  |  1.52e+01 +/- 8.9e-02  |
 test_proximal_freeb_jac                 |     -1.60 +/- 10.26    | -8.10e-02 +/- 5.19e-01 |  4.98e+00 +/- 3.6e-01  |  5.06e+00 +/- 3.7e-01  |
 test_solve_fixed_iter_compiled          |     -1.18 +/- 2.42     | -2.13e-01 +/- 4.38e-01 |  1.79e+01 +/- 3.1e-01  |  1.81e+01 +/- 3.1e-01  |
 test_LinearConstraintProjection_build   |     -1.20 +/- 1.85     | -1.05e-01 +/- 1.61e-01 |  8.61e+00 +/- 1.1e-01  |  8.71e+00 +/- 1.2e-01  |
 test_objective_compute_ripple_spline    |     +1.24 +/- 7.81     | +3.93e-03 +/- 2.47e-02 |  3.20e-01 +/- 2.1e-02  |  3.16e-01 +/- 1.3e-02  |
 test_objective_grad_ripple_spline       |     +0.13 +/- 3.60     | +1.63e-03 +/- 4.48e-02 |  1.24e+00 +/- 2.2e-02  |  1.24e+00 +/- 3.9e-02  |
 test_build_transform_fft_midres         |     +3.17 +/- 3.57     | +2.30e-02 +/- 2.59e-02 |  7.48e-01 +/- 2.3e-02  |  7.25e-01 +/- 1.3e-02  |
 test_build_transform_fft_highres        |     +1.56 +/- 2.82     | +1.57e-02 +/- 2.82e-02 |  1.02e+00 +/- 2.3e-02  |  1.00e+00 +/- 1.6e-02  |
 test_equilibrium_init_lowres            |     -0.77 +/- 3.87     | -3.40e-02 +/- 1.70e-01 |  4.36e+00 +/- 1.4e-01  |  4.40e+00 +/- 9.7e-02  |
 test_objective_compile_atf              |     -1.23 +/- 1.88     | -7.98e-02 +/- 1.22e-01 |  6.42e+00 +/- 9.8e-02  |  6.50e+00 +/- 7.3e-02  |
 test_objective_compute_atf              |     -3.03 +/- 13.55    | -2.63e-04 +/- 1.18e-03 |  8.43e-03 +/- 5.2e-04  |  8.69e-03 +/- 1.1e-03  |
 test_objective_jac_atf                  |     -0.07 +/- 2.93     | -1.05e-03 +/- 4.61e-02 |  1.57e+00 +/- 3.4e-02  |  1.57e+00 +/- 3.1e-02  |
 test_perturb_1                          |     +0.12 +/- 1.72     | +1.72e-02 +/- 2.54e-01 |  1.48e+01 +/- 2.1e-01  |  1.48e+01 +/- 1.4e-01  |
 test_proximal_jac_atf                   |     +1.12 +/- 1.34     | +8.56e-02 +/- 1.02e-01 |  7.72e+00 +/- 7.1e-02  |  7.64e+00 +/- 7.3e-02  |
 test_proximal_freeb_compute             |     +1.39 +/- 4.15     | +2.49e-03 +/- 7.42e-03 |  1.81e-01 +/- 5.5e-03  |  1.79e-01 +/- 5.0e-03  |
 test_solve_fixed_iter                   |     -1.08 +/- 1.68     | -3.32e-01 +/- 5.18e-01 |  3.05e+01 +/- 3.9e-01  |  3.09e+01 +/- 3.4e-01  |
 test_objective_compute_ripple           |     -0.30 +/- 0.80     | -8.16e-03 +/- 2.19e-02 |  2.75e+00 +/- 1.4e-02  |  2.76e+00 +/- 1.7e-02  |
 test_objective_grad_ripple              |     +1.74 +/- 1.42     | +8.68e-02 +/- 7.09e-02 |  5.08e+00 +/- 5.6e-02  |  4.99e+00 +/- 4.3e-02  |

github-actions[bot] avatar Oct 24 '24 03:10 github-actions[bot]

Codecov Report

:white_check_mark: All modified and coverable lines are covered by tests. :white_check_mark: Project coverage is 95.78%. Comparing base (faba7ec) to head (2ab0a87). :warning: Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1320      +/-   ##
==========================================
+ Coverage   95.77%   95.78%   +0.01%     
==========================================
  Files         101      101              
  Lines       27757    27794      +37     
==========================================
+ Hits        26583    26622      +39     
+ Misses       1174     1172       -2     
Files with missing lines Coverage Δ
desc/backend.py 89.69% <ø> (ø)
desc/objectives/__init__.py 100.00% <ø> (ø)
desc/objectives/linear_objectives.py 97.27% <100.00%> (+0.14%) :arrow_up:
desc/objectives/objective_funs.py 94.67% <ø> (ø)

... and 2 files with indirect coverage changes

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Oct 24 '24 04:10 codecov[bot]

OUTDATED

This would help a lot for multiple GPU stuff. When I put different ForceBalance objectives to different GPUs, optimizer thinks there are multiple different equilibrium (because jax.device_put() creates a copy of the equilibrium, so although they are the same at that moment, their memory address is different). I tried to prevent this by overriding the obj.things[0] = eq_replicated_to_all_devices but Jax complains about this saying that inputs of the jitted function should live on the same device (since we close on self, different portions of the objective live on different device).

        eq = jax.device_put(eq, desc_config["sharding_replicated"])
        obj = ForceBalance(eq, grid=grid)
        obj = jax.device_put(obj, jax.devices("gpu")[i])
        # if the eq is also distrubuted across GPUs, then some internal logic that
        # checks if the things are different will fail, so we need to set the eq
        # to be the same manually
        obj._things[0] = eq
        obj.build(use_jit=use_jit)
        objs += (obj,)

This PR would help to just use different equilibrium but link their parameters.

YigitElma avatar Feb 07 '25 19:02 YigitElma

@dpanici

dpanici avatar Aug 20 '25 19:08 dpanici

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    4.08 %    |     3.909e+03      |     4.068e+03      |    159.60    |       37.96        |       36.15        |
  test_proximal_jac_w7x_with_eq_update   |   -3.30 %    |     6.576e+03      |     6.359e+03      |   -217.15    |       163.10       |       162.09       |
  test_proximal_freeb_jac                |   -0.22 %    |     1.320e+04      |     1.317e+04      |    -29.46    |       84.33        |       83.73        |
  test_proximal_freeb_jac_blocked        |   -0.29 %    |     7.485e+03      |     7.463e+03      |    -22.08    |       72.78        |       74.31        |
  test_proximal_freeb_jac_batched        |    0.80 %    |     7.419e+03      |     7.479e+03      |    59.20     |       73.60        |       74.72        |
  test_proximal_jac_ripple               |   -0.64 %    |     3.465e+03      |     3.443e+03      |    -22.04    |       65.09        |       66.61        |
  test_proximal_jac_ripple_bounce1d      |    0.33 %    |     3.576e+03      |     3.588e+03      |    11.95     |       75.44        |       78.27        |
  test_eq_solve                          |   -4.75 %    |     2.071e+03      |     1.972e+03      |    -98.39    |       93.87        |       93.98        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

github-actions[bot] avatar Sep 15 '25 19:09 github-actions[bot]

Getting a weird static attr-related error from one of the tests involving using this with an optimizable collection

import numpy as np
from desc.coils import CoilSet, FourierXYZCoil
from desc.objectives import (
    ObjectiveFunction,
    ShareParameters,
)

coils1 = CoilSet.linspaced_angular(FourierXYZCoil(), n=2)
coils2 = coils1.copy()

subobj = ShareParameters([coils1, coils2], {"X_n": True, "Y_n": True, "Z_n": [2]})
subobj.build()
obj = ObjectiveFunction(subobj)
obj.build()

# check dimensions
# dim_f should be 2 (for the 2 subcoils in each coilset) x 2 (for X_n, Y_n)
# x params["X_n"].size + 2 x 1 (Z_n) x 1 (bc only fixed idx=2)
assert subobj.dim_f == 2 * 2 * coils1.params_dict[0]["X_n"].size + 2
np.testing.assert_allclose(subobj.target, 0)

# check compute
np.testing.assert_allclose(obj.compute_unscaled(obj.x(coils1, coils2)), 0) # errors with message below

I am running in jupyter and if I run the cell above a second time, it works with no error... not sure at all what is happening. This obj has the same static attr as FixParameters, and it seems like the problem is within the static attr list itself which makes no sense to me

error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 16
     13 np.testing.assert_allclose(subobj.target, 0)
     15 # check compute
---> 16 np.testing.assert_allclose(obj.compute_unscaled(obj.x(coils1, coils2)), 0)
     18 # check the jacobian
     19 J = obj.jac_unscaled(obj.x(coils1, coils2))

    [... skipping hidden 5 frame]

File [~/miniconda3/envs/descenv/lib/python3.13/site-packages/jax/_src/pjit.py:731](http://localhost:8889/lab/tree/~/miniconda3/envs/descenv/lib/python3.13/site-packages/jax/_src/pjit.py#line=730), in _infer_input_type(fun, dbg, explicit_args)
    729 except TypeError:
    730   arg_description = f"path {dbg.arg_names[i]}"  # pytype: disable=name-error
--> 731   raise TypeError(
    732     f"Error interpreting argument to {fun} as an abstract array."
    733     f" The problematic value is of type {type(x)} and was passed to"  # pytype: disable=name-error
    734     f" the function at {arg_description}.\n"
    735     "This typically means that a jit-wrapped function was called with a non-array"
    736     " argument, and this argument was not marked as static using the"
    737     " static_argnums or static_argnames parameters of jax.jit."
    738   ) from None
    739 if config.mutable_array_checks.value:
    740   _check_no_aliased_ref_args(dbg, avals, explicit_args)

TypeError: Error interpreting argument to <function ObjectiveFunction.compute_unscaled at 0x12d9637e0> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path self[0]['_objectives'][0][0]['_static_attrs'][0].
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

Interestingly, if I don't build the subobj first before wrapping in the ObjectiveFunction, it works without error. Really unsure of why this would be the case though, somehow setting _built in the subobj is causing issues?

dpanici avatar Sep 21 '25 17:09 dpanici

@dpanici I had a similar error in multidevice PR and I ended up adding "_static_attrs" to the "_static_attrs" list. For both objective function and _Objective. I also couldn't find the exact reason, but that seemed to solve the problem

YigitElma avatar Sep 21 '25 17:09 YigitElma

Another issue that I know the reason now but am unsure of the resolution, tho it is an edge case: for things like normal where we normalize the vector when the attribute is set, if I for instance do ShareParameters([coil1, coil2], {"normal":True}) everything works perfect, if I do though ShareParameters([coil1, coil2], {"normal":np.array([1])}) to only make the index-1 compenent of the normal the same, the constraint works correctly, but when the attributes are then set at the end of the optimization, they are normalized in the setter, and the 0th and 2nd components may differ (as they were not shared), and thus the resulting normalized vector may have a different 1-component.

This is sort of a useless edge case, but still unsure of how to fully resolve a case like this.

dpanici avatar Sep 22 '25 00:09 dpanici

Also don't forget to add the new objective to the docs

f0uriest avatar Oct 01 '25 13:10 f0uriest

@f0uriest if you have time to review, we can also try to get this in before the new pip release.

YigitElma avatar Nov 14 '25 16:11 YigitElma