pytorch-image-models
pytorch-image-models copied to clipboard
Naflex performance
Hi!
Thank you for the excellent library and the recent addition of NaFlex transformers. I've been learning a lot from exploring it! While experimenting with the code, I identified a few points I'd like to discuss regarding the two key components of NaFlex compared to standard transformers: 1) on-the-fly patch embedding interpolation and 2) on-the-fly positional embedding interpolation.
Patch Embeddings
The interpolation of patch embeddings uses a bit more complicated method involving the calculation of a pseudo-inverse matrix. Since it operates on a single matrix (with all patch sizes within a batch being identical) of shape (C_out, C_in, P_h, P_w), its performance impact is relatively minor. However, I noticed a small optimization that could improve efficiency. The current implementation uses torch.vmap to vectorize calculations across channels:
def _apply_resampling(
patch_embed: torch.Tensor,
pinv_matrix: torch.Tensor,
new_size_tuple: Tuple[int, int],
orig_dtype: torch.dtype,
intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE
) -> torch.Tensor:
"""Applies the precomputed pinv_matrix to resample the patch_embed tensor."""
try:
from torch import vmap
except ImportError:
from functorch import vmap
def resample_kernel(kernel: torch.Tensor) -> torch.Tensor:
kernel_flat = kernel.reshape(-1).to(intermediate_dtype)
resampled_kernel_flat = pinv_matrix @ kernel_flat
return resampled_kernel_flat.reshape(new_size_tuple)
resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0)
patch_embed_float = patch_embed.to(intermediate_dtype)
resampled_patch_embed = resample_kernel_vmap(patch_embed_float)
return resampled_patch_embed.to(orig_dtype)
I was unfamiliar with vmap (which seems more common in JAX-style workflows), and its use here was initially unclear. While it performs well, the operation is essentially a batched matrix multiplication, which can be simplified using the standard @ operator. This alternative approach reduces code complexity and slightly improves performance:
def _apply_resampling(patch_embed, pinv_matrix, new_size_tuple, orig_dtype, intermediate_dtype):
c_out, c_in, *_ = patch_embed.shape
patch_embed = patch_embed.reshape(c_out, c_in, -1).to(dtype=intermediate_dtype)
pinv_matrix = pinv_matrix.to(dtype=intermediate_dtype)
resampled_patch_embed = patch_embed @ pinv_matrix.T # (C_out, C_in, P_old * P_old) @ (P_old * P_old, P_new * P_new)
resampled_patch_embed = resampled_patch_embed.reshape(c_out, c_in, *new_size_tuple).to(dtype=orig_dtype)
return resampled_patch_embed
In my tests, resampling a (768, 3, 32, 32) embedding to (768, 3, 16, 16) took 18 ms (forward) and 26 ms (forward + backward) with the vmap implementation. The batched matrix multiplication approach reduced this to 11 ms (forward) and 13 ms (forward + backward), roughly doubling the speed for the forward + backward pass while using less code.
Positional Embeddings
Unlike patch embeddings, where all images in a batch share the same patch size, positional embeddings must account for varying aspect ratios and dimensions across images in a batch. This requires calling F.interpolate individually for each sample, currently implemented in a Python for-loop, which significantly impacts performance. While there's an optimization to group similarly sized images and interpolate once per group, this only helps when many images share the same size. In the original JAX implementation, vectorization with vmap improved efficiency, but PyTorch's F.interpolate is incompatible with torch.vmap, which seems like pytorch's current limitation.
I tested performance with the following code:
import einx
import torch
from timm.models import NaFlexEmbeds
from torch.nn.utils.rnn import pad_sequence
def get_coords(h, w):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
coord = einx.rearrange('i j, i j -> (i j) (1 + 1)', y, x)
return coord
# Generate a batch of 512 images with random heights and widths (up to 32 patches max in one dimension)
device = 'cuda'
B = 512
sizes = torch.randint(4, 32, (B, 2))
max_seq_len = sizes.prod(-1).amax()
coords = pad_sequence([get_coords(h, w) for h, w in sizes], batch_first=True).to(device=device)
x = torch.randn(B, coords.shape[1], 16 * 16 * 3).to(device=device)
patch_emb = NaFlexEmbeds(embed_dim=768, proj_type='linear', pos_embed_grid_size=(32, 32)).cuda()
%%time
patch_emb(x, coords).sum().backward()
torch.cuda.synchronize()
| Batch Size | Forward (ms) | Forward + Backward (ms) |
|---|---|---|
| 256 | 240 | 1940 |
| 512 | 370 | 7230 |
While the forward pass is reasonably fast, the backward pass is notably slow and scales super-linearly with batch size for some reason.
One idea, though it will produce slightly different results, as there is no antialias=True equivalent, is to use torch.grid_sample instead of F.interpolate, which actually supports resampling different images in batch to different shapes in vectorized manner.
An example implementation using it might look like this:
def apply_naflex_pos_emb_grid_sample(pos_emb, x, coords, mode='bilinear', align_corners=False, padding_mode='zeros'):
B = x.shape[0]
C = pos_emb.shape[-1]
pos_emb = einx.rearrange('h w c -> b c h w', pos_emb, b=B)
shapes = coords.amax(1) + 1
grid_size = shapes.amax(0)
theta = torch.zeros(B, 2, 3, dtype=pos_emb.dtype, device=pos_emb.device)
theta[:, 0, 0] = grid_size[1] / shapes[:, 1] # scale x
theta[:, 0, 2] = theta[:, 0, 0] - 1 # translate x
theta[:, 1, 1] = grid_size[0] / shapes[:, 0] # scale y
theta[:, 1, 2] = theta[:, 1, 1] - 1 # translate y
grid = F.affine_grid(theta, (B, C, *grid_size), align_corners=align_corners)
pos_emb = F.grid_sample(pos_emb, grid, mode=mode, align_corners=align_corners, padding_mode=padding_mode)
bi = einx.rearrange('b -> b n', torch.arange(B, device=pos_emb.device), n=coords.shape[1])
x = x + pos_emb[bi, :, coords[..., 0], coords[..., 1]]
return x
This approach eliminates the Python for-loop, reducing the forward + backward pass for a batch size of 512 to just 150 ms. I haven't tested its impact on training performance compared to F.interpolate with antialias=True vs antialias=False, but for bilinear interpolation, F.grid_sample produces numerically similar results to F.interpolate with antialias=False.
@stas-sl thanks, I'll try taking out the vmap in patch interpolation with your approach and replace if it matches.
For the position embeddings, you are correct, this is not particularly fast, there isn't a great approach to this with what pytorch has given us. I did investigate using grid sample, but the results are sufficiently different to make it unnaceptable (in my opinion) for use with the SigLIP-2 encoder weights (which was a benchmark during dev)... I don't know why they never added anti-aliasing and aligned it more closely with output of interpolate...
However, given that it is a noteworthy performance difference, I could add it as an option for training from scratch or a compromise when fine-tuning.
pos embed is updated and tested in #2518 ... haven't had a chance to work on the pos embed yet
@stas-sl I've got an impl of both in the PR now. The patch embed resizing is a win on simplicity and performance. In practical terms the grid_sample would appear to come down on the variation in your dataset. With basic imagenet trials (which isn't extremely diverse in aspect), the original impl is faster for training as there ends up being quite a bit of re-use of the cached interpolations.
Hi @rwightman, did you see my comment on the PR? Maybe I'm missing something, but I believe there's no need to construct coords inside _apply_learned_naflex_pos_embed_grid_sample, since you could directly use patch_coord from the dataloader - which should be a bit more efficient.
That said, I completely understand that your original approach might be faster if the dataset doesn’t have much variability in image shapes. In my case, however, I’m working with a fairly diverse set of image shapes and aspect ratios, especially with max_ratio=1 (which is actually why I’m exploring NaflexViT). In this setup, the grid_sample approach seems to be faster.
@stas-sl hah, right that was a bit silly, I just focused on the fn w/ set interface but was no need to do that
Okay, fixed yesterdays quick tunnel vision impl, altered the interface and leverage patch_coords as I should have
Looks like I also messed up yesterday during my PR review - I added comments on some lines and assumed that was enough, but I just realized they probably weren’t visible to you since I didn’t press "Start review"... my bad 😳
Looks like I also messed up yesterday during my PR review - I added comments on some lines and assumed that was enough, but I just realized they probably weren’t visible to you since I didn’t press "Start review"... my bad 😳
I make that exact mistake all the time, I blame GitHub ;)