[Bug] FSDP FULL_SHARD incorrectly rejects timm models with features_only=True (FeatureListNet) due to overly-strict nn.ModuleDict inheritance check
Describe the bug
When timm.create_model is called with the features_only=True argument, it returns a FeatureListNet module. This module cannot be correctly wrapped by torch.distributed.fsdp.FullyShardedDataParallel when using the FULL_SHARD strategy.
UserWarning: FSDP will not all-gather parameters for containers that do not implement forward: FeatureListNet(
(stem_0): Conv2d(3, 128, kernel_size=(4, 4), ...)
FSDP incorrectly identifies FeatureListNet as a container that does not implement forward, even though it does. This results in a ValueError
To Reproduce Steps to reproduce the behavior: 1.Minimal Reproduction Code (test1.py):
import timm
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
def main():
dist.init_process_group(backend="nccl")
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)
encoder = timm.create_model("convnextv2_base",pretrained=False,features_only=True)
# This returns a FeatureListNet, which inherits from nn.ModuleDict
print(type(encoder))
model = FSDP(
encoder,
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=torch.cuda.current_device(),
)# This will fail
print(f"Rank {local_rank}: FSDP wrapping SUCCEEDED .")
if __name__ == '__main__':
try:
main()
except Exception as e:
print(f"Rank {dist.get_rank()} FAILED with error: {e}")
finally:
if dist.is_initialized():
rank = dist.get_rank()
dist.destroy_process_group()
print(f"Rank {rank} cleaned up process group.")
2.Run Command
torchrun --nproc_per_node=2 test1.py
Expected behavior The encoder module (FeatureListNet) should be successfully wrapped by FSDP FULL_SHARD without error, as it is a valid nn.Module that implements its own forward method.
Desktop:
- OS: Ubuntu 22.04
- timm-1.0.22-py3-none-any.whl
- pytorch 2.6.0+cu124
Additional context
The root cause is a validation check in Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py:
if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
raise ValueError(
f"fully_shard does not support containers that do not implement forward: {module}"
)
In the timm library (Lib/site-packages/timm/models/_features.py), the FeatureListNet and FeatureDictNet classes inherit from nn.ModuleDict.
Crucially, both FeatureListNet and FeatureDictNet implement their own forward methods.
The FSDP validation is too strict. It only checks if the module isinstance of nn.ModuleDict and immediately raises the ValueError, without first checking if a forward method has been implemented by the inheriting class.
@wenwwww you'd think it'd check for existence of forward() instead of doing an isinstance check :/
Does it work fine if you hack their impl to skip that check?
If everything works fine minus that check, I could just collapse 'ModuleDict' functionality into the feature wrappers themselves (manipulate ._modules directly), or implement a cut down version in BasicModuleDict (which seems really silly)
@wenwwww you'd think it'd check for existence of forward() instead of doing an isinstance check :/
Does it work fine if you hack their impl to skip that check?
My apologies, I previously misunderstood the FSDP versions. I was using Accelerator to control FSDP2, which led to an error, but it turns out the minimal reproducible code I provided above was actually for FSDP1.
FSDP1 Minimal Reproducible Code
import timm
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
def main():
dist.init_process_group(backend="nccl")
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)
encoder = timm.create_model("convnextv2_base",pretrained=False,features_only=True)
# This returns a FeatureListNet, which inherits from nn.ModuleDict
print(type(encoder))
model = FSDP(
encoder,
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=torch.cuda.current_device(),
)
if __name__ == '__main__':
try:
main()
except Exception as e:
print(f"Rank {dist.get_rank()} FAILED with error: {e}")
finally:
if dist.is_initialized():
rank = dist.get_rank()
dist.destroy_process_group()
print(f"Rank {rank} cleaned up process group.")
Command
torchrun --nproc_per_node=2 test1.py
Output
UserWarning: FSDP will not all-gather parameters for containers that do not implement forward: FeatureListNet(
(stem_0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
This is related to the code in Lib/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py, but it's actually just a warning. A ModuleDict with a forward method should be processed correctly. Here is a related PyTorch forum discussion: https://discuss.pytorch.org/t/should-forward-be-banned-from-modulelist-with-fsdp/207096?utm_source=chatgpt.com
if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
warnings.warn(
"FSDP will not all-gather parameters for containers that do "
f"not implement forward: {module}",
stacklevel=2,
)
However, FSDP2 has become stricter. FSDP2 Minimal Reproducible Code
import timm
import torch
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
dist.init_process_group(backend="nccl")
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)
encoder = timm.create_model("convnextv2_base", pretrained=False, features_only=True)
# This returns a FeatureListNet, which inherits from nn.ModuleDict
for module in encoder.modules():
fully_shard(module)
fully_shard(encoder)
Command
torchrun --nproc_per_node=2 test2.py
Output
File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py", line 187, in fully_shard
[rank0]: raise ValueError(
[rank0]: ValueError: fully_shard does not support containers that do not implement forward: FeatureListNet(
[rank0]: (stem_0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
This is due to checks inLib/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py:
if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
raise ValueError(
f"fully_shard does not support containers that do not implement forward: {module}"
)
I tried commenting out these isinstance checks, but it led to some issues. I am using Accelerator here because I'm not familiar enough with FSDP2 at the moment.
Accelerator Configuration
compute_environment: LOCAL_MACHINE
debug: true
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: SIZE_BASED_WRAP
fsdp_cpu_ram_efficient_loading: false
fsdp_min_num_params: 10000000
fsdp_offload_params: false
fsdp_reshard_after_forward: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Code
from contextlib import contextmanager
import torch
import torchvision
from torch.utils.data import Dataset
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torch.optim import AdamW
import timm
num_epochs=10
num_vector_for_loss=128
class TestDataset(Dataset):
def __init__(self):
self.data1=torch.rand((200,3,512,1024))
self.data2=torch.randint(0,18,(200,512,1024),dtype=torch.long)
def __len__(self):
return len(self.data1)
def __getitem__(self, idx):
return self.data1[idx], self.data2[idx]
model = timm.create_model("resnet101", pretrained=False, features_only=True)
for param in model.parameters():
param.requires_grad = True
optimizer = AdamW([{'params':model.parameters(), 'lr':5e-4,'weight_decay':1e-2},])
train_dataset = TestDataset()
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
accelerator = Accelerator(mixed_precision="fp16")
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
model.train()
for epoch in range(num_epochs):
for batch in train_loader:
if accelerator.is_main_process:
accelerator.print(f"start")
optimizer.zero_grad()
inputs, targets = batch
outputs = model(inputs)
if accelerator.is_main_process:
accelerator.print('model outputs')
size=outputs.shape[-1]*outputs.shape[-2]
sample_indices=torch.randint(0,size,(num_vector_for_loss,),device=outputs.device)
outputs = outputs.view(outputs.shape[0],outputs.shape[1],-1)[:,:,sample_indices]
targets_temp=targets.view(targets.shape[0],-1)[:,sample_indices]
losses=torch.nn.functional.cross_entropy(outputs , targets_temp)
if accelerator.is_main_process:
accelerator.print(f"Loss: {losses}")
if accelerator.is_main_process:
accelerator.print(f"backward")
accelerator.backward(losses)
optimizer.step()
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.print("finish")
Output
(base) root@autodl-container-40f843a50e-4b50f4d5:~# accelerate launch --config_file /root/autodl-tmp/default_config.yaml /root/autodl-tmp/test2.py
/root/miniconda3/lib/python3.12/site-packages/accelerate/utils/fsdp_utils.py:707: UserWarning: FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints.
warnings.warn(
start
[rank0]:[E1114 16:03:10.690560438 ProcessGroupNCCL.cpp:1896] [PG ID 0 PG GUID 0(default_pg) Rank 0] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at /pytorch/c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x98 (0x7f2424d1e5e8 in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xe0 (0x7f2424cb34a2 in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x3c2 (0x7f2495d60422 in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f2425a8b456 in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x70 (0x7f2425a9b6f0 in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x782 (0x7f2425a9d282 in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f2425a9ee8d in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdbbf4 (0x7f2415dcabf4 in /root/miniconda3/bin/../lib/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f24969c0ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: clone + 0x44 (0x7f2496a51a04 in /usr/lib/x86_64-linux-gnu/libc.so.6)
[rank1]:[E1114 16:03:10.847308672 ProcessGroupNCCL.cpp:1896] [PG ID 0 PG GUID 0(default_pg) Rank 1] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at /pytorch/c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x98 (0x7fe9c4f785e8 in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xe0 (0x7fe9c4f0d4a2 in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x3c2 (0x7fe9c53d1422 in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7fe954c8b456 in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x70 (0x7fe954c9b6f0 in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x782 (0x7fe954c9d282 in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7fe954c9ee8d in /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdbbf4 (0x7fe944fcabf4 in /root/miniconda3/bin/../lib/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7fe9c5d7cac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: clone + 0x44 (0x7fe9c5e0da04 in /usr/lib/x86_64-linux-gnu/libc.so.6)
W1114 16:03:10.997000 58449 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 58516 closing signal SIGTERM
E1114 16:03:11.212000 58449 site-packages/torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: -6) local_rank: 0 (pid: 58515) of binary: /root/miniconda3/bin/python
Traceback (most recent call last):
File "/root/miniconda3/bin/accelerate", line 8, in <module>
sys.exit(main())
^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/accelerate/commands/accelerate_cli.py", line 50, in main
args.func(args)
File "/root/miniconda3/lib/python3.12/site-packages/accelerate/commands/launch.py", line 1222, in launch_command
multi_gpu_launcher(args)
File "/root/miniconda3/lib/python3.12/site-packages/accelerate/commands/launch.py", line 853, in multi_gpu_launcher
distrib_run.run(args)
File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/run.py", line 883, in run
elastic_launch(
File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 139, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 270, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
======================================================
/root/autodl-tmp/test2.py FAILED
------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2025-11-14_16:03:10
host : autodl-container-40f843a50e-4b50f4d5
rank : 0 (local_rank: 0)
exitcode : -6 (pid: 58515)
error_file: <N/A>
traceback : Signal 6 (SIGABRT) received by PID 58515
======================================================
I probably still need some time to look into this.
My issue was that I wrote the third code snippet incorrectly. The subsequent CUDA error occurred because my NCCL version was too old, which led to communication problems. The original version was nccl nvidia/linux-64::nccl-2.28.7-h35aabad_0. After upgrading to PyTorch 2.8.0 with a compatible NCCL, and then commenting out the isinstance check, FSDP2 should be able to run correctly.
The YAML file is the same as above
test_code
import torch
from torch import nn
import timm
from torch.utils.data import Dataset
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn.functional as F
num_epochs = 10
num_vector_for_loss = 128
BATCH_SIZE = 4
class TestDataset(Dataset):
def __init__(self):
super().__init__()
self.data1 = torch.rand((200, 3, 512, 1024))
self.data2 = torch.rand((200,1,512, 1024))
def __len__(self):
return len(self.data1)
def __getitem__(self, idx):
# Return the input image and the target mask
return self.data1[idx], self.data2[idx]
base_model = timm.create_model("resnet101", features_only=True)
class TR(nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
self.head = nn.Conv2d(2048, 1, kernel_size=1)
def forward(self, x):
target_size = x.shape[-2:]
features = self.encoder(x)
last_features = features[-1]
logits = self.head(last_features)
logits = F.interpolate(logits, size=target_size, mode='bilinear', align_corners=False)
return logits # Final shape: (B, 1, 512, 1024)
model = TR(base_model)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
# --- Training Setup ---
for param in model.parameters():
param.requires_grad = True
optimizer = AdamW([{'params': model.parameters(), 'lr': 5e-4, 'weight_decay': 1e-2},])
train_dataset = TestDataset()
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
accelerator = Accelerator()
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
criterion = nn.SmoothL1Loss()
# --- Training Loop ---
def main():
model.train()
for epoch in range(num_epochs):
for batch in train_loader:
# Unpack the batch
inputs, targets = batch
if accelerator.is_main_process:
accelerator.print(f"--- Epoch {epoch+1}/{num_epochs} ---")
optimizer.zero_grad()
outputs = model(inputs) # `outputs` is (B, 1, 512, 1024)
# `targets` is (B, 1, 512, 1024)
# --- Loss Calculation (Completed) ---
loss = criterion(outputs, targets)
# ------------------------------------
if accelerator.is_main_process:
accelerator.print(f"Loss: {loss.item()}")
# Backward pass
accelerator.backward(loss)
# Optimizer step
optimizer.step()
accelerator.wait_for_everyone()
if accelerator.is_main_process:
accelerator.print("Training finished.")
# --- Run Training ---
main()
accelerator.end_training()
Command
accelerate launch --config_file default_config.yaml test.py
@wenwwww thanks, yeah I picked up the FSDP1 vs FSDP2 concern there.
I created #2610 ... do you want to check if the code on that PR branch works for you? I haven't tested extensively, but basically pulled out the ModuleDict interface and a few other Feature*Net related things into their own base class...
So, I remember why I need to inherit from ModuleDict, the module lookup by key doesn't work in torchscript unless I do that. There's special handling of ModuleDict that AFAIK can't be done at the user level. I can't access _modules when torchscript is used, other workarounds like getattr() fails with non literal keys, get_submodule() doesn't work well with torchscript,
So, I remember why I need to inherit from ModuleDict, the module lookup by key doesn't work in torchscript unless I do that. There's special handling of ModuleDict that AFAIK can't be done at the user level. I can't access when torchscript is used, other workarounds like fails with non literal keys, doesn't work well with torchscript,
_modules``getattr()``get_submodule()
Thanks for the explanation. It looks like the best solution is to modify the PyTorch FSDP2 source code, and I noticed you've already reported this to PyTorch
@wenwwww yup, I reported it, hopefully can have a small modification made there, there's probably a way to monkey patch things to confuse the isinstance check...