VisCy icon indicating copy to clipboard operation
VisCy copied to clipboard

Update the projection head (normalization and size).

Open mattersoflight opened this issue 1 year ago • 2 comments

TL;DR: current projection head doesn't do what it is supposed to do. It should have a batch norm and the size of the features and projections may be reduced further. In the previous implementations of contrastive learning models (dynacontrast), we used batch norm in the projection head after each MLP.

(projection): Sequential(
    (fc1): Linear(in_features=2048, out_features=2048, bias=False)
    (bn1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU()
    (fc2): Linear(in_features=2048, out_features=128, bias=False)
    (bn2): BatchNorm1dNoBias(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )

This paper also recommends using a non-linear projection head with batch norm. Different projection heads are evaluated by comparing the rank of the features (# of independent features) before and after the projection head: image

As expected, the rank(projections) << rank(features).

Our current model's behavior is the opposite: rank(projections)> rank(features) as seen from the examination of the principal components in each.

image

This seems to be the consequence of clipping of projections, which seems to be due to the use of ReLU without normalization.

plt.plot(np.mean(embedding_dataset["projections"].values,axis=1)) image

plt.plot(np.std(embedding_dataset["projections"].values,axis=1)) image

mattersoflight avatar Aug 17 '24 01:08 mattersoflight

@mattersoflight In SimCLR and others using InfoNCE-style losses, there is an implicit L2-normalization of $\mathbb{z}$ happening in the loss function, since they use cosine similarity as the distance function. The triplet margin loss uses L2 distance (which is fully determined by the cosine and vice versa for unit vectors). I still need to find a reference implementation, but the planned removal of L2 normalization might not be needed.

ziw-liu avatar Aug 19 '24 17:08 ziw-liu

@ziw-liu L2 normalization is indeed equivalent to converting feature vectors into unit vectors. It also makes sense that the loss is computed between unit vectors (either cosine similarity or eucledian) given SimCLR paper.

I agree that L2 normalization of projections doesn't need to be removed. NO NEED to implement that as an argument.

[TripletMarginLoss](https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html) provided by torch.nn doesn't do normalization by default.

mattersoflight avatar Aug 19 '24 21:08 mattersoflight

@ziw-liu the paper that compares different types of projection heads uses infoNCE loss (which reads the same as NT-Xent loss in simCLR). It may be that using NT-Xent loss promotes higher-rank embeddings.

mattersoflight avatar Sep 03 '24 01:09 mattersoflight

Model changes was implemented in #145. But #154 can potentially fix the low-rank feature map.

ziw-liu avatar Sep 09 '24 23:09 ziw-liu