lecture-jax icon indicating copy to clipboard operation
lecture-jax copied to clipboard

block until ready

Open jstac opened this issue 6 months ago • 3 comments

I think this syntax

start = time.time()
result2 = jnp.cos(X)
jax.device_get(result2)  # Force execution completion
second_time = time.time() - start

is nicer than .block_until_ready() for timing.

Could someone please test this and confirm that it's a suitable alternative. If yes we'll change throughout.

CC @mmcky -- perhaps you could coordinate with @HumphreyYang or Bishmay?

I also prefer that we move away from %time and %%time because they cause problems with jupytext conversion.

jstac avatar May 23 '25 00:05 jstac

Many thanks @jstac, it's a very interesting change!

jax.device_get(result2)  # Force execution completion

Are we trying to also measure the time it takes to get from the device to the RAM?

Below is the run time differences between %%time + .block_until_ready() and the new approach:

Image

One minor caveat would be that the print out time would not be formatted like %%time.

HumphreyYang avatar May 23 '25 01:05 HumphreyYang

Oh, the timing is very different. I wonder why that is...

Thanks @HumphreyYang.

jstac avatar May 23 '25 03:05 jstac

Hi @jstac,

I think the new approach will take extra time transfering data from the device to the RAM:

Image

This would cost some time!

HumphreyYang avatar May 23 '25 03:05 HumphreyYang