[FEATURE] Unified `embed` Interfaces for Vision Transformer Models for MIM Pretrain Research
This feature request is related to challenges in Masked Image Modeling (MIM) pre-training using vision transformer models in timm. Currently, embedding and feature extraction are tightly coupled within forward_features, making it difficult to inject mask operations after initial embedding and positional encoding but before the transformer stages, which is a common MIM requirement. Researchers need to access embedded tokens for masking before passing them through subsequent transformer layers.
Describe the solution you'd like
I propose a refactoring of all vision transformer models(e.g. vit, swin_transformer, etc.) to uniformly expose two distinct interfaces:
embed(self, x): This method should take the input tensor x (e.g., image patches) and return the embedded vectors, including positional encodings if applicable. The output should be ready for the transformer encoder stages.forward_stages(self, x): This method should take the output from embed(x) (i.e., the embedded and position-encoded tokens) and pass them through the transformer encoder layers.
This separation would make the existing forward_features(x) effectively equivalent to forward_stages(embed(x)). This allows researchers to easily perform mask operations on the embedded tokens returned by embed(x) before passing them to forward_stages(x), enabling flexible MIM pre-training experiments.
Describe alternatives you've considered
I have considered alternative solutions, such as adding a mask parameter directly to the forward_features and forward methods, similar to vision_transformer's implementation.
https://github.com/huggingface/pytorch-image-models/blob/a7c5368ba0c8713dc1c9a98cc83bf46ddd02b0a0/timm/models/vision_transformer.py#L933-L935
While seemingly straightforward, this approach presents drawbacks:
- Function Signature: It still alters the method's interface, even if default parameter values mitigate direct breakage.
- Mask Value Flexibility: More importantly, it would limit flexibility in how masked-out positions are handled, restricting whether masked token values can be learned (e.g., a learnable mask token) or simply set to zero. Separating
embedandforward_stagesprovides full control.
Additional context
If this refactoring aligns with the library's design, I would gladly contribute a Pull Request to implement these changes across relevant vision transformer models.
@ryan-minato it's not possible to remove forward_features as an API at this point.
For research, this is Python and very flexible, don't see why some minor patching can't achieve the goal without requiring a re-write of the model, the code in forward_features isn't that complicated or extensive so could patch that method to and inject others to achieve the desired goal of separating embeds and block forward logic into two parts such that masking can be inserted... then can still use all other aspects of the library and patch a few vit & related models you might want to use.
Slightly closer to your goal, the NaFlexVit impl separates embeds into their own module, but still iterates over blocks in forward_features/forward_intermediates... one less patching step
https://github.com/huggingface/pytorch-image-models/blob/a7c5368ba0c8713dc1c9a98cc83bf46ddd02b0a0/timm/models/naflexvit.py#L1309-L1338
I fired off the request to Claude to provide an example, the patching aspects looked reasonable not sure if all it's MIM ideas made sense, but the patching is what I wanted to illustrate as a relatively clean solution to adapt native timm models...
import torch
import torch.nn as nn
from functools import wraps
import types
def patch_vit_for_mim(model):
"""
Patch a timm vision transformer model to add embed() and forward_stages() methods
for Masked Image Modeling (MIM) pretraining.
This works with most timm vision transformer variants including:
- vit_* (Vision Transformer)
- deit_* (DeiT)
- beit_* (BeiT)
- eva_* (EVA)
Args:
model: A timm vision transformer model instance
Returns:
The same model with added embed() and forward_stages() methods
"""
# Store the original forward_features method
original_forward_features = model.forward_features
def embed(self, x):
"""Extract embeddings including patch embedding and positional encoding."""
# Patch embedding
x = self.patch_embed(x)
# Add class token if it exists
if hasattr(self, 'cls_token') and self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
# Add positional embedding
if hasattr(self, 'pos_embed'):
x = x + self.pos_embed
elif hasattr(self, 'pos_embed_fn'):
x = self.pos_embed_fn(x)
# Apply pos_drop if it exists
if hasattr(self, 'pos_drop'):
x = self.pos_drop(x)
return x
def forward_stages(self, x):
"""Forward through transformer blocks and norm."""
# Apply patch dropout if it exists
if hasattr(self, 'patch_drop') and self.patch_drop is not None:
x = self.patch_drop(x)
# Forward through transformer blocks
if hasattr(self, 'blocks'):
x = self.blocks(x)
elif hasattr(self, 'layers'):
x = self.layers(x)
elif hasattr(self, 'encoder'):
x = self.encoder(x)
# Apply norm
if hasattr(self, 'norm'):
x = self.norm(x)
elif hasattr(self, 'ln_post'):
x = self.ln_post(x)
return x
def new_forward_features(self, x):
"""Reimplemented forward_features using embed and forward_stages."""
x = self.embed(x)
x = self.forward_stages(x)
return x
# Bind the methods to the model instance
model.embed = types.MethodType(embed, model)
model.forward_stages = types.MethodType(forward_stages, model)
model.forward_features = types.MethodType(new_forward_features, model)
return model
def patch_swin_for_mim(model):
"""
Patch a timm Swin Transformer model for MIM pretraining.
Swin has a different architecture with hierarchical stages.
"""
def embed(self, x):
"""Extract initial embeddings for Swin."""
x = self.patch_embed(x)
if hasattr(self, 'pos_drop'):
x = self.pos_drop(x)
return x
def forward_stages(self, x):
"""Forward through all Swin stages."""
# Forward through all layers (stages)
x = self.layers(x)
# Apply norm
x = self.norm(x)
# Apply avgpool if needed
if hasattr(self, 'avgpool'):
x = self.avgpool(x.transpose(1, 2)) # B N C -> B C N
x = torch.flatten(x, 1)
return x
# Bind methods
model.embed = types.MethodType(embed, model)
model.forward_stages = types.MethodType(forward_stages, model)
# Update forward_features to use the new methods
original_forward_features = model.forward_features
def new_forward_features(self, x):
x = self.embed(x)
x = self.forward_stages(x)
return x
model.forward_features = types.MethodType(new_forward_features, model)
return model
# Unified patching function
def patch_model_for_mim(model):
"""
Automatically patch any supported timm model for MIM pretraining.
Args:
model: A timm model instance
Returns:
Patched model with embed() and forward_stages() methods
"""
model_name = model.__class__.__name__.lower()
if 'swin' in model_name:
return patch_swin_for_mim(model)
else:
# Default to ViT-style patching for most transformers
return patch_vit_for_mim(model)
# Example usage for MIM pretraining
class MaskedImageModelingWrapper(nn.Module):
"""
Example wrapper for MIM pretraining using the patched model.
"""
def __init__(self, backbone, mask_ratio=0.75, mask_token_init='zero'):
super().__init__()
self.backbone = patch_model_for_mim(backbone)
self.mask_ratio = mask_ratio
# Get embedding dimension
if hasattr(backbone, 'embed_dim'):
embed_dim = backbone.embed_dim
elif hasattr(backbone, 'num_features'):
embed_dim = backbone.num_features
else:
# Try to infer from first block
embed_dim = 768 # default
# Initialize mask token
if mask_token_init == 'learnable':
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.init.normal_(self.mask_token, std=0.02)
else:
self.register_buffer('mask_token', torch.zeros(1, 1, embed_dim))
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by shuffling.
Args:
x: [N, L, D] tensor (batch, length, dim)
mask_ratio: Float between 0 and 1
Returns:
x_masked: masked tokens
mask: binary mask (0 is masked, 1 is kept)
ids_restore: indices to restore original order
"""
N, L, D = x.shape
len_keep = int(L * (1 - mask_ratio))
# Generate random noise
noise = torch.rand(N, L, device=x.device)
# Sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# Keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# Generate binary mask: 0 is masked, 1 is kept
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward(self, x, return_mask=True):
# Get embeddings
x_embed = self.backbone.embed(x)
# Apply masking (skip CLS token if present)
if hasattr(self.backbone, 'cls_token') and self.backbone.cls_token is not None:
cls_tokens = x_embed[:, :1, :]
x_patches = x_embed[:, 1:, :]
else:
cls_tokens = None
x_patches = x_embed
# Random masking
x_masked, mask, ids_restore = self.random_masking(x_patches, self.mask_ratio)
# Add mask tokens for masked positions
N, L_kept, D = x_masked.shape
L_original = x_patches.shape[1]
mask_tokens = self.mask_token.repeat(N, L_original - L_kept, 1)
# Combine kept and mask tokens
x_full = torch.cat([x_masked, mask_tokens], dim=1)
x_full = torch.gather(x_full, dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, D))
# Add back CLS token if needed
if cls_tokens is not None:
x_full = torch.cat([cls_tokens, x_full], dim=1)
# Forward through transformer stages
x_output = self.backbone.forward_stages(x_full)
if return_mask:
return x_output, mask, ids_restore
return x_output
# Example usage
if __name__ == "__main__":
import timm
# Load a pretrained ViT model
vit = timm.create_model('vit_base_patch16_224', pretrained=True)
# Create MIM wrapper
mim_model = MaskedImageModelingWrapper(
vit,
mask_ratio=0.75,
mask_token_init='learnable'
)
# Example forward pass
dummy_input = torch.randn(2, 3, 224, 224)
output, mask, ids_restore = mim_model(dummy_input)
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Mask shape: {mask.shape}")
print(f"Masked patches: {(mask == 0).sum(1).float().mean():.0f} out of {mask.shape[1]}")
# You can also use the patched methods directly
patched_vit = patch_model_for_mim(vit)
embeddings = patched_vit.embed(dummy_input)
print(f"\nEmbeddings shape: {embeddings.shape}")
# Apply your own masking logic here
# ... custom masking ...
# Then forward through stages
output = patched_vit.forward_stages(embeddings)
print(f"Stages output shape: {output.shape}")
All that said, I'm open to adding additional API to cover common use cases (as long as it doesn't break compat with existing API)... but, need to see some traction, universal interest/use cases before doing something with project wide scope like that. So the patching is a good middle ground, possibly add a patch helper, see if there is other interest before changing significant models...
I have no intention of removing forward_features.
The changes I want to make, which involve refactoring forward_features into smaller, more focused methods like embed and forward_stages, should not cause any breaking changes. forward_features will continue to function exactly as before, simply by calling these new, smaller methods internally:
def embed(self, x): ...
def forward_stages(self, x): ...
def forward_features(self, x):
x = self.embed(x)
x = self.forward_stages(x)
return x
You're right that a wrapper class can address this, and that's precisely what I've done in my own research. I've been examining each class that needs modification, separating the embed and other processing steps, and then wrapping them with an nn.Module to manually call the sub-modules' layers.