scico
scico copied to clipboard
Coherent policy needed on array types
This issue is related to both typing and explicit array conversion and creation (the latter form being closely related to the subject of #93). At the moment we have a somewhat random mix of Array
and JaxArray
type specifications within our code. A coherent policy is needed as a guide to addressing the existing random usage, and ensuring that it doesn't persist into new code.
The simplest policy is obviously to simply require that all scico functions and methods take and return JaxArray
types. Is this both feasible and desirable? If not, what are the counter-examples?
Make it all JaxArray
and be done with it. But note that will really mess with static typing, as jax arrays are really aliased to Any.
Note also that using np.ndarray
may not be the "correct" way to annotate a function that takes an ndarray input, cf the numpy typing module
See also the jax policy on type checking: https://github.com/google/jax/issues/8224 (tl, dr; make everything alias to Any)
First, let's have a type, i.e. ScicoArray = Union[BlockArray, DeviceArray]
or whatever we want to call it.
I think there are arrays that don't make sense to be ScicoArray
s. For example, the angle list of a projector. It is rather an array-shaped parameter, not any different than the pixel-spacing of the projector (other than that is an array). If we choose to make this an ScicoArray
then we may as well make the pixel-spacing a zero dimensional ScicoArray
, but that would obviously not be sensible. These cases can be easily identified and should only be numpy arrays or python scalars.
Then there are operators, like f(x) = a B x
where a
is a scaler, B
is a matrix, x
is a vector.
In my opinion B
and x
should be ScicoArray
s but a
should not be required to be a ScicoArray
. In any case operators should always return ScicoArray
s (related to #93).
A good part of this issue has been addressed in #410, but the question of whether all scico
functions should take and return jax arrays (as opposed to numpy arrays) is still open.
It's worth noting the availability of jax.typing.ArrayLike
for any value that is safe to implicitly cast to a JAX array; this includes
jax.Array
,numpy.ndarray
, as well as Python builtin numeric values (e.g.int
,float
, etc.) and numpy scalar values (e.g.numpy.int32
,numpy.flota64
, etc.)
Also, see section "JAX Typing Best Practices" in jax.typing.
Suggestion: pre- and post-processing functions should take numpy or jax arrays, while functions expected to be used within optimization problems etc. should take jax arrays.