Should get_namespace support more than arrays?
Background
Consider a probability distribution library such as efax. To add Array API support, each probability distribution class will contain a number of parameters, and it makes sense that they will all be from the same namespace. Thus, the a standard pattern for methods that use the parameters will be to query the namespace of self by feeding in all of the parameters to get_namespace.
This pattern is not unique to efax. I imagine it will pop up in SciPy's future distribution classes (that are being developed and will support the Array API). It could be added to any object exposing the "Jax PyTrees" interface (see the registry) or generally any aggregate structure with a homogenous set of arrays.
Motivation
To simplify getting the namespace in functions that interact with aggregate structure containing homogeneous sets of arrays.
Example
Suppose that Distribution is an aggregate structure containing array-valued parameters. Instead of:
def f(x: Distribution, y: Distribution, z: Array):
xp = x.get_namespace() # Call method to get namespace.
assert y.get_namespace() == xp # Call method and check that it's the same namespace.
assert get_namespace(z) == xp # Check that it's the same namespace.
we would like to simply have:
def f(x: Distribution, y: Distribution, z: Array):
xp = get_namespace(x, y, z) # One simple line
Proposal
Extend get_namespace(o) to first read o.__namespace_arrays__, which returns an iterable of arrays that get_namespace can use as before.
Thus, instead of aggregate structures proving a method that queries the namespace like this function, we would instead have
class Distribution:
def __namespace_arrays__(self) -> Iterable[Array]:
return (getattr(self, field.name) for field in fields(self))
A simple recursive extension of get_namespace is illustrated here.
Alternative proposal
One alternative is to support __array_namespace__ on all inputs to get_namespace. Thus, we would have
class Distribution:
def __array_namespace__(self, api_version: str, use_compat: bool) -> ArrayNameSpace:
return get_namespace(*(getattr(self, field.name) for field in fields(self)),
api_version=api_version,
use_compat=use_compat)
The problem with this is that it complicates extending the parameter specification of get_namespace.
In scikit-learn we have our own get_namespace function which uses the get_namespace of the Array API but also contains some useful stuff (that would be repeated in many places). https://github.com/scikit-learn/scikit-learn/blob/19c068f64249f95f745962b42a4dd581c7738218/sklearn/utils/_array_api.py#L473
Could you do something like that in efax? Or asked differently, aren't you going to end up having something like this sooner or later anyway, in which case it could also take care of dealing with "things that aren't arrays but contain them"?
How is the linked function related? It doesn't deal with aggregate structures, which is the motivation for this proposal.
get_namespace isn't actually part of the array API, it's part of the compat library. The array API defines x.__array_namespace__. The compat library get_namespace() (which is also called array_namespace()) is just a wrapper around calling this method which manually returns the compat layer namespace when necessary. Maybe this should be made clearer in the documentation.
I'm a little confused how your proposal would work. If a function takes a Distribution object, then that function already needs to know how to extract the relevant arrays from that object in order to use the array API on them, no?
In scikit-learn we have our own get_namespace function which uses the get_namespace of the Array API but also contains some useful stuff (that would be repeated in many places). scikit-learn/scikit-learn@19c068f/sklearn/utils/_array_api.py#L473
Happy to upstream some of those features to array_api_compat. We already implement some flags on top of __array_namespace__. A feature to skip certain types seems it should be generally useful and easy to implement.
I'm a little confused how your proposal would work. If a function takes a Distribution object, then that function already needs to know how to extract the relevant arrays from that object in order to use the array API on them, no?
Yes, but it's just a question of convenience. Sometimes, you have a method that accepts an aggregate object (say, self) and some arrays (say, x). I guess you could expand out the aggregate object into its component arrays and pass them to get_namespace(self.a, self.b, self.c, x). I'm proposing the convenience of get_namespace(self, x). It's just simplicity.
Happy to upstream some of those features to array_api_compat. We already implement some flags on top of
__array_namespace__. A feature to skip certain types seems it should be generally useful and easy to implement.
If you want to, go for it. No strong feelings from my side. I have a slight preference/find it nicer to keep the get_namespace in the compat library simple. At least I can see a future happening where it accumulates "all the useful things" from the various array consuming libraries and then becomes quite unwieldy.
The reason I linked to the custom get_namespace in scikit-learn is that it is an example of an array consuming library having a custom version of get_namespace that implements things that are convenient for it. efax could define its own get_namespace that makes dealing with the types that occur in efax convenient.
The reason I linked to the custom
get_namespacein scikit-learn is that it is an example of an array consuming library having a custom version ofget_namespacethat implements things that are convenient for it.
Ah, right, that makes sense!
efax could define its own
get_namespacethat makes dealing with the types that occur in efax convenient.
Right, which is what I'm doing. The reason I suggested upstreaming aggregate structure support is in case there are ever functions that accept aggregate structure types from different libraries.