einops
einops copied to clipboard
Helper function for mapping array shapes into dict
My sense in that in many cases the size of new axis should match the size of an existing axis on a different tensor.
I wonder if a helper function that used the same syntax as the rest of einops
for extracting axis size in a dict
would work well?
e.g.,
>>> einops.sizes(input, 'b h w c')
{'b': 32, 'h': 192, 'w': 192, 'c': 3}
This could be naturally extended into a multi-argument version that verifies consistent sizes, e.g.,
>>> einops.sizes(input, 'b h w c_in', weights, 'w h c_in c_out`)
{'b': 32, 'h': 192, 'w': 192, 'c_in': 3, 'c_out': 16}
The alternative is manual unpacking of shape
, e.g., b_size, h_size, w_size, c_size = input.shape
This is also pretty readable, but maybe a little harder to use reliably. For example, if you only care about the size of the batch axis, you would be tempted to write b_size, *_ = input.shape
or b_size = input.shape[0]
, which doesn't include the explicit shape assertion. And there's no easy way to check sizes for multiple arguments.
I should have checked! I see this already exists as einops.parse_shape
Perhaps the multi-argument version would be a nice addition? This could be spelled in either of two forms:
-
einops.sizes(input, 'b h w c_in', weights, 'w h c_in c_out')
-
einops.sizes([input, weights], 'b h w c_in, w h c_in c_out')
a bit late comment :)
There is a function that does extraction for a single array: https://github.com/arogozhnikov/einops/blob/ed9aafd209900cbd604675af3dd39f11d6446973/einops/einops.py#L535 it can skip dimensions (and in general was designed for extracting dimensions to pass to other operations).
However I've found myself doing manual unpacking of shapes in almost all cases (and now function is not mentioned in docs).
Dict outputs are a bit heavy on code when you use them downstream (compared to b_size, h_size, w_size, c_size = input.shape
)
Checking multiple arguments is potentially useful, need to analyze actual usecases.
Got back to proposal, second is more readable:
einops.sizes([input, weights], 'b h w c_in, w h c_in c_out')
but overlaps with ability to pass a list of tensors to rearrange
function, so first option is preferred (or tensors should be separate arguments). Like
einops.sizes(input, weights, 'b h w c_in, w h c_in c_out')
Caveat is shape checking, which for symbolic tensors would require baking check into a graph.
- for tf there is assert, doable
- keras officially dead as a framework, need to check TF asserts work properly there
- mxnet seemingly does not have an assert or any way to build check into the graph
- jax - seems no special support. But maybe some additional tools available?
@shoyer thoughts?
JAX does not have symbolic shapes. You can just raise a normal exception.
On Sat, Aug 29, 2020 at 11:02 PM Alex Rogozhnikov [email protected] wrote:
Got back to proposal, second is more readable:
einops.sizes([input, weights], 'b h w c_in, w h c_in c_out')
but overlaps with ability to pass a list of tensors to rearrange function, so first option is preferred (or tensors should be separate arguments). Like
einops.sizes(input, weights, 'b h w c_in, w h c_in c_out')
Caveat is shape checking, which for symbolic tensors would require baking check into a graph.
for tf there is assert, doable
keras officially dead as a framework, need to check TF asserts work properly there
mxnet seemingly does not have an assert or any way to build check into the graph
jax - seems no special support https://github.com/google/jax/issues/2273. But maybe some additional tools available?
@shoyer https://github.com/shoyer thoughts?
— You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub https://github.com/arogozhnikov/einops/issues/58#issuecomment-683380542, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVSMBDBF57TNYLPP4WDSDHTOBANCNFSM4QHTRPXQ .