pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

[FEATURE] Unified `embed` Interfaces for Vision Transformer Models for MIM Pretrain Research

Open ryan-minato opened this issue 5 months ago • 4 comments

This feature request is related to challenges in Masked Image Modeling (MIM) pre-training using vision transformer models in timm. Currently, embedding and feature extraction are tightly coupled within forward_features, making it difficult to inject mask operations after initial embedding and positional encoding but before the transformer stages, which is a common MIM requirement. Researchers need to access embedded tokens for masking before passing them through subsequent transformer layers.

Describe the solution you'd like

I propose a refactoring of all vision transformer models(e.g. vit, swin_transformer, etc.) to uniformly expose two distinct interfaces:

  • embed(self, x): This method should take the input tensor x (e.g., image patches) and return the embedded vectors, including positional encodings if applicable. The output should be ready for the transformer encoder stages.
  • forward_stages(self, x): This method should take the output from embed(x) (i.e., the embedded and position-encoded tokens) and pass them through the transformer encoder layers.

This separation would make the existing forward_features(x) effectively equivalent to forward_stages(embed(x)). This allows researchers to easily perform mask operations on the embedded tokens returned by embed(x) before passing them to forward_stages(x), enabling flexible MIM pre-training experiments.

Describe alternatives you've considered

I have considered alternative solutions, such as adding a mask parameter directly to the forward_features and forward methods, similar to vision_transformer's implementation.

https://github.com/huggingface/pytorch-image-models/blob/a7c5368ba0c8713dc1c9a98cc83bf46ddd02b0a0/timm/models/vision_transformer.py#L933-L935

While seemingly straightforward, this approach presents drawbacks:

  • Function Signature: It still alters the method's interface, even if default parameter values mitigate direct breakage.
  • Mask Value Flexibility: More importantly, it would limit flexibility in how masked-out positions are handled, restricting whether masked token values can be learned (e.g., a learnable mask token) or simply set to zero. Separating embed and forward_stages provides full control.

Additional context

If this refactoring aligns with the library's design, I would gladly contribute a Pull Request to implement these changes across relevant vision transformer models.

ryan-minato avatar Jul 09 '25 05:07 ryan-minato

@ryan-minato it's not possible to remove forward_features as an API at this point.

For research, this is Python and very flexible, don't see why some minor patching can't achieve the goal without requiring a re-write of the model, the code in forward_features isn't that complicated or extensive so could patch that method to and inject others to achieve the desired goal of separating embeds and block forward logic into two parts such that masking can be inserted... then can still use all other aspects of the library and patch a few vit & related models you might want to use.

Slightly closer to your goal, the NaFlexVit impl separates embeds into their own module, but still iterates over blocks in forward_features/forward_intermediates... one less patching step

https://github.com/huggingface/pytorch-image-models/blob/a7c5368ba0c8713dc1c9a98cc83bf46ddd02b0a0/timm/models/naflexvit.py#L1309-L1338

rwightman avatar Jul 09 '25 14:07 rwightman

I fired off the request to Claude to provide an example, the patching aspects looked reasonable not sure if all it's MIM ideas made sense, but the patching is what I wanted to illustrate as a relatively clean solution to adapt native timm models...

import torch
import torch.nn as nn
from functools import wraps
import types


def patch_vit_for_mim(model):
    """
    Patch a timm vision transformer model to add embed() and forward_stages() methods
    for Masked Image Modeling (MIM) pretraining.
    
    This works with most timm vision transformer variants including:
    - vit_* (Vision Transformer)
    - deit_* (DeiT)
    - beit_* (BeiT)
    - eva_* (EVA)
    
    Args:
        model: A timm vision transformer model instance
    
    Returns:
        The same model with added embed() and forward_stages() methods
    """
    
    # Store the original forward_features method
    original_forward_features = model.forward_features
    
    def embed(self, x):
        """Extract embeddings including patch embedding and positional encoding."""
        # Patch embedding
        x = self.patch_embed(x)
        
        # Add class token if it exists
        if hasattr(self, 'cls_token') and self.cls_token is not None:
            x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        
        # Add positional embedding
        if hasattr(self, 'pos_embed'):
            x = x + self.pos_embed
        elif hasattr(self, 'pos_embed_fn'):
            x = self.pos_embed_fn(x)
        
        # Apply pos_drop if it exists
        if hasattr(self, 'pos_drop'):
            x = self.pos_drop(x)
            
        return x
    
    def forward_stages(self, x):
        """Forward through transformer blocks and norm."""
        # Apply patch dropout if it exists
        if hasattr(self, 'patch_drop') and self.patch_drop is not None:
            x = self.patch_drop(x)
        
        # Forward through transformer blocks
        if hasattr(self, 'blocks'):
            x = self.blocks(x)
        elif hasattr(self, 'layers'):
            x = self.layers(x)
        elif hasattr(self, 'encoder'):
            x = self.encoder(x)
        
        # Apply norm
        if hasattr(self, 'norm'):
            x = self.norm(x)
        elif hasattr(self, 'ln_post'):
            x = self.ln_post(x)
            
        return x
    
    def new_forward_features(self, x):
        """Reimplemented forward_features using embed and forward_stages."""
        x = self.embed(x)
        x = self.forward_stages(x)
        return x
    
    # Bind the methods to the model instance
    model.embed = types.MethodType(embed, model)
    model.forward_stages = types.MethodType(forward_stages, model)
    model.forward_features = types.MethodType(new_forward_features, model)
    
    return model


def patch_swin_for_mim(model):
    """
    Patch a timm Swin Transformer model for MIM pretraining.
    
    Swin has a different architecture with hierarchical stages.
    """
    
    def embed(self, x):
        """Extract initial embeddings for Swin."""
        x = self.patch_embed(x)
        if hasattr(self, 'pos_drop'):
            x = self.pos_drop(x)
        return x
    
    def forward_stages(self, x):
        """Forward through all Swin stages."""
        # Forward through all layers (stages)
        x = self.layers(x)
        
        # Apply norm
        x = self.norm(x)
        
        # Apply avgpool if needed
        if hasattr(self, 'avgpool'):
            x = self.avgpool(x.transpose(1, 2))  # B N C -> B C N
            x = torch.flatten(x, 1)
        
        return x
    
    # Bind methods
    model.embed = types.MethodType(embed, model)
    model.forward_stages = types.MethodType(forward_stages, model)
    
    # Update forward_features to use the new methods
    original_forward_features = model.forward_features
    
    def new_forward_features(self, x):
        x = self.embed(x)
        x = self.forward_stages(x)
        return x
    
    model.forward_features = types.MethodType(new_forward_features, model)
    
    return model


# Unified patching function
def patch_model_for_mim(model):
    """
    Automatically patch any supported timm model for MIM pretraining.
    
    Args:
        model: A timm model instance
        
    Returns:
        Patched model with embed() and forward_stages() methods
    """
    model_name = model.__class__.__name__.lower()
    
    if 'swin' in model_name:
        return patch_swin_for_mim(model)
    else:
        # Default to ViT-style patching for most transformers
        return patch_vit_for_mim(model)


# Example usage for MIM pretraining
class MaskedImageModelingWrapper(nn.Module):
    """
    Example wrapper for MIM pretraining using the patched model.
    """
    
    def __init__(self, backbone, mask_ratio=0.75, mask_token_init='zero'):
        super().__init__()
        self.backbone = patch_model_for_mim(backbone)
        self.mask_ratio = mask_ratio
        
        # Get embedding dimension
        if hasattr(backbone, 'embed_dim'):
            embed_dim = backbone.embed_dim
        elif hasattr(backbone, 'num_features'):
            embed_dim = backbone.num_features
        else:
            # Try to infer from first block
            embed_dim = 768  # default
        
        # Initialize mask token
        if mask_token_init == 'learnable':
            self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
            nn.init.normal_(self.mask_token, std=0.02)
        else:
            self.register_buffer('mask_token', torch.zeros(1, 1, embed_dim))
    
    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by shuffling.
        
        Args:
            x: [N, L, D] tensor (batch, length, dim)
            mask_ratio: Float between 0 and 1
            
        Returns:
            x_masked: masked tokens
            mask: binary mask (0 is masked, 1 is kept)
            ids_restore: indices to restore original order
        """
        N, L, D = x.shape
        len_keep = int(L * (1 - mask_ratio))
        
        # Generate random noise
        noise = torch.rand(N, L, device=x.device)
        
        # Sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        
        # Keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        
        # Generate binary mask: 0 is masked, 1 is kept
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        return x_masked, mask, ids_restore
    
    def forward(self, x, return_mask=True):
        # Get embeddings
        x_embed = self.backbone.embed(x)
        
        # Apply masking (skip CLS token if present)
        if hasattr(self.backbone, 'cls_token') and self.backbone.cls_token is not None:
            cls_tokens = x_embed[:, :1, :]
            x_patches = x_embed[:, 1:, :]
        else:
            cls_tokens = None
            x_patches = x_embed
        
        # Random masking
        x_masked, mask, ids_restore = self.random_masking(x_patches, self.mask_ratio)
        
        # Add mask tokens for masked positions
        N, L_kept, D = x_masked.shape
        L_original = x_patches.shape[1]
        mask_tokens = self.mask_token.repeat(N, L_original - L_kept, 1)
        
        # Combine kept and mask tokens
        x_full = torch.cat([x_masked, mask_tokens], dim=1)
        x_full = torch.gather(x_full, dim=1, 
                            index=ids_restore.unsqueeze(-1).repeat(1, 1, D))
        
        # Add back CLS token if needed
        if cls_tokens is not None:
            x_full = torch.cat([cls_tokens, x_full], dim=1)
        
        # Forward through transformer stages
        x_output = self.backbone.forward_stages(x_full)
        
        if return_mask:
            return x_output, mask, ids_restore
        return x_output


# Example usage
if __name__ == "__main__":
    import timm
    
    # Load a pretrained ViT model
    vit = timm.create_model('vit_base_patch16_224', pretrained=True)
    
    # Create MIM wrapper
    mim_model = MaskedImageModelingWrapper(
        vit, 
        mask_ratio=0.75,
        mask_token_init='learnable'
    )
    
    # Example forward pass
    dummy_input = torch.randn(2, 3, 224, 224)
    output, mask, ids_restore = mim_model(dummy_input)
    
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Mask shape: {mask.shape}")
    print(f"Masked patches: {(mask == 0).sum(1).float().mean():.0f} out of {mask.shape[1]}")
    
    # You can also use the patched methods directly
    patched_vit = patch_model_for_mim(vit)
    embeddings = patched_vit.embed(dummy_input)
    print(f"\nEmbeddings shape: {embeddings.shape}")
    
    # Apply your own masking logic here
    # ... custom masking ...
    
    # Then forward through stages
    output = patched_vit.forward_stages(embeddings)
    print(f"Stages output shape: {output.shape}")

rwightman avatar Jul 09 '25 15:07 rwightman

All that said, I'm open to adding additional API to cover common use cases (as long as it doesn't break compat with existing API)... but, need to see some traction, universal interest/use cases before doing something with project wide scope like that. So the patching is a good middle ground, possibly add a patch helper, see if there is other interest before changing significant models...

rwightman avatar Jul 09 '25 16:07 rwightman

I have no intention of removing forward_features.

The changes I want to make, which involve refactoring forward_features into smaller, more focused methods like embed and forward_stages, should not cause any breaking changes. forward_features will continue to function exactly as before, simply by calling these new, smaller methods internally:

def embed(self, x): ...
def forward_stages(self, x): ...
def forward_features(self, x):
    x = self.embed(x)
    x = self.forward_stages(x)
    return x

You're right that a wrapper class can address this, and that's precisely what I've done in my own research. I've been examining each class that needs modification, separating the embed and other processing steps, and then wrapping them with an nn.Module to manually call the sub-modules' layers.

ryan-minato avatar Jul 10 '25 00:07 ryan-minato