jax copied to clipboard
BUG: [jax2tf] Strided pooling with polymorphic shape sometimes fails.
Pooling operations sometimes fail to convert. It looks like a None dimension is sometimes slipping through the cracks. The bug depends on the stride value... I'm using 'framed' inputs to ensure that striding evenly divides the input size.
Here's a minimal colab notebook replication: https://colab.corp.google.com/drive/1FX99EPcaX-1mAVnpnpUQwkkR0WZ_6o3h#scrollTo=NzSMa2VWhzWq
TypeError: in user code:
File "<ipython-input-2-e6d49ad967e7>", line 26, in None *
lambda inputs: converted_infer_fn(
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 302, in fun_no_kwargs *
return fun(*args, **kwargs)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 534, in _interpret_fun *
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = _call_wrapped_with_new_constant_cache(fun, in_vals,
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 688, in _call_wrapped_with_new_constant_cache *
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = fun.call_wrapped(*in_vals)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/linear_util.py", line 168, in call_wrapped *
ans = self.f(*args, **dict(self.params, **kwargs))
File "<ipython-input-23-a74605bc8077>", line 11, in infer_fn *
pooled = nn.pooling.avg_pool(
File "/tmp/lyra_notebook.par/google3/third_party/py/flax/linen/pooling.py", line 72, in avg_pool *
y = pool(inputs, 0., lax.add, window_shape, strides, padding)
File "/tmp/lyra_notebook.par/google3/third_party/py/flax/linen/pooling.py", line 52, in pool *
y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/windowed_reductions.py", line 79, in reduce_window *
return monoid_reducer(operand, window_dimensions, window_strides, padding,
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/windowed_reductions.py", line 127, in _reduce_window_sum *
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/core.py", line 324, in bind *
return self.bind_with_trace(find_top_trace(args), args, params)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/core.py", line 327, in bind_with_trace *
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 977, in invoke_impl *
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/impl_no_xla.py", line 553, in _reduce_window *
tf_padding = pads_to_padtype(operand.shape, window_dimensions, window_strides, padding)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/impl_no_xla.py", line 115, in pads_to_padtype *
pads = lax.padtype_to_pads(in_shape, window_shape, window_strides, pad_str)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/lax.py", line 4527, in padtype_to_pads *
out_shape = _ceil_divide(in_shape, window_strides)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/lax.py", line 4512, in _ceil_divide *
return -np.floor_divide(np.negative(x1), x2)
TypeError: bad operand type for unary -: 'NoneType'
What jax/jaxlib version are you using?
Which accelerator(s) are you using?
- [X] CPU
- [ ] GPU
- [X] TPU
Additional System Info
No response
It seems that this is only for the case enable_xla=False.
@marcvanzee PTAL
This may have been fixed incidentally by a very recent change #11816. Can you please try again at HEAD?
Ah, that's great! Still getting an error, but the message has changed.
TypeError: in user code:
File "<ipython-input-2-e6d49ad967e7>", line 26, in None *
lambda inputs: converted_infer_fn(
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 302, in fun_no_kwargs *
return fun(*args, **kwargs)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 534, in _interpret_fun *
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = _call_wrapped_with_new_constant_cache(fun, in_vals,
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 688, in _call_wrapped_with_new_constant_cache *
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = fun.call_wrapped(*in_vals)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/linear_util.py", line 168, in call_wrapped *
ans = self.f(*args, **dict(self.params, **kwargs))
File "<ipython-input-3-587f6ee978a3>", line 13, in infer_fn *
pooled = nn.pooling.avg_pool(
File "/tmp/lyra_notebook.par/google3/third_party/py/flax/linen/pooling.py", line 72, in avg_pool *
y = pool(inputs, 0., lax.add, window_shape, strides, padding)
File "/tmp/lyra_notebook.par/google3/third_party/py/flax/linen/pooling.py", line 52, in pool *
y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/windowed_reductions.py", line 79, in reduce_window *
return monoid_reducer(operand, window_dimensions, window_strides, padding,
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/windowed_reductions.py", line 127, in _reduce_window_sum *
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/core.py", line 324, in bind *
return self.bind_with_trace(find_top_trace(args), args, params)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/core.py", line 327, in bind_with_trace *
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 977, in invoke_impl *
File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/impl_no_xla.py", line 562, in tf_pool *
op = tf.reshape(op, (1,) + operand_shape + (1,))
TypeError: Failed to convert elements of (1, 1, 16*t, 1, 1) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.
Oh, and things seem to have gotten worse in one regard... Previously I noticed that conversion succeeds if stride >= pool_size, but that no longer works with the new change applied.
Looking briefly at the code+error, the reshape here is a bit strange, as the tensor already has batch and channel axes.
Hi Tom,
Thanks for filing the bug! I removed the batch + channel axes logic because I thought this was not part of the op: looking at the operational semantics of XLA::ReduceWindow, the definition of operands
is A sequence of N multi-dimensional arrays of types T_0,..., T_{N-1}, each representing the base area on which the window is placed.
However, it seems it actually is possible to add batch and feature dimensions. In fact, this seems to be what Flax does as well! (so all Flax Modules using pooling now fail)
I think this actually quite a good argument for adding some end-to-end Flax --> jax2tf (enable_xla=False) tests, which would catch this. I wrote some tooling in converter_eval a while ago, but I think we should rewrite this to some actual tests. I filed #11872 for this.
In any case, I will look at adding back support for batch and channel dimensions.
I actually found a different bug in our current implementation when doing average pooling with "SAME" padding. It is related to the way TF computes average pooling, and will lead to different outputs than JAX (https://github.com/google/jax/issues/11874).
@sdenton4 PTAL at that issue since you are doing average pooling with "SAME" padding as well, so I think the bug affects your code.
I just thought of a solution to that problem using manual padding whenever we encounter SAME padding. Will try implementing that fix together with this one.
Oh, wow; interesting bug on the TF side. Probably worth filing something with them?
It certainly affects me, but the TF behavior seems not too bad for my use-case (averaging embeddings over time). Treating the padding as signal effectively creates a downward bias on the mean. Dealing with the noisier average, without fake data, is probably preferable for me...
If there's anything I can help with feel free to send a ping, of course. I've got a couple other things on the front burners right now, but as you know this is all important to me, so happy to pitch in where I can.
If there's anything I can help with feel free to send a ping, of course.
No worries, i am taking a look at this issue and the other one, I'll let you know when I have a fix!
@sdenton4 could you please try https://github.com/google/jax/pull/11913 on your code and see if it fixes things?