gsplat icon indicating copy to clipboard operation
gsplat copied to clipboard

CUDA V2

Open liruilong940607 opened this issue 2 months ago • 10 comments

I dramatically rewrite the CUDA part to make the code runs more efficient, with cleaner code structure, and some extra functionalities.

New features:

  • Support batch rendering in one pass: N cameras -> images with shape (N, H, W, C). This allows for sampling evenly distributed cameras during every training iteration, thus could potentially lead to better convergence. With this, we probably don't need gradient accumulation anymore. Also, for rendering GS, if some latency can be allowed, we can use batched rendering to further boost FPS.
  • Supports fully PyTorch Auto-Grad: Provides a version of the rasterization & projection implementation that the gradients are fully managed by PyTorch Auto-Grad (see _rendering in experimental/cuda/__init__.py). This allows for easier try-out for the new ideas as the modifications can probably done purely on the python side. It could also serves as a gradient checker for the CUDA implementation.
  • Fused implementation with verified Gradients: Every gradient in this PR is carefully checked against PyTorch Auto-Grad. On the other hand, the current GSplat implementation has flaws (errors) in the gradients including -- 1. FoV clip is applied in the forward pass but not the backward pass. As a result, the gradients are incorrect on some GSs. -- 2. The backward suffers from a numerical issue caused by using the remained transmittance output from the forward pass because the remained transmittance can be very small (~1e-4). Double precision instead of float should be used for it.
  • Faster FWD and BWD (esp. in ND case): See table below for details.

Gradient Diff results again PyTorch Auto Grad (code in experimental/test.py):

viewmats scales quats means
GSplat 42.8834 2e-4 5e-4 4e-3
This PR 4e-3 2e-7 5e-6 3e-6

Runtime Performance (in ms) for Forward (code in experimental/profile.py):

Forward, Channel=1 Batch=1 Batch=8 Batch=64
GSplat 1.492 9.921 76.295
This PR 1.146 3.593 28.948
Forward, Channel=32 Batch=1 Batch=8 Batch=64
GSplat 5.078 32.824 264.146
This PR 1.930 7.773 64.461

Runtime Performance for Backward (disable viewmats gradients):

Backward, Channel=1 Batch=1 Batch=8 Batch=64
GSplat 1.594 10.600 84.395
This PR 1.562 7.409 60.402
Backward, Channel=32 Batch=1 Batch=8 Batch=64
GSplat 10.242 80.935 653.627
This PR 7.850 43.392 353.733

Runtime Performance for Backward (enable viewmats gradients):

Backward, Channel=3 Batch=1 Batch=8 Batch=64
GSplat 2.721 16.629 130.977
This PR 2.183 10.697 85.817

liruilong940607 avatar Apr 29 '24 21:04 liruilong940607