tensorflow-onnx
tensorflow-onnx copied to clipboard
feat: add support for ExtractImagePatches
Closes: https://github.com/onnx/tensorflow-onnx/issues/436
This rewrite is based on this comment: https://github.com/onnx/tensorflow-onnx/issues/436#issuecomment-993313423 with changes to make it more general and translatable into tf2onnx.
Equivalent TensorFlow function and automated test script (expand)
import tensorflow as tf
import numpy as np
from hypothesis import given, strategies as st, settings, assume
def our_extract_image_patches(sizes, strides, rates, padding):
# TensorFlow's constraints.
assert sizes[0] == 1 and sizes[3] == 1
assert strides[0] == 1 and strides[3] == 1
assert rates[0] == 1 and rates[3] == 1
assert padding in ["SAME", "VALID"]
# Extract size.
[_, size_rows, size_cols, _] = sizes
@tf.function
def function(tensor):
# Input shape of [N, H, W, C].
tensor_shape = tensor.shape
# Transpose and reshape to [N * C, H, W, 1].
tensor = tf.transpose(tensor, perm=[0, 3, 1, 2])
tensor = tf.reshape(tensor, [
tensor_shape[0] * tensor_shape[3],
tensor_shape[1],
tensor_shape[2],
1,
])
# Convolve with identity kernel into [N * C, ?H, ?W, K].
k = size_rows * size_cols
kernel = tf.reshape(tf.eye(k), [size_rows, size_cols, 1, k])
convolution = tf.nn.conv2d(tensor, kernel, strides=strides, padding=padding, dilations=rates)
# Reshape into [N, C, ?H, ?W, K].
reshaped = tf.reshape(convolution, [
tensor_shape[0],
tensor_shape[3],
convolution.shape[1],
convolution.shape[2],
k,
])
# Transpose and reshape into [N, ?H, ?W, C * K].
patches = tf.transpose(reshaped, perm=[0, 2, 3, 4, 1])
return tf.reshape(patches, [
tensor_shape[0],
convolution.shape[1],
convolution.shape[2],
tensor_shape[3] * k,
])
return function
def tf_extract_image_patches(sizes, strides, rates, padding):
@tf.function
def function(tensor):
return tf.image.extract_patches(
tensor,
sizes=sizes,
strides=strides,
rates=rates,
padding=padding,
)
return function
@settings(max_examples=5000)
@given(
st.lists(st.integers(min_value=1, max_value=20), min_size=4, max_size=4),
st.integers(min_value=1, max_value=20),
st.integers(min_value=1, max_value=20),
st.integers(min_value=1, max_value=20),
st.integers(min_value=1, max_value=20),
st.integers(min_value=1, max_value=20),
st.integers(min_value=1, max_value=20),
st.sampled_from(["VALID", "SAME"]),
)
def test_equal(shape, size_rows, size_cols, stride_rows, stride_cols, dil_rows, dil_cols, padding):
sizes = [1, size_rows, size_cols, 1]
strides = [1, stride_rows, stride_cols, 1]
rates = [1, dil_rows, dil_cols, 1]
try:
tensor = tf.cast(tf.reshape(tf.range(np.prod(shape)), shape), dtype=tf.float32)
tfs = tf_extract_image_patches(sizes, strides, rates, padding)(tensor)
if 0 in tfs.shape:
# We cannot handle operations that produce empty outputs.
assume(False)
except ValueError:
# Ignore input if TensorFlow would fail.
assume(False)
return
ours = our_extract_image_patches(sizes, strides, rates, padding)(tensor)
assert tf.reduce_all(tf.math.equal(tfs, ours)).numpy()
Output from pytest convolve.py --hypothesis-show-statistics (no failures):
convolve.py::test_equal:
- during generate phase (70.74 seconds):
- Typical runtimes: ~ 1-14 ms, of which < 1ms in data generation
- 5000 passing examples, 0 failing examples, 3913 invalid examples
- Stopped because settings.max_examples=5000
Thanks you for putting that solution into this PR, and it looks great!
Rewriter is designed to rewrite the ONNX graph after we transform each tf op into the corresponding onnx op. Each rewriter will search the ONNX graph following a given pattern. Once the pattern is matched, those involved onnx ops will be replaced with some other ops for an optimization in further inference.
In this case, ExtractImagePatches is just a tf op which is not supported by tf2onnx yet. So, your implementations should be put into nn.py file instead of adding a rewriter. Please add it into nn.py, just like adding a new tf op support.
Please feel free to refer to this comment for more details.
Hi @fatcat-z,
Rewriter is designed to rewrite the ONNX graph after we transform each tf op into the corresponding onnx op.
I'm not entirely sure if this is true. From my understanding, the rewriters are ran before each operation is converted into an ONNX operation:
https://github.com/onnx/tensorflow-onnx/blob/01520291227d46c48615300b0a73436dbe3c6610/tf2onnx/tfonnx.py#L616 where line 622 performs the conversion (?). There do appear to be late rewriters that run after the mapping occurs, but in general, it seems like the rewriting and optimization steps are separate.
I chose to implement this as a rewrite in order to avoid duplicating the construction of the Conv2D node but if you would still prefer for this to be implemented in nn.py, please let me know.
Hi @fatcat-z,
Rewriter is designed to rewrite the ONNX graph after we transform each tf op into the corresponding onnx op.
I'm not entirely sure if this is true. From my understanding, the rewriters are ran before each operation is converted into an ONNX operation:
https://github.com/onnx/tensorflow-onnx/blob/01520291227d46c48615300b0a73436dbe3c6610/tf2onnx/tfonnx.py#L616
where line 622 performs the conversion (?). There do appear to be late rewriters that run after the mapping occurs, but in general, it seems like the rewriting and optimization steps are separate. I chose to implement this as a rewrite in order to avoid duplicating the construction of the Conv2D node but if you would still prefer for this to be implemented in
nn.py, please let me know.
No, graphs_from_tf() function will transfer the tf graph to onnx graph meaning each tf op has been converted to onnx op, if possible. Afterwards, process_parsed_graph() will be called to finish those rewriters and optimizations.
Yes, please implement this as an op in nn.py instead of creating a new rewriter. Thanks.
Is this new operator going to be merged into main?
Is this new operator going to be merged into main?
I don't think this operator can be considered to be new but I'm aiming to get the requested changes done sometime within the week.
Sorry for the delay! I've implemented this operation inside of nn.py instead of as a rewriter. Let me know if anything else needs to be changed!