keras-cv icon indicating copy to clipboard operation
keras-cv copied to clipboard

PR proposal: 3D deformable convolution

Open SimonBiggs opened this issue 3 years ago • 30 comments

I am proposing to be able to include a 3D deformable convolution within keras-cv

Background

A recent paper has used deformable convolutions as a key building block in order to have a CNN be state of the art:

  • https://github.com/opengvlab/internimage
  • https://arxiv.org/abs/2211.05778

Also, in the past they have shown to be able to make standard convolutional networks be more robust to multi-scale problems: https://github.com/kastnerkyle/deform-conv

There was an implementation PR for these over at tensorflow addons: https://github.com/tensorflow/addons/pull/2196#issuecomment-1351957574

And there was an idea for how to make that particular PR smaller: https://github.com/tensorflow/addons/pull/2196#issuecomment-884275962

There is also an efficient Apache-2.0 tensor-rt implementation of a 2D variant for pytorch over at: https://github.com/open-mmlab/mmcv/blob/46eb9ec5d07ea344ed43056d007a7eb71dc3ee98/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp https://github.com/open-mmlab/mmcv/blob/46eb9ec5d07ea344ed43056d007a7eb71dc3ee98/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu

In my particular task I am undergoing segmentation of 3D medical imaging CT and MRI scans in order to help support radiotherapy cancer treatments. These scans are often "multi-scale" with differing voxel (a 3D pixel) sizes and slice thicknesses. A standard approach is to resample all data to an isotropic voxel size. But combined with a few other components I hypothesise that by utilising a 3D deformable convolution instead that this would be part of the story of leading to an efficient new state of the art approach.

Proposal

That we undergo an initial implementation of a 3D deformable convolution within keras-cv. After our work someone else can extend it further to also work on the 2D case. I would be willing to spearhead the task. But first, let's discuss an appropriate and acceptable approach here.

SimonBiggs avatar Dec 14 '22 19:12 SimonBiggs

/cc @innat

bhack avatar Dec 14 '22 19:12 bhack

@innat and @bhack, what do you believe our next steps should be?

SimonBiggs avatar Dec 15 '22 08:12 SimonBiggs

/cc @ianstenbit @tanzhenyu

bhack avatar Dec 15 '22 11:12 bhack

As a following note for my comment on https://github.com/tensorflow/addons/pull/2196#issuecomment-884275962

I've written my own tensorflow addon that does exactly the proposed method. A lot of work remains because I didn't write test code and only supported the subset of parameters I needed. It works for me and gives a speed boost over the original because I don't need the NHWC NCHW transposition, but it's not as fast as it should. There likely remains some investigation to be sure everything is as optimized as it should be (maybe there is some flag to add somewhere or something. I'm very confident in my OpenCL writing, but I don't do CUDA usually. I'm open to share my code if it helps.

axeldavy avatar Dec 15 '22 16:12 axeldavy

There was an old compositional implementation of a 2d deformable conv: https://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers/convolution/deformable_conv.html

But I have not benchmarked it under XLA or the new bridge to understand how it is efficiently lowered.

bhack avatar Dec 15 '22 21:12 bhack

I'm open to share my code if it helps.

@axeldavy, that would be amazing if you could. Would you be open to putting it up on a repo under an Apache-2.0 license?

There was an old compositional implementation of a 2d deformable conv

I believe that implementation is a copy of the following:

https://github.com/kastnerkyle/deform-conv

SimonBiggs avatar Dec 15 '22 21:12 SimonBiggs

I believe that implementation is a copy of the following:

Yes probably it is derived but with some refactoring.

It could be interesting if you can try to jit_compile a pure TF version of that one in a Colab just to understand the performance and the emitted code.

bhack avatar Dec 15 '22 23:12 bhack

Also, just a heads up, in case it's an issue, I can't actually interact with the following discussion thread:

https://github.com/openxla/xla/discussions/17#discussioncomment-4412225

SimonBiggs avatar Dec 15 '22 23:12 SimonBiggs

It could be interesting if you can try to jit_compile a pure TF version of that one in a Colab just to understand the performance and the emitted code.

Might you be interested @axeldavy in comparing the difference between your implementation and a tf.function(jit_compile=True) version of the following code?

~~https://github.com/kastnerkyle/deform-conv~~ --> https://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers/convolution/deformable_conv.html

Edit: see https://github.com/keras-team/keras-cv/issues/1140#issuecomment-1353882096

SimonBiggs avatar Dec 15 '22 23:12 SimonBiggs

I still see np code there that instead in the tensorlayer flavour was removed.

bhack avatar Dec 15 '22 23:12 bhack

We had a V2 version at https://github.com/smallsunsun1/deformable_conv_v2 but I've not verfied the impl.

bhack avatar Dec 15 '22 23:12 bhack

We had a V2 version at https://github.com/smallsunsun1/deformable_conv_v2 but I've not verfied the impl.

Unfortunately no license file...

SimonBiggs avatar Dec 15 '22 23:12 SimonBiggs

Yes but it is just for benchmarking so I am not worried about the license while it is not used in a PR:

https://github.com/RuaHU/keras_DCNv2

bhack avatar Dec 16 '22 00:12 bhack

I am open to including a deformable conv layer in KerasCV, probably as an experimental API for now. We don't have bandwidth to build this in-house right now but are open to a contribution @SimonBiggs.

I do wonder if deformable contribution in general is something that should eventually be upstreamed to core Keras, given that it may also have value in time-domain convolutions (at least I would guess it might -- haven't seen any papers about this).

@tanzhenyu wdyt?

ianstenbit avatar Dec 16 '22 00:12 ianstenbit

given that it may also have value in time-domain convolutions (at least I would guess it might -- haven't seen any papers about this).

Here's an example paper for time-domain with 3D deformable convolutions: https://arxiv.org/abs/2004.02803

SimonBiggs avatar Dec 16 '22 01:12 SimonBiggs

@axeldavy @SimonBiggs @bhack Yeah I'm interested in knowing whether this can be expressed by native TF ops. If it works and it's really just a matter of kernel performance, maybe that's something the XLA can help. Otherwise, hosting a CUDA kernel seems to be contradicting to our core value, which is that it should on GPU/TPU without code change.

tanzhenyu avatar Dec 16 '22 02:12 tanzhenyu

Here is the code I'm using: https://github.com/axeldavy/tensorflow_addons/tree/add-deformable-interpolation

The CUDA code can be seen here: https://github.com/axeldavy/tensorflow_addons/blob/add-deformable-interpolation/tensorflow_addons/custom_ops/layers/cc/kernels/deformable_interpolation_op_gpu.cu.cc

I don't know enough of XLA to know if it can be efficiently expressed with it.

I haven't benchmarked the code either or checked the generated asm is ok. Maybe changing the border handling would help performance. For example in openCL it is typically much better to always sample (using adress clamping for when we are outside), and replace the read content if we were outside, rather than reading conditionally. This is due to the generated asm being of better quality (the memory access is generated more in advance, thus it reduces latency). Similarly it would be better to know at compilation time the number of channels, to enable loop unrolling. I don't know if CUDA recompiles internally to optimize that.

The difference of my approach against DCNv2 usual implementations is that the number of position sampled is arbitrary (and doesn't have initial offsets. If you want to initialize to 3x3 instead of all positions in 0, you need to have it in the input sampling offset). The cost is proportional to the number of positions.

Here is an example how to use it:

# Standard DCNV2 would be 9 positions and 1 group (all features have the same offsets).
num_positions = 4
num_groups = 1
 # Compute offsets for the positions - init with zeros (not like standard DCN)
offsets = Conv2D(2 * num_positions * num_groups, 1,
                             padding='same', use_bias=True,
                             kernel_initializer='zeros', kernel_regularizer=l2(5e-3), bias_regularizer=l2(5e-3),
                             name='conv_offsets_1' )(x)
# Compute DCNv2 masks
mask = Conv2D(num_positions * num_groups, 1,
                          padding='same', use_bias=True, kernel_initializer='zeros',
                          kernel_regularizer=l2(5e-3), bias_regularizer=l2(5e-3),
                          name='conv_mask_1', activation='sigmoid')(x)
# Retrieve features at the different positions and stack them
x = DeformableInterpolation(offset_groups=num_groups, num_positions=num_positions,
                                        use_mask=True, name="dcn_interp_1")([x, offsets, mask])
# Apply the convolution
x = Conv2D(num_filters, 1, padding='same', use_bias=True, kernel_initializer='he_normal',
                      kernel_regularizer=l2(5e-3), name="dcn_conv_1")(x)

One other advantage I see to this approach is that it seems easy to implement DCNv3 with it (change sigmoid for softmax, use more than 1 group, replace the Conv2D call (with two convolutions if I'm not mistaken. One that use groups and reduces the number of features, and one normal one). Also the code works with float16.

axeldavy avatar Dec 16 '22 08:12 axeldavy

Thanks @axeldavy, it's Friday afternoon here in Aus, so I won't get back to this until next week. But thank you. Thank you very much.

SimonBiggs avatar Dec 16 '22 09:12 SimonBiggs

maybe that's something the XLA can help.

If XLA cannot speed up the operator enough or it cannot compile we could always check with the upstream why it cannot be efficiently lowered:

https://github.com/openxla/xla/discussions/17#discussioncomment-4412225

bhack avatar Dec 16 '22 10:12 bhack

It's not just about training with tensorflow though. One issue with deformable convolution is getting it to work with inference solutions (tensorrt, etc). Ideally there should be an onnx operation for it, and the corresponding implementations in tensorflow, tensorrt, etc.

axeldavy avatar Dec 16 '22 14:12 axeldavy

Xla has custom call if the high level op is not lowered efficiently or the the specific backend is not supported:

https://www.tensorflow.org/xla/custom_call

bhack avatar Dec 16 '22 19:12 bhack

So, I might have messed something up with the benchmark. I am just using timeit + running the function once for preheating. And, gee wizz, it looks like XLA absolutely wins the day:

CPU timings:

  • mvcc (pytorch + cpp): 8 s
  • tensorlayer (tensorflow + jit): 1 s

GPU timings:

  • mvcc (pytorch + cuda): 16 ms
  • tensorlayer (tensorflow + jit): 6 ms

Here is the colab: https://colab.research.google.com/drive/1TNIUJN4W95V81VkWq0n2TksKisnfbZTR?usp=sharing

Keen for feedback it there's anything I have done wrong. Otherwise, if it's okay with the team I'll get to work writing a 3D Modulated Deformable convolution in pure python using native TF ops based off of tensorlayer's implementation and hopefully be able to provide it as a PR.

Thank you @bhack for helping me through this!

SimonBiggs avatar Dec 20 '22 03:12 SimonBiggs

So, I might have messed something up with the benchmark. I am just using timeit + running the function once for preheating. And, gee wizz, it looks like XLA absolutely wins the day:

CPU timings:

  • mvcc (pytorch + cuda): 8 s
  • tensorlayer (tensorflow + jit): 1 s

GPU timings:

  • mvcc (pytorch + cuda): 16 ms
  • tensorlayer (tensorflow + jit): 6 ms

Here is the colab: https://colab.research.google.com/drive/1TNIUJN4W95V81VkWq0n2TksKisnfbZTR?usp=sharing

Keen for feedback it there's anything I have done wrong. Otherwise, if it's okay with the team I'll get to work writing a 3D Modulated Deformable convolution in pure python using native TF ops based off of tensorlayer's implementation and hopefully be able to provide it as a PR.

Thank you @bhack for helping me through this!

Sounds good. Feel free to add a deformable conv3d in this library

tanzhenyu avatar Dec 20 '22 04:12 tanzhenyu

Also, @axeldavy, I tried to build and install your fork: https://colab.research.google.com/drive/1TNIUJN4W95V81VkWq0n2TksKisnfbZTR?usp=sharing#scrollTo=EdjOCYULwXv1

But I didn't have success. Keen to know what I did wrong there. It would be nice to be able to also benchmark your implementation.

SimonBiggs avatar Dec 20 '22 04:12 SimonBiggs

Your compilation failure doesn't seem related to my changes. I would suggest to just apply my last commit to a branch that works for you.

axeldavy avatar Dec 20 '22 08:12 axeldavy

I'm going on leave for the next few weeks so won't be in a position to look at this again until the 3rd quarter of Jan. If someone else is keen to do some work on it I'm completely happy for that, otherwise I'll get back to it next year.

SimonBiggs avatar Dec 20 '22 08:12 SimonBiggs

I hadn't noticed, but onnx has recently received a new operator GridSample that is relevant to the topic. https://github.com/onnx/onnx/blob/main/docs/Operators.md

I think it is relevant, because it is better if the proposed implementation can export to a good enough onnx. Using an operator like GridSample (available with pytorch but not yet with tensorflow) removes the need to do the bilinear interpolation manually like in https://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers/convolution/deformable_conv.html

EDIT: If I'm not clear enough in my message, I suggest adding tensorflow and XLA support for a function similar to GridSample, which could be used to implement deformable convolution, but is also useful for other applications, such as optical flow, registration, etc.

axeldavy avatar Dec 29 '22 13:12 axeldavy

https://github.com/tensorflow/models/issues/7381

https://github.com/tensorflow/tensorflow/issues/56225

bhack avatar Dec 29 '22 14:12 bhack

So after taking three weeks of leave, it seems my plate is now exceptionally full. I'm going to have to put this in my backlog for the time being.

SimonBiggs avatar Jan 17 '23 01:01 SimonBiggs

This issue is stale because it has been open for 180 days with no activity. It will be closed if no further activity occurs. Thank you.

github-actions[bot] avatar Jan 30 '24 01:01 github-actions[bot]