gsplat
gsplat copied to clipboard
CUDA V2
I dramatically rewrite the CUDA part to make the code runs more efficient, with cleaner code structure, and some extra functionalities.
New features:
- Support batch rendering in one pass: N cameras -> images with shape (N, H, W, C). This allows for sampling evenly distributed cameras during every training iteration, thus could potentially lead to better convergence. With this, we probably don't need gradient accumulation anymore. Also, for rendering GS, if some latency can be allowed, we can use batched rendering to further boost FPS.
-
Supports fully PyTorch Auto-Grad: Provides a version of the rasterization & projection implementation that the gradients are fully managed by PyTorch Auto-Grad (see
_rendering
inexperimental/cuda/__init__.py
). This allows for easier try-out for the new ideas as the modifications can probably done purely on the python side. It could also serves as a gradient checker for the CUDA implementation. - Fused implementation with verified Gradients: Every gradient in this PR is carefully checked against PyTorch Auto-Grad. On the other hand, the current GSplat implementation has flaws (errors) in the gradients including -- 1. FoV clip is applied in the forward pass but not the backward pass. As a result, the gradients are incorrect on some GSs. -- 2. The backward suffers from a numerical issue caused by using the remained transmittance output from the forward pass because the remained transmittance can be very small (~1e-4). Double precision instead of float should be used for it.
- Faster FWD and BWD (esp. in ND case): See table below for details.
Gradient Diff results again PyTorch Auto Grad (code in experimental/test.py
):
viewmats | scales | quats | means | |
---|---|---|---|---|
GSplat | 42.8834 | 2e-4 | 5e-4 | 4e-3 |
This PR | 4e-3 | 2e-7 | 5e-6 | 3e-6 |
Runtime Performance (in ms) for Forward (code in experimental/profile.py
):
Forward, Channel=1 | Batch=1 | Batch=8 | Batch=64 |
---|---|---|---|
GSplat | 1.492 | 9.921 | 76.295 |
This PR | 1.146 | 3.593 | 28.948 |
Forward, Channel=32 | Batch=1 | Batch=8 | Batch=64 |
---|---|---|---|
GSplat | 5.078 | 32.824 | 264.146 |
This PR | 1.930 | 7.773 | 64.461 |
Runtime Performance for Backward (disable viewmats gradients):
Backward, Channel=1 | Batch=1 | Batch=8 | Batch=64 |
---|---|---|---|
GSplat | 1.594 | 10.600 | 84.395 |
This PR | 1.562 | 7.409 | 60.402 |
Backward, Channel=32 | Batch=1 | Batch=8 | Batch=64 |
---|---|---|---|
GSplat | 10.242 | 80.935 | 653.627 |
This PR | 7.850 | 43.392 | 353.733 |
Runtime Performance for Backward (enable viewmats gradients):
Backward, Channel=3 | Batch=1 | Batch=8 | Batch=64 |
---|---|---|---|
GSplat | 2.721 | 16.629 | 130.977 |
This PR | 2.183 | 10.697 | 85.817 |