jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Design suggestion for better type check of has-aux

Open NeilGirdhar opened this issue 3 years ago • 7 comments

Currently, various optimizers accept a generic Callable function and parameters for specifying how it should be used:

class LBFGS:
  fun: Callable
  value_and_grad: bool = False
  has_aux: bool = False

So, the user has to carefully initialize making sure that these three variables are consistent.

Instead, it might be a better user experience to have:

class LBFGS(Generic[T]):
  fun: Callable[[T], RealNumeric] | ValueAuxAndGrad[T] | HasAux[T]
  # No need for flags!

given the definitions:

RealNumeric = jax.Array | npt.NDArray[np.floating[Any]] | float
T = TypeVar('T')

class _ValueAuxAndGradProtocol(Protocol, Generic[T]):
    def __call__(self, x: T, /, *args: Any, **kwargs: Any
                 ) -> tuple[tuple[RealNumeric, Any], T]:
      ...

class _HasAuxProtocol(Protocol, Generic[T]):
    def __call__(self, x: T, /, *args: Any, **kwargs: Any
                 ) -> tuple[RealNumeric, Any]:
      ...

@dataclass
class ValueAuxAndGrad(Generic[T]):
  fun: _ValueAuxAndGradProtocol[T]
  # Add appropriate methods based on how you use this

@dataclass
class HasAux(Generic[T]):
  fun: _HasAuxProtocol[T]
  # Add appropriate methods based on how you use this

This way, the function parmeters and return values are type-checked in all three cases. Also, there's no need for any flags. Instead of

LBFGS(f, has_aux=True, value_and_grad=True)

you can do

LBFGS(ValueAuxAndGrad(f))  # type-checked; fails if f doesn't return the appropriate tuple structure.

Even the simple use case would be type-checked:

LBFGS(f)  # type-checked; fails if f doesn't return an array or scalar float.

This would also allow transparent addition of ValueAndGrad one day, if it's desirable.

What do you think?

NeilGirdhar avatar Oct 22 '22 08:10 NeilGirdhar

(Just discovered this library by the way, and it's a real gem!!)

NeilGirdhar avatar Oct 22 '22 08:10 NeilGirdhar

Thanks a lot, this is an interesting suggestion.

One advantage of has_aux and value_and_grad is that those are familiar keywords for JAX users. I'm a bit concerned that your approach would make JAXopt too "frameworkish".

But we are aware that we should improve error checking, i.e., make sure that fun's output is compatible with what has been specified in the has_aux and value_and_grad options.

mblondel avatar Oct 24 '22 12:10 mblondel

@mblondel That's fair enough, but the benefit to doing this in Jax is smaller than thhe benefit of doing this in jaxopt.

First, Jax only has one flag whereas here you have two flags, and a somewhat complicated interaction that requires reading the documentation carefully when using them (if you only use one of the flags, which one are you allowed to use?—it's not obvious without reading the docs). I proposed this change so that instead of reading the documentation to understand what code does, it becomes self-documenting.

And second, the typing problem doesn't exist in Jax since you can overload on the flags. However, there's no such thing as a class overload, so you can never have explicity typing for the function in jax-opt without doing something like this.

Finally, since you just generalized the has-aux/value-and-grad to all the optimizers, I thought that this would be a good time to make any architectural changes before people start using the flags.

NeilGirdhar avatar Oct 24 '22 13:10 NeilGirdhar

I'm not saying your idea is bad but I would like to understand its pros and cons and I would like JAXopt to integrate well in the JAX ecosystem...

If we adopt objects to represent functions, this could also potentially be useful so that functions know how to initialize the parameters.

mblondel avatar Oct 24 '22 13:10 mblondel

Thanks for starting the discussion in google/jax#12948. This would have the advantage of being able to use jax.* operators instead of creating new ones in JAXopt.

mblondel avatar Oct 24 '22 14:10 mblondel

@mblondel Yeah, I started the discussion, but I wasn't sure if I should leave it open since the benefits are smaller in Jax (as I explained in my last comment). I'll reopen and maybe someone from the Jax team will comment though. There may be benefits that I'm missing.

NeilGirdhar avatar Oct 24 '22 14:10 NeilGirdhar

This would be useful to "annotate" the returned Callable with its structure.

mblondel avatar Oct 24 '22 14:10 mblondel