gsplat icon indicating copy to clipboard operation
gsplat copied to clipboard

JAX support?

Open yklcs opened this issue 1 year ago • 5 comments

I'm working on a preliminary port of gsplat to JAX. It appears like it'd be possible if I were to reuse the CUDA kernels (mostly) as-is and heavily modify the bindings. But it would also require substantial changes to the Python-side code and overall API. JAX's custom GPU FFI requires quite a bit of boilerplate.

I was wondering if there's any interest in merging JAX support into gsplat if I were to create a PR, or, even better, if there's a maintainer interested in collaborating to support JAX.

If not I'll just create a hard fork. Thanks.

yklcs avatar May 08 '24 02:05 yklcs

I ended up creating my own implementation based on gsplat here: https://github.com/yklcs/jaxsplat

yklcs avatar May 13 '24 15:05 yklcs

@yklcs this is cool work, are you planning to implement a full training pipeline in JAX?

This is something I'm very curious about, especially because splatfacto currently relies a lot on things like dynamic shapes and boolean masking (which are hard in JAX).

brentyi avatar May 13 '24 16:05 brentyi

Yes, dynamic shapes are a problem: as of right now JIT doesn't work. gaussian_ids_sorted is num_intersects long which is dynamic depending on num_tiles_hit. So a full pipeline would need to come after fixing that unless no JIT is acceptable.

I'm not sure what the best way of removing the dynamic shape is. num_intersects is bounded by num_tiles * num_points, which is probably too big to store. The tiling and binning approach may just be incompatible with statically known shapes. Maybe someone else has better ideas?

yklcs avatar May 14 '24 08:05 yklcs

Yeah, it's an interesting problem!

It seems hard to make this useful without JIT. For making the shape static, could a MAX_INTERSECTS or MAX_AVG_INTERSECTS_PER_GAUSSIAN constant be good enough? If the number of intersects exceeds the constant:

  • Maybe Gaussians can be prioritized based on distance or alpha, and any "overflow" can just be ignored?
  • It seems possible to reduce memory usage by trading for computation, perhaps the forward/backward can be done in multiple passes? After the Gaussians are sorted it seems possible to chunk them by distance, rasterize separately, and then alpha-composite?

This could also be a feature and not a bug. Having num_intersects = num_tiles * num_points doesn't seem unreasonable (for example, if we have only large Gaussians), and choosing some well-defined behavior seems better than a spurious OOM.

brentyi avatar May 14 '24 11:05 brentyi

JIT now works with jaxsplat: I took the simple approach of simply recalculating gaussian_ids in the backwards pass. I'll see if there's a better approach later on, those ideas seem worth exploring.

Yeah, it's an interesting problem!

It seems hard to make this useful without JIT. For making the shape static, could a MAX_INTERSECTS or MAX_AVG_INTERSECTS_PER_GAUSSIAN constant be good enough? If the number of intersects exceeds the constant:

  • Maybe Gaussians can be prioritized based on distance or alpha, and any "overflow" can just be ignored?
  • It seems possible to reduce memory usage by trading for computation, perhaps the forward/backward can be done in multiple passes? After the Gaussians are sorted it seems possible to chunk them by distance, rasterize separately, and then alpha-composite?

This could also be a feature and not a bug. Having num_intersects = num_tiles * num_points doesn't seem unreasonable (for example, if we have only large Gaussians), and choosing some well-defined behavior seems better than a spurious OOM.

yklcs avatar May 15 '24 10:05 yklcs