jax icon indicating copy to clipboard operation
jax copied to clipboard

BUG: [jax2tf] Strided pooling with polymorphic shape sometimes fails.

Open sdenton4 opened this issue 2 years ago • 5 comments

Description

Pooling operations sometimes fail to convert. It looks like a None dimension is sometimes slipping through the cracks. The bug depends on the stride value... I'm using 'framed' inputs to ensure that striding evenly divides the input size.

Here's a minimal colab notebook replication: https://colab.corp.google.com/drive/1FX99EPcaX-1mAVnpnpUQwkkR0WZ_6o3h#scrollTo=NzSMa2VWhzWq

TypeError: in user code:

    File "<ipython-input-2-e6d49ad967e7>", line 26, in None  *
        lambda inputs: converted_infer_fn(
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 302, in fun_no_kwargs  *
        return fun(*args, **kwargs)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 534, in _interpret_fun  *
        out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] =               _call_wrapped_with_new_constant_cache(fun, in_vals,
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 688, in _call_wrapped_with_new_constant_cache  *
        out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] =         fun.call_wrapped(*in_vals)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/linear_util.py", line 168, in call_wrapped  *
        ans = self.f(*args, **dict(self.params, **kwargs))
    File "<ipython-input-23-a74605bc8077>", line 11, in infer_fn  *
        pooled = nn.pooling.avg_pool(
    File "/tmp/lyra_notebook.par/google3/third_party/py/flax/linen/pooling.py", line 72, in avg_pool  *
        y = pool(inputs, 0., lax.add, window_shape, strides, padding)
    File "/tmp/lyra_notebook.par/google3/third_party/py/flax/linen/pooling.py", line 52, in pool  *
        y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/windowed_reductions.py", line 79, in reduce_window  *
        return monoid_reducer(operand, window_dimensions, window_strides, padding,
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/windowed_reductions.py", line 127, in _reduce_window_sum  *
        window_dilation=tuple(window_dilation))
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/core.py", line 324, in bind  *
        return self.bind_with_trace(find_top_trace(args), args, params)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/core.py", line 327, in bind_with_trace  *
        out = trace.process_primitive(self, map(trace.full_raise, args), params)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 977, in invoke_impl  *
        **params)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/impl_no_xla.py", line 553, in _reduce_window  *
        tf_padding = pads_to_padtype(operand.shape, window_dimensions, window_strides, padding)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/impl_no_xla.py", line 115, in pads_to_padtype  *
        pads = lax.padtype_to_pads(in_shape, window_shape, window_strides, pad_str)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/lax.py", line 4527, in padtype_to_pads  *
        out_shape = _ceil_divide(in_shape, window_strides)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/lax.py", line 4512, in _ceil_divide  *
        return -np.floor_divide(np.negative(x1), x2)

    TypeError: bad operand type for unary -: 'NoneType'

What jax/jaxlib version are you using?

v0.3.15

Which accelerator(s) are you using?

  • [X] CPU
  • [ ] GPU
  • [X] TPU

Additional System Info

No response

sdenton4 avatar Aug 08 '22 22:08 sdenton4

It seems that this is only for the case enable_xla=False.

@marcvanzee PTAL

gnecula avatar Aug 11 '22 13:08 gnecula

This may have been fixed incidentally by a very recent change #11816. Can you please try again at HEAD?

gnecula avatar Aug 11 '22 13:08 gnecula

Ah, that's great! Still getting an error, but the message has changed.

TypeError: in user code:

    File "<ipython-input-2-e6d49ad967e7>", line 26, in None  *
        lambda inputs: converted_infer_fn(
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 302, in fun_no_kwargs  *
        return fun(*args, **kwargs)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 534, in _interpret_fun  *
        out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] =               _call_wrapped_with_new_constant_cache(fun, in_vals,
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 688, in _call_wrapped_with_new_constant_cache  *
        out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] =         fun.call_wrapped(*in_vals)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/linear_util.py", line 168, in call_wrapped  *
        ans = self.f(*args, **dict(self.params, **kwargs))
    File "<ipython-input-3-587f6ee978a3>", line 13, in infer_fn  *
        pooled = nn.pooling.avg_pool(
    File "/tmp/lyra_notebook.par/google3/third_party/py/flax/linen/pooling.py", line 72, in avg_pool  *
        y = pool(inputs, 0., lax.add, window_shape, strides, padding)
    File "/tmp/lyra_notebook.par/google3/third_party/py/flax/linen/pooling.py", line 52, in pool  *
        y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/windowed_reductions.py", line 79, in reduce_window  *
        return monoid_reducer(operand, window_dimensions, window_strides, padding,
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/_src/lax/windowed_reductions.py", line 127, in _reduce_window_sum  *
        window_dilation=tuple(window_dilation))
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/core.py", line 324, in bind  *
        return self.bind_with_trace(find_top_trace(args), args, params)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/core.py", line 327, in bind_with_trace  *
        out = trace.process_primitive(self, map(trace.full_raise, args), params)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/jax2tf.py", line 977, in invoke_impl  *
        **params)
    File "/tmp/lyra_notebook.par/google3/third_party/py/jax/experimental/jax2tf/impl_no_xla.py", line 562, in tf_pool  *
        op = tf.reshape(op, (1,) + operand_shape + (1,))

    TypeError: Failed to convert elements of (1, 1, 16*t, 1, 1) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.

sdenton4 avatar Aug 11 '22 14:08 sdenton4

Oh, and things seem to have gotten worse in one regard... Previously I noticed that conversion succeeds if stride >= pool_size, but that no longer works with the new change applied.

sdenton4 avatar Aug 11 '22 14:08 sdenton4

Looking briefly at the code+error, the reshape here is a bit strange, as the tensor already has batch and channel axes.

sdenton4 avatar Aug 11 '22 14:08 sdenton4

Hi Tom,

Thanks for filing the bug! I removed the batch + channel axes logic because I thought this was not part of the op: looking at the operational semantics of XLA::ReduceWindow, the definition of operands is A sequence of N multi-dimensional arrays of types T_0,..., T_{N-1}, each representing the base area on which the window is placed.

However, it seems it actually is possible to add batch and feature dimensions. In fact, this seems to be what Flax does as well! (so all Flax Modules using pooling now fail)

I think this actually quite a good argument for adding some end-to-end Flax --> jax2tf (enable_xla=False) tests, which would catch this. I wrote some tooling in converter_eval a while ago, but I think we should rewrite this to some actual tests. I filed #11872 for this.

In any case, I will look at adding back support for batch and channel dimensions.

marcvanzee avatar Aug 12 '22 08:08 marcvanzee

I actually found a different bug in our current implementation when doing average pooling with "SAME" padding. It is related to the way TF computes average pooling, and will lead to different outputs than JAX (https://github.com/google/jax/issues/11874).

@sdenton4 PTAL at that issue since you are doing average pooling with "SAME" padding as well, so I think the bug affects your code.

marcvanzee avatar Aug 12 '22 11:08 marcvanzee

I just thought of a solution to that problem using manual padding whenever we encounter SAME padding. Will try implementing that fix together with this one.

marcvanzee avatar Aug 12 '22 11:08 marcvanzee

Oh, wow; interesting bug on the TF side. Probably worth filing something with them?

It certainly affects me, but the TF behavior seems not too bad for my use-case (averaging embeddings over time). Treating the padding as signal effectively creates a downward bias on the mean. Dealing with the noisier average, without fake data, is probably preferable for me...

If there's anything I can help with feel free to send a ping, of course. I've got a couple other things on the front burners right now, but as you know this is all important to me, so happy to pitch in where I can.

sdenton4 avatar Aug 12 '22 19:08 sdenton4

If there's anything I can help with feel free to send a ping, of course.

No worries, i am taking a look at this issue and the other one, I'll let you know when I have a fix!

marcvanzee avatar Aug 14 '22 09:08 marcvanzee

@sdenton4 could you please try https://github.com/google/jax/pull/11913 on your code and see if it fixes things?

marcvanzee avatar Aug 15 '22 07:08 marcvanzee