examples icon indicating copy to clipboard operation
examples copied to clipboard

How to load Transformer model once using FSDP

Open ToddMorrill opened this issue 1 year ago • 0 comments

📚 Documentation

@HamidShojanazeri, I'm following your FSDP example and swapped in a bigger model, google/flan-t5-xxl, and am a little unclear on what happens when the script starts up. I'm running on a server with 8 V100s so I run the launch command as listed in the README.md file: torchrun --nnodes 1 --nproc_per_node 8 T5_training.py

Next, I was having trouble downloading the model weights because I think with 8 processes, each one was trying to download the weights and they were removing each others' file locks, so I changed the setup_model function so that only rank 0 downloads the weights and then all other processes will read from the local cache.

Finally, my big question for you is - as the setup_model function is currently written, is it fair to say that we're loading a copy of the model weights for every process running (e.g. in my case, 8 processes)? If so, how can we load the model once and broadcast the weights to all other processes? I ask because this will become a blocker at bigger model scales because we'll eventually run out of CPU memory trying to do this.

Here's my modified setup_model function for reference:

def setup_model(model_name, model_max_length=512, cache_dir=None, rank=None):
    # TODO: is this loading the model on all processes?
    # 1) this seems time consuming, and 2) it seems like it would use way too much memory
    # ensure weights are only downloaded by one process
    if rank == 0:
        model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
        # set model_max_length to avoid warnings
        tokenizer =  T5Tokenizer.from_pretrained(model_name, model_max_length=model_max_length, cache_dir=cache_dir)
    dist.barrier()
    if rank != 0:
        model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
        # set model_max_length to avoid warnings
        tokenizer =  T5Tokenizer.from_pretrained(model_name, model_max_length=model_max_length, cache_dir=cache_dir)
    return model, tokenizer

I imagine this all gets easier and more memory efficient once we start saving the model in the formats you've specified in the model_checkpointing directory but we have to get there in the first place.

I should also note, in case it makes a difference, that I'm setting up the distributed process group (within T5_training.py) before calling setup_model, whereas you call setup_model before setting up the distributed process group in your example.

ToddMorrill avatar Aug 01 '23 22:08 ToddMorrill