pytorch.github.io
pytorch.github.io copied to clipboard
Wrong Code in the FSDP blog
📚 Documentation
In the blog introducing FSDP API
fsdp_model = FullyShardedDataParallel(
model(),
fsdp_auto_wrap_policy=default_auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
)
it should be model
instead of model()
inside FullyShardedDataParallel
so it should be
fsdp_model = FullyShardedDataParallel(
model,
fsdp_auto_wrap_policy=default_auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
)
It looks a little confusing and maybe could be written more clearly, but I think that's actually correct. If it was just model
, it would be trying to FSDP-wrap the DDP model. By using model()
, it's FSDP wrapping a new model instance.