scico
scico copied to clipboard
Document and test `input_shape` and `output_shape` policy
Proposed input_shape
and output_shape
policy:
- All instantiated
Operator
s must haveinput_shape
andoutput_shape
properties - For an
Operator
H
,H(x)
should throw an error ifx
is notinput_shape
-
Operator
s should be written so thatH(x).shape
isH.output_shape
; this will not be checked and runtime, but should be tested -
Operator
s should attempt to automatically deduceinput_shape
andoutput_shape
from other arguments and throw an error if the user requests aninput_shape
oroutput_shape
that is not realizable - The user may need to specify
input_shape
and/oroutput_shape
- The user may specify
input_shape
and/oroutput_shape
when not needed to prevent them from being automatically deduced (for speed)
This policy should be described in the docs and implemented in the code.
Why this? While some operators can conceptually work with different input sizes, trying to allow this creates downstream problems, e.g., what is the adjoint of a sum operator (one that sums all the elements of a vector) applied to the scalar 1.0? How does one initialize an optimization routine without knowing input shapes for operators?
💯 also worth noting that functions are autograd/jitted for a particular input size/shape and must be recompiled for other shapes.
some functions only need a recompile (jitted things) but iirc the adjoint will straight fail if it is applied to a different shape than the one used in the jax adjoint function
adjoint will straight fail if it is applied to a different shape than the one used in the jax adjoint function
I can confirm this is true. Also true if the dtype doesn't match. Issue on input_dtype coming soon.
(move this comment to your new issue whenever it shows up)
we discussed having strict input_dtype checking a few times. it gets a little messy though for complex valued functions. Like if F
is a DFT then it is naturally F : C -> C
. you can make it F : R -> C
but that's a bad choice for most problems, like MRI, where you have a non-symmetric sampling pattern in fourier space. the adjoint gets all screwy.
but if you have something like F : C -> C
then something like F(abs(x))
doesn't work without an explicit cast to complex
so right now there is strict type checking on the adjoint, but not the forward evaluation. this policy is documented here