pyscf-ipu
pyscf-ipu copied to clipboard
Do everything except one operation in float64
Transform the Jax graph to perform everything in float64 except a set of user-specified operations. May not be possible, we need to think about what that would look like as a Jax graph transformation.
I had the idea to introduce a function decorator that will run an operation twice:
- a baseline run with all the floating point inputs promoted to fp64
- a second run with fp32
Then report a difference (possibly to stdout or save to a npz file?). What I'm not sure about is how to dynamically annotate a graph to do this but perhaps others have some insight into the jax-way to attempt that.
If that sounds like a useful step I could draft a PR with the decorator.
I had the idea to introduce a function decorator that will run an operation twice:
Do "operation=nanoDFT" or e.g. "operation=einsum(eri, dm)"? I was thinking that we' run nanoDFT twice, first everything float64, then second where a single decorated operation is in float32. Is this also what you're considering?
I put together #110 which contains just the function decorator idea: the problem I haven't solved is how to inject it into the desired place within a larger compute graph. I think for that we may need a syntax to say "decorate function foo" and compiler pass that does a find and replace on foo in the compute graph.
Or maybe isn't the right way to approach the problem in JAX?