torchsnapshot icon indicating copy to clipboard operation
torchsnapshot copied to clipboard

Issue Loading FSDP wrapped module using FULL_STATE_DICT type.

Open hbikki opened this issue 2 years ago • 3 comments

🐛 Describe the bug

Hello , I am working on training a pretrained hugging face model "t5-small". Using the torchsnpashot examples provided form the documentaion, I am able to save/load checkpoint for LOCAL_STATE_DICT type, I am also able to save the model checkpoint for FULL_STATE_DICT. But, when loading the full statedict checkpoint I am facing the below issue.

Versions: pytorch = 2.0.0+cu117 torchx-nightly>=2023.3.15 torchsnapshot=0.1.0

Host Details: The bellow training is tested on a single node with 8 NPROC_PER_NODE.

Code:

Model training code:

def train() -> None:
    init_process_group(backend="nccl")
    torch.cuda.empty_cache()
    torch.cuda.set_device(local_rank())
    model = load_model("t5-small")

    fsdp_model = FSDP(
        model,
        auto_wrap_policy=functools.partial(
            transformer_auto_wrap_policy, transformer_layer_cls={T5Block}
        ),
        sharding_strategy=ShardingStrategy.HYBRID_SHARD,
        device_id=local_rank(),
    )
    <-------training -loop-->
    <-------save_checkpoint-->

stateDictType = FULL_STATE_DICT
related saving/loading code:

  def save_checkpoint() -> None:
        with FSDP.state_dict_type(
            checkpoint.model,
            self.stateDictType):
            Snapshot.take(path=str(save_dir), app_state=app_state)

    def load_checkpoint() -> None:
        with FSDP.state_dict_type(checkpoint.model, self.stateDictType):
            Snapshot(path=str(load_dir)).restore(app_state=app_state)
   

Error stack trace: https://pastebin.com/ih9qSbwR

.snapshot_metadata for the model on local rank: https://pastebin.com/t6grkKyX

Does anyone know how to resolve this ? thanks!

hbikki avatar May 03 '23 21:05 hbikki

/assigntome

xanderex-sid avatar Nov 01 '23 17:11 xanderex-sid

/assigntome

andrewashere avatar Nov 02 '23 22:11 andrewashere

/assigntome

markstur avatar Nov 06 '23 18:11 markstur

@svekars I see the deprecate label. Do you still think it is okay to keep this as a jit.trace and jit.script tutorial? I think I can get that done soon, but I noticed the deprecate label and was wondering if it was already decided to just archive it.

markstur avatar Nov 08 '23 06:11 markstur