pyscf-ipu icon indicating copy to clipboard operation
pyscf-ipu copied to clipboard

Do everything except one operation in float64

Open AlexanderMath opened this issue 1 year ago • 3 comments

Transform the Jax graph to perform everything in float64 except a set of user-specified operations. May not be possible, we need to think about what that would look like as a Jax graph transformation.

AlexanderMath avatar Sep 20 '23 11:09 AlexanderMath

I had the idea to introduce a function decorator that will run an operation twice:

  1. a baseline run with all the floating point inputs promoted to fp64
  2. a second run with fp32

Then report a difference (possibly to stdout or save to a npz file?). What I'm not sure about is how to dynamically annotate a graph to do this but perhaps others have some insight into the jax-way to attempt that.

If that sounds like a useful step I could draft a PR with the decorator.

hatemhelal avatar Sep 21 '23 08:09 hatemhelal

I had the idea to introduce a function decorator that will run an operation twice:

Do "operation=nanoDFT" or e.g. "operation=einsum(eri, dm)"? I was thinking that we' run nanoDFT twice, first everything float64, then second where a single decorated operation is in float32. Is this also what you're considering?

AlexanderMath avatar Sep 21 '23 14:09 AlexanderMath

I put together #110 which contains just the function decorator idea: the problem I haven't solved is how to inject it into the desired place within a larger compute graph. I think for that we may need a syntax to say "decorate function foo" and compiler pass that does a find and replace on foo in the compute graph.

Or maybe isn't the right way to approach the problem in JAX?

hatemhelal avatar Sep 26 '23 20:09 hatemhelal