sphericart
sphericart copied to clipboard
Make APIs as uniform as possible
At the moment, the APIs for C/C++/NumPy/torch, JAX, Julia and CUDA are all slightly different. We should discuss up to what point we should aim at making them uniform and where we should instead give way to the idioms of each language/framework
Im open for anything really since at our end we have to write wrappers anyhow to make this compatible with how we organize computations.
We had a discussion about this today, here is a quick summary:
- use separate classes/objects for spherical harmonics and solid harmonics instead of a "normalize" parameter
- use functional-looking API for most things, using a custom
__call__for numpy and torch
Julia
Currently does something like this
basis = SolidHarmonics(10)
# one of
sph = basis(R)
sph = compute(basis, R)
Jax
Currently does something like this
sph = compute_spherical_harmonics(lmax=10, R, normalize=True)
We can change it to
calculator = SphericalHarmonics(lmax=10)
sph = compute_spherical_harmonics(calculator, R)
sph = calculator(R)
calculator = SolidHarmonics(lmax=10)
sph = compute_solid_harmonics(calculator, R)
sph = calculator(R)
Torch/Numpy
We can change these to
calculator = SphericalHarmonics(lmax=10)
sph = calculator(R)
# add this one
calculator = SolidHarmonics(lmax=10)
sph = calculator(R)
Also would be nice to be very explicit about which part of the "what sphericart computes" article maps to which arguments/classes. Currently it's technically written down, but a bit hard to parse.
My only gripe with what you suggest is that the type of the calculator should determine whether it's spherical or solids - hence just "compute".
But as a general rule I'm not certain that unifying too much is even a good thing. Different languages and different frameworks make different usage conventions natural.
For jax, I'd suggest following the e3x API, i.e., just having functions solid_harmonics and spherical_harmonics. compute_spherical_harmonics accepting some opaque calculator object as argument doesn't seem very intuitive and probably doesn't play nicely with jit.
For torch, the suggest approach is good -- the classes should have a forward(self, xyz) and nothing else.
The idea for jax was that we can use the calculator object to cache the state we need (right now we are using a hidden global cache). We are pretty confident it should be possible to make it work with the jit (maybe defining the state as a PyTree).
My only gripe with what you suggest is that the type of the calculator should determine whether it's spherical or solids - hence just "compute".
This works really well in Julia where type-based dispatch is king, but in jax in particular this patter is much harder to integrate with the function transformations (jit, grad, vmap, …)
Hm... does this state even need to be visible to jax/python at all if you're already calling out to custom code? Otherwise, you maybe can initialise it at tracing/"compile" time... it seems a bit clunky to carry around some state for this. I assume this needs to be updated/changed based on the requested spherical harmonic order? Or what is being cached here?