torchsnapshot
torchsnapshot copied to clipboard
Issue Loading FSDP wrapped module using FULL_STATE_DICT type.
🐛 Describe the bug
Hello , I am working on training a pretrained hugging face model "t5-small". Using the torchsnpashot examples provided form the documentaion, I am able to save/load checkpoint for LOCAL_STATE_DICT type, I am also able to save the model checkpoint for FULL_STATE_DICT. But, when loading the full statedict checkpoint I am facing the below issue.
Versions: pytorch = 2.0.0+cu117 torchx-nightly>=2023.3.15 torchsnapshot=0.1.0
Host Details: The bellow training is tested on a single node with 8 NPROC_PER_NODE.
Code:
Model training code:
def train() -> None:
init_process_group(backend="nccl")
torch.cuda.empty_cache()
torch.cuda.set_device(local_rank())
model = load_model("t5-small")
fsdp_model = FSDP(
model,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls={T5Block}
),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
device_id=local_rank(),
)
<-------training -loop-->
<-------save_checkpoint-->
stateDictType = FULL_STATE_DICT
related saving/loading code:
def save_checkpoint() -> None:
with FSDP.state_dict_type(
checkpoint.model,
self.stateDictType):
Snapshot.take(path=str(save_dir), app_state=app_state)
def load_checkpoint() -> None:
with FSDP.state_dict_type(checkpoint.model, self.stateDictType):
Snapshot(path=str(load_dir)).restore(app_state=app_state)
Error stack trace: https://pastebin.com/ih9qSbwR
.snapshot_metadata for the model on local rank: https://pastebin.com/t6grkKyX
Does anyone know how to resolve this ? thanks!
/assigntome
/assigntome
/assigntome
@svekars I see the deprecate label. Do you still think it is okay to keep this as a jit.trace and jit.script tutorial? I think I can get that done soon, but I noticed the deprecate label and was wondering if it was already decided to just archive it.