accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

Can accelerate train a single model on multiple TPU VMs (not a TPU Pod)?

Open codetomakechange opened this issue 2 years ago • 7 comments

I know accelerate supports training on a TPU Pod( #1049 and #471 ), but I only have 5 TPU v3-8 VMs obtained from the TRC program. I wanted to use all these 5 VMs to train a single model, but I could not find an appropriate way to make it work by using accelerate. Can someone kindly tell me how I can make this work, or is it just not possible?

codetomakechange avatar Jun 02 '23 03:06 codetomakechange

cc @muellerzr

sgugger avatar Jun 02 '23 11:06 sgugger

@codetomakechange out of curiosity have you found a way to do this without accelerate launch? (aka native torch-xla?) If not that's okay, will look into this soon

muellerzr avatar Jun 02 '23 12:06 muellerzr

@muellerzr

@codetomakechange out of curiosity have you found a way to do this without accelerate launch? (aka native torch-xla?) If not that's okay, will look into this soon

I did look into the docs of torch-xla, and it seems that torch-xla does not support this either.

codetomakechange avatar Jun 02 '23 12:06 codetomakechange

In that case we won't until they do :)

muellerzr avatar Jun 02 '23 12:06 muellerzr

I'd recommend opening an issue on the xla repo

muellerzr avatar Jun 02 '23 12:06 muellerzr

Probably, you can't do this because the TPUs from different VMs are not connected. If you want greater compute power, you need to use VMs with higher TPU cores available on GCP, such as V3-64.

carlesoctav avatar Jun 05 '23 11:06 carlesoctav

Probably, you can't do this because the TPUs from different VMs are not connected. If you want greater compute power, you need to use VMs with higher TPU cores available on GCP, such as V3-64.

TPUs from different VMs are not connected internally, for sure. But these VMs are connected through the same LAN, so I think this scenriao could be considered as a distributed training enviroment.

I'd recommend opening an issue on the xla repo

Thank you for your advice. I tried to train on a single TPU VM, and it worked. But The training is unbearably slow. I used the xla profiler to debug, and got lots of warinnings to tell me to open a issuse and send the report. It seems that pytorch/xla still has some key operations unimplemented on TPUs as of now. Therefore, I had to give up on pytorch/xla. Maybe tensorflow or flax is the right choice for TPUs.

codetomakechange avatar Jun 08 '23 01:06 codetomakechange