tiny-cuda-nn icon indicating copy to clipboard operation
tiny-cuda-nn copied to clipboard

merge 2nd order derivative of CutlassMLP

Open SusanLiu0709 opened this issue 1 year ago • 11 comments

Hi @Tom94,

As second order derivative is a common operation in 3D reconstruction (neural SDF related), there sometimes requires an Eikonal loss during training, in which a first order derivative is leveraged to better learn the normal of surface so as to compute a second order derivative when backpropagation. We (@SusanLiu0709 and @yanglixiao1994) implemented the feature of 2nd order derivative of class CutlassMLP and NetworkWithInputEncoding.

Overall

There are major changes in files below:

  • network_with_input_encoding.h: add up backward_backward_input_impl function of class NetworkWithInputEncoding, of which the dataflow includes 2 modules interaction, encoding and network, shown as Figure 1.
  • cutlass_mlp.h: add up definitions involved in 2nd order derivative of the network class CutlassMLP, including backward_backward_input_impl, prepare_backward_variables.
  • cutlass_mlp.cu: add up implementation of backward_backward_input_impl and necessary kernels and functions, including compute_dL2dw, compute_dL2_ddL1doutput, comput_dL2dinput, etc. Its detailed dataflow is shown in Figure 2.
  • identity.h: add up 2nd order derivative support of identity encoding, in case class Network is solely used and encoding would be default as identity.

Theoretical Derivation

Class NetworkWithInputEncoding

Class NetworkWithInputEncoding contains 2 modules, an encoder and a network. An encoder can be set as identity, grid (hash encoding), etc. And the network is a multi-layer perceptron including several hidden layers and activation. And the dataflow of its forward, backward and backward backward pass can be visualized as Figure 1.

Previously, tiny-cuda-nn already supports forward and backward pass of class NetworkWithInputEncoding. For backward backward pass, it consists of 3 stages and involves with 2 modules, the encoder and network. Those 3 stages are 2nd order derivative of the encoder, 2nd order derivative of the network, and 1st order derivative of dL2dmlp_input back to dL2dinput.

For each 2nd order derivative (i.e. backward backward pass), there require 3 input params of current module (encoder/network), input, dL2_ddL1dinput i.e. 2nd order derivative term, and dL1doutput 1st order derivative term, and meanwhile there exist 2 output variables of current module, dL2_ddL1doutput 2nd order derivative term of dL1doutput, and dL2_dinput 2nd order derivative of input.

For the 1st order derivative involved in the 2nd order derivative (backward backward pass), there also require 3 input params of current module, input, output, dL2doutput and 1 output variable, dL2dinput.

And we have implemented backward backward pass of class NetworkWithInputEncoding in the function backward_backward_input_impl() in network_with_input_encoding.h.

Figure 1 Figure 1. The overall forward, backward and backward backward dataflow of encoding, linear layer (including activation). It's worth noting that there also includes a 1st order derivative when the 2nd order term dL2dinput finished computing, marked as dashed amaranth line.

Class CutlassMLP

Previously, tiny-cuda-nn already supports 2nd order derivative of hash encoding !69. And this pull request mainly focuses on implementing 2nd order derivative of the network module, class CutlassMLP (FullyFusedMLP not supported yet). And the simplified dataflow of the network module is visualized as Figure 2.

1695041737179 Figure 2. The forward, backward and backward backward dataflow of a single linear layer (including activation). It's worth noting that there also includes a 1st order derivative when the 2nd order term dL2dinput finished computing, while the 1st order derivative is not marked in this figure and it should be similar to the dashed amaranth line Figure 1.

The detailed numeric derivation of CutlassMLP 2nd order derivative are shown as below:

1695041770254 1695041804114 Figure 3. The numeric derivation of 2nd order derivative of a network with multi-layer perceptron.

Implementation Details

With the analysis numeric derivation of 2nd order derivative, we designed and implemented the computing workflow and it can be shown as Figure 4.

image Figure 4. The workflow of the implemented 2nd order derivative of CutlassMLP.

Visual Comparison

In order to further verify the correctness of implemented 2nd order derivative of the network module, we conduct an experiment comparing the visual quality trained by pytorch and tiny-cuda-nn (TCNN) respectively. The training is based on the open-source method, NeuS and the only difference between pytorch and TCNN version is the definition of SDF Network. And the training results are shown in Figure 4 and Figure 5, in which there is no obvious difference between the pytorch and TCNN results.

1695112705579 Figure 5. Visual comparison of trained results of pytorch and tiny-cuda-nn (TCNN). And the encoder of the SDF Network is set as positional encoding and the network is set as 3 hidden layers with ReLU activation and None activation of the output layer.

1695113440841 Figure 6. Visual comparison of trained results of pytorch and tiny-cuda-nn (TCNN). And the encoder of the SDF Network is set as hash encoding and the network is set as 3 hidden layers with ReLU activation and None activation of the output layer.

Numeric Alignment with Pytorch

To verify the correctness of implemented 2nd order derivative of CutlassMLP and NetworkWithInputEncoding, we implemented a toy test script defining a simple neural SDF with an Eikonal loss supervised. The sample codes are as below:

image Figure 7. Comparison between sampled gradients from Pytorch and TCNN defined NeuS. All the numbers are sampled from 1st and 2nd hidden layer of the first training iteration.

image Figure 8. Weights distribution comparison between Pytorch and TCNN of the 1st hidden layer.

image Figure 9. Weights distribution comparison between Pytorch and TCNN of the 2nd hidden layer.

image Figure 10. Weights distribution comparison between Pytorch and TCNN of the 3rd hidden layer.

TODO

More details would be complemented soon:

  • upload neuS+TCNN sample codes.

SusanLiu0709 avatar Sep 18 '23 12:09 SusanLiu0709

Hi Susan, thank you (and @yanglixiao1994) very much for this contribution as well as its thorough derivation and testing!

I'd like to take the time to properly review it, but am already occupied in the next week or two. I'll get back to this PR afterwards -- thanks again.

Tom94 avatar Sep 25 '23 15:09 Tom94

Thanks for your wonderful job. so this implementation is almost same as neus2 ? @SusanLiu0709

zebin-dm avatar Oct 10 '23 11:10 zebin-dm

i have try your PR, But i can not find the file Armadillo.ply, how can i generate it. @SusanLiu0709

zebin-dm avatar Oct 10 '23 12:10 zebin-dm

Thanks for your wonderful job. so this implementation is almost same as neus2 ? @SusanLiu0709

Hi zebin-dm, we were implementing the feature before neuS2 released. And currently we are preparing to learn the details of neuS2 and check if it's available to merge neuS2 into our implementation.

SusanLiu0709 avatar Oct 10 '23 19:10 SusanLiu0709

i have try your PR, But i can not find the file Armadillo.ply, how can i generate it. @SusanLiu0709

Thanks for testing :) @yanglixiao1994 may help to upload the test data and params soon.

SusanLiu0709 avatar Oct 10 '23 19:10 SusanLiu0709

i have try your PR, But i can not find the file Armadillo.ply, how can i generate it. @SusanLiu0709

Hi, zebin. This is the testing armadillo data. (https://drive.google.com/file/d/1KfIkGcLkQOopnXoBLmkT55bBBNQu6nBm/view?usp=sharing). Actually, you can generate your own data(3D grid and corresponding sdf) according to https://github.com/SusanLiu0709/tiny-cuda-nn/blob/8c66716e59b94f73f918c058797e17368528c748/scripts/test_armadillo_numeric_align.py#L129

yanglixiao1994 avatar Oct 11 '23 03:10 yanglixiao1994

i have try your PR, But i can not find the file Armadillo.ply, how can i generate it. @SusanLiu0709

Hi, zebin. This is the testing armadillo data. (https://drive.google.com/file/d/1KfIkGcLkQOopnXoBLmkT55bBBNQu6nBm/view?usp=sharing). Actually, you can generate your own data(3D grid and correspondent sdf) according to https://github.com/SusanLiu0709/tiny-cuda-nn/blob/8c66716e59b94f73f918c058797e17368528c748/scripts/test_armadillo_numeric_align.py#L129

Thank you very much.

zebin-dm avatar Oct 11 '23 03:10 zebin-dm

Hi, thanks again for this PR! I requested a few changes in the C++ code that are required before I can merge. Please feel free to discuss if anything is unclear or if I missed something.

Once the changes are in, I'll go through the testing code (which I appreciate a lot, by the way) and give another round of feedback.

Hi Thomas @Tom94,

Sorry for replying late. I was busy last 3 months. And I can work on the PR these days. Hope it's not too late.

Best, Susan

SusanLiu0709 avatar Apr 01 '24 10:04 SusanLiu0709

Hi, thanks again for this PR! I requested a few changes in the C++ code that are required before I can merge. Please feel free to discuss if anything is unclear or if I missed something. Once the changes are in, I'll go through the testing code (which I appreciate a lot, by the way) and give another round of feedback.

Hi Thomas @Tom94,

Sorry for replying late. I was busy last 3 months. And I can work on the PR these days. Hope it's not too late.

Best, Susan

Thanks you very much, looking forward your nice job.

zebin-dm avatar Apr 01 '24 12:04 zebin-dm

Hi,

Using the network config below, I still get an error when calculating the second derivative.

File "/opt/conda/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward torch.autograd.backward( File "/opt/conda/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward _engine_run_backward( File "/opt/conda/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/opt/conda/lib/python3.9/site-packages/torch/autograd/function.py", line 301, in apply return user_fn(self, *args) File "/opt/conda/lib/python3.9/site-packages/tinycudann-1.7-py3.9-linux-x86_64.egg/tinycudann/modules.py", line 145, in backward doutput_grad, params_grad, input_grad = ctx.ctx_fwd.native_tcnn_module.bwd_bwd_input( RuntimeError: DifferentiableObject::backward_backward_input_impl: not implemented error

 "network": {
            "otype": "CutlassMLP",
            "activation": "Sine",
            "output_activation": "None",
            "n_neurons": 16,
            "n_hidden_layers": 3
        },


Should these changes also be available in pytorch at the moment?

lucasdevries avatar May 16 '24 15:05 lucasdevries

Hi,

Using the network config below, I still get an error when calculating the second derivative.

File "/opt/conda/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward torch.autograd.backward( File "/opt/conda/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward _engine_run_backward( File "/opt/conda/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/opt/conda/lib/python3.9/site-packages/torch/autograd/function.py", line 301, in apply return user_fn(self, *args) File "/opt/conda/lib/python3.9/site-packages/tinycudann-1.7-py3.9-linux-x86_64.egg/tinycudann/modules.py", line 145, in backward doutput_grad, params_grad, input_grad = ctx.ctx_fwd.native_tcnn_module.bwd_bwd_input( RuntimeError: DifferentiableObject::backward_backward_input_impl: not implemented error

 "network": {
            "otype": "CutlassMLP",
            "activation": "Sine",
            "output_activation": "None",
            "n_neurons": 16,
            "n_hidden_layers": 3
        },

Should these changes also be available in pytorch at the moment?

For now, the 2nd order derivative of the activation function "Sine" is not supported. You can try with "ReLU" and "Softplus".

SusanLiu0709 avatar Jun 12 '24 22:06 SusanLiu0709