David Devoogdt
David Devoogdt
I don't really need a round trip, just code to convert a tensorflow function to to jax function f: R^(nx3)->R^m (i.e a m dimensional function of n 3d molecular coordinates)....
Seems good to me. Another option would be to add a configurable flag `SLOW_TF_VMAP = False`. If the flag is not set to `True`, raise an error within batching function...
Some parts of my code were annoyingly slow, so I've reimplemented the batcher using the `tf.vectorized_map` function. Basically I redo the call_tf on the vectorized tensoflow function and pipe the...