David Devoogdt

Results 3 comments of David Devoogdt

I don't really need a round trip, just code to convert a tensorflow function to to jax function f: R^(nx3)->R^m (i.e a m dimensional function of n 3d molecular coordinates)....

Seems good to me. Another option would be to add a configurable flag `SLOW_TF_VMAP = False`. If the flag is not set to `True`, raise an error within batching function...

Some parts of my code were annoyingly slow, so I've reimplemented the batcher using the `tf.vectorized_map` function. Basically I redo the call_tf on the vectorized tensoflow function and pipe the...