examples
examples copied to clipboard
How to load Transformer model once using FSDP
📚 Documentation
@HamidShojanazeri, I'm following your FSDP example and swapped in a bigger model, google/flan-t5-xxl
, and am a little unclear on what happens when the script starts up. I'm running on a server with 8 V100s so I run the launch command as listed in the README.md file:
torchrun --nnodes 1 --nproc_per_node 8 T5_training.py
Next, I was having trouble downloading the model weights because I think with 8 processes, each one was trying to download the weights and they were removing each others' file locks, so I changed the setup_model
function so that only rank 0 downloads the weights and then all other processes will read from the local cache.
Finally, my big question for you is - as the setup_model
function is currently written, is it fair to say that we're loading a copy of the model weights for every process running (e.g. in my case, 8 processes)? If so, how can we load the model once and broadcast the weights to all other processes? I ask because this will become a blocker at bigger model scales because we'll eventually run out of CPU memory trying to do this.
Here's my modified setup_model
function for reference:
def setup_model(model_name, model_max_length=512, cache_dir=None, rank=None):
# TODO: is this loading the model on all processes?
# 1) this seems time consuming, and 2) it seems like it would use way too much memory
# ensure weights are only downloaded by one process
if rank == 0:
model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
# set model_max_length to avoid warnings
tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=model_max_length, cache_dir=cache_dir)
dist.barrier()
if rank != 0:
model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
# set model_max_length to avoid warnings
tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=model_max_length, cache_dir=cache_dir)
return model, tokenizer
I imagine this all gets easier and more memory efficient once we start saving the model in the formats you've specified in the model_checkpointing directory but we have to get there in the first place.
I should also note, in case it makes a difference, that I'm setting up the distributed process group (within T5_training.py
) before calling setup_model
, whereas you call setup_model
before setting up the distributed process group in your example.