lecture-jax
lecture-jax copied to clipboard
block until ready
I think this syntax
start = time.time()
result2 = jnp.cos(X)
jax.device_get(result2) # Force execution completion
second_time = time.time() - start
is nicer than .block_until_ready() for timing.
Could someone please test this and confirm that it's a suitable alternative. If yes we'll change throughout.
CC @mmcky -- perhaps you could coordinate with @HumphreyYang or Bishmay?
I also prefer that we move away from %time and %%time because they cause problems with jupytext conversion.
Many thanks @jstac, it's a very interesting change!
jax.device_get(result2) # Force execution completion
Are we trying to also measure the time it takes to get from the device to the RAM?
Below is the run time differences between %%time + .block_until_ready() and the new approach:
One minor caveat would be that the print out time would not be formatted like %%time.
Oh, the timing is very different. I wonder why that is...
Thanks @HumphreyYang.
Hi @jstac,
I think the new approach will take extra time transfering data from the device to the RAM:
This would cost some time!