Feature Request: Support for astropy.units.Quantity types in array annotations
Feature Request: Support for astropy.units.Quantity types in array annotations
Hi @patrick-kidger, thanks a lot for creating such an excellent library!
I'd love to propose adding support for astropy.units.Quantity types, which would be a natural extension of current capabilities.
In scientific computing (and astronomy in particular) we frequently work with quantities that have both numerical values and physical units (e.g., distances in meters, masses in kilograms, times in seconds). The astropy.units.Quantity class is the de facto standard for handling dimensional quantities in Python astro community.
Currently, jaxtyping provides type safety for array shapes and dtypes, but there's no built-in way to annotate arrays that carry unit information. We are potentially missing out on a great opportunity for dimensional analysis at the type-checking level.
Proposed feature
Add support for Quantity types while preserving jaxtyping's excellent shape and dtype checking, e.g.:
from astropy.units import Quantity
from jaxtyping import Float
def calculate_physics(
velocities: Float[Quantity, "n_particles 3"], # 3D velocities with units
masses: Float[Quantity, " n_particles"], # masses with units
) -> Float[Quantity, " n_particles"]: # kinetic energy with units
...
return energy
more specifically, it would be useful to catch examples like these in type checking:
def bad_physics(
distance: Float[Quantity, " n"] = ..., # expects length units
mass: Float[Quantity, " n"] = ... # expects mass units
):
return distance + mass # Type error: can't add length + mass
I'm not sure how straightforward or practical an implementation would be, but happy to brainstorm together!
I agree this sounds pretty interesting! I suspect something like this is doable.
When writing out SomeDtype[SomeArray, some_shape], then jaxtyping is actually happy with anything that duck-types as an array for the SomeArray object – specifically we require only that it have .shape and .dtype attributes. So, if nothing else, then we could probably smuggle in the units via the jaxtyping .dtype.
One thing I do note is that your examples currently do not actually explicitly express the units in the type annotation, which I am guessing may be desirable. So as a non-jaxtyping alternative, then I imagine you could also wrap the annotations in something like the following:
def calculate_physics(velocities: HasUnits[Float[Quantity, "n_particles 3"], "m/s"]): ...
where (untested)
def _MetaHasUnits(type):
def __instancecheck__(cls, other):
return isinstance(other, cls.quantity) and other.units == cls.units
def _make_cls(quantity, units):
return _MetaHasUnits("_HasUnits", (), dict(quantity=quantity, units=units))
class HasUnits:
def __class_getitem__(cls, item):
quantity, units = item
return _make_cls(quantity, units)
IIUC your usecase is with numpy as the underlying array type. Nevertheless I'll also tag @nstarman here, who I know has thought about at least JAX+units before.
Thanks for the tag @patrick-kidger.
@joanna-pk
Yes, jaxtyping should already enable shape checking on astropy.Quantity.
For unit/dimension info I think it's a great idea to build a checker like @patrick-kidger suggested. I would tie it to the developing https://github.com/quantity-dev/metrology-apis, where Astropy and the other big libraries are collaborating on bringing our apis closer together.
If you're doing Jax things and want unit checking now, check out https://unxt.readthedocs.io/en/latest/ !
I agree this would be a nice feature to have. However in general unit handling and checking is not trivial. Typically one would like to check for equivalencies rather than exact correspondence, in some cases there are fractional units (m^(2/3) etc.) which require specifying numerical precision for equivalence, not sure how to handle this with types.
For context, there is already a decorator to validate units provide by Astropy, It does not rely on type annotations though:(https://docs.astropy.org/en/stable/api/astropy.units.quantity_input.html#quantity-input)