scico icon indicating copy to clipboard operation
scico copied to clipboard

Coherent policy needed on array types

Open bwohlberg opened this issue 3 years ago • 5 comments

This issue is related to both typing and explicit array conversion and creation (the latter form being closely related to the subject of #93). At the moment we have a somewhat random mix of Array and JaxArray type specifications within our code. A coherent policy is needed as a guide to addressing the existing random usage, and ensuring that it doesn't persist into new code.

The simplest policy is obviously to simply require that all scico functions and methods take and return JaxArray types. Is this both feasible and desirable? If not, what are the counter-examples?

bwohlberg avatar Jan 18 '22 22:01 bwohlberg

Make it all JaxArray and be done with it. But note that will really mess with static typing, as jax arrays are really aliased to Any.

Note also that using np.ndarray may not be the "correct" way to annotate a function that takes an ndarray input, cf the numpy typing module

See also the jax policy on type checking: https://github.com/google/jax/issues/8224 (tl, dr; make everything alias to Any)

lukepfister avatar Jan 18 '22 23:01 lukepfister

First, let's have a type, i.e. ScicoArray = Union[BlockArray, DeviceArray] or whatever we want to call it.

I think there are arrays that don't make sense to be ScicoArrays. For example, the angle list of a projector. It is rather an array-shaped parameter, not any different than the pixel-spacing of the projector (other than that is an array). If we choose to make this an ScicoArray then we may as well make the pixel-spacing a zero dimensional ScicoArray, but that would obviously not be sensible. These cases can be easily identified and should only be numpy arrays or python scalars.

Then there are operators, like f(x) = a B x where a is a scaler, B is a matrix, x is a vector. In my opinion B and x should be ScicoArrays but a should not be required to be a ScicoArray. In any case operators should always return ScicoArrays (related to #93).

tbalke avatar Jan 24 '22 19:01 tbalke

A good part of this issue has been addressed in #410, but the question of whether all scico functions should take and return jax arrays (as opposed to numpy arrays) is still open.

bwohlberg avatar May 03 '23 00:05 bwohlberg

It's worth noting the availability of jax.typing.ArrayLike

for any value that is safe to implicitly cast to a JAX array; this includes jax.Array, numpy.ndarray, as well as Python builtin numeric values (e.g. int, float, etc.) and numpy scalar values (e.g. numpy.int32, numpy.flota64, etc.)

Also, see section "JAX Typing Best Practices" in jax.typing.

bwohlberg avatar May 09 '23 12:05 bwohlberg

Suggestion: pre- and post-processing functions should take numpy or jax arrays, while functions expected to be used within optimization problems etc. should take jax arrays.

bwohlberg avatar May 22 '23 17:05 bwohlberg