score_sde icon indicating copy to clipboard operation
score_sde copied to clipboard

the jax-based code on multi-host tpu

Open lucasliunju opened this issue 3 years ago • 3 comments

Hi Yang,

That's a great work. I would like to ask whether this code can run on the multi-host tpu (such as v3-32). And could you give me some advice on how to change this code for it.

Thank you very much!

Yong

lucasliunju avatar Mar 22 '21 16:03 lucasliunju

This code can be directly run on multi-host TPU without modification. Internally we used multi-host TPU for most of our training.

yang-song avatar Mar 22 '21 17:03 yang-song

Hi Yang,

Thanks so much for your help!

Yong

lucasliunju avatar Mar 23 '21 05:03 lucasliunju

Hi Yang,

I still have a question about how to connect jax and multi-host tpu since I have not find this part on the code.

My current code is :

from jax.config import config
# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://[ip address]:8470"

And I find the warning:

2021-03-23 16:47:20.111472: W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:606] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program.

Thank you very much!

Yong

lucasliunju avatar Mar 23 '21 16:03 lucasliunju