sphericart icon indicating copy to clipboard operation
sphericart copied to clipboard

Make APIs as uniform as possible

Open frostedoyster opened this issue 1 year ago • 7 comments

At the moment, the APIs for C/C++/NumPy/torch, JAX, Julia and CUDA are all slightly different. We should discuss up to what point we should aim at making them uniform and where we should instead give way to the idioms of each language/framework

frostedoyster avatar Jan 19 '24 18:01 frostedoyster

Im open for anything really since at our end we have to write wrappers anyhow to make this compatible with how we organize computations.

cortner avatar Jan 19 '24 22:01 cortner

We had a discussion about this today, here is a quick summary:

  • use separate classes/objects for spherical harmonics and solid harmonics instead of a "normalize" parameter
  • use functional-looking API for most things, using a custom __call__ for numpy and torch

Julia

Currently does something like this

basis = SolidHarmonics(10)

# one of
sph = basis(R)
sph = compute(basis, R)

Jax

Currently does something like this

sph = compute_spherical_harmonics(lmax=10, R, normalize=True)

We can change it to

calculator = SphericalHarmonics(lmax=10)
sph = compute_spherical_harmonics(calculator, R)
sph = calculator(R)

calculator = SolidHarmonics(lmax=10)
sph = compute_solid_harmonics(calculator, R)
sph = calculator(R)

Torch/Numpy

We can change these to

calculator = SphericalHarmonics(lmax=10)
sph = calculator(R)

# add this one
calculator = SolidHarmonics(lmax=10)
sph = calculator(R)

Luthaf avatar Apr 08 '24 12:04 Luthaf

Also would be nice to be very explicit about which part of the "what sphericart computes" article maps to which arguments/classes. Currently it's technically written down, but a bit hard to parse.

sirmarcel avatar Apr 08 '24 18:04 sirmarcel

My only gripe with what you suggest is that the type of the calculator should determine whether it's spherical or solids - hence just "compute".

But as a general rule I'm not certain that unifying too much is even a good thing. Different languages and different frameworks make different usage conventions natural.

cortner avatar Apr 08 '24 19:04 cortner

For jax, I'd suggest following the e3x API, i.e., just having functions solid_harmonics and spherical_harmonics. compute_spherical_harmonics accepting some opaque calculator object as argument doesn't seem very intuitive and probably doesn't play nicely with jit.

For torch, the suggest approach is good -- the classes should have a forward(self, xyz) and nothing else.

sirmarcel avatar Apr 09 '24 08:04 sirmarcel

The idea for jax was that we can use the calculator object to cache the state we need (right now we are using a hidden global cache). We are pretty confident it should be possible to make it work with the jit (maybe defining the state as a PyTree).

My only gripe with what you suggest is that the type of the calculator should determine whether it's spherical or solids - hence just "compute".

This works really well in Julia where type-based dispatch is king, but in jax in particular this patter is much harder to integrate with the function transformations (jit, grad, vmap, …)

Luthaf avatar Apr 09 '24 10:04 Luthaf

Hm... does this state even need to be visible to jax/python at all if you're already calling out to custom code? Otherwise, you maybe can initialise it at tracing/"compile" time... it seems a bit clunky to carry around some state for this. I assume this needs to be updated/changed based on the requested spherical harmonic order? Or what is being cached here?

sirmarcel avatar Apr 09 '24 10:04 sirmarcel