simclr icon indicating copy to clipboard operation
simclr copied to clipboard

Is the trained projection head available?

Open lkshrsch opened this issue 2 years ago • 7 comments

I am interested in downloading a pre-trained simCLR model with the projection head, to retrieve the latent features z, upon which the contrastive loss was applied. Is this layer + pre-trained weights available somewhere?

lkshrsch avatar Mar 30 '22 19:03 lkshrsch

yes, the projection head weights should also be included in gs://simclr-checkpoints/simclrv2/pretrained/...

chentingpc avatar Apr 02 '22 23:04 chentingpc

In the github README that link is under the description

"Pretrained SimCLRv2 models (with linear eval head):"

I assumed "with linear eval head" refers to the classification layer for ImageNet,

but downloading the model r50_1x_sk0 from:

https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv2/pretrained/r50_1x_sk0?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))&prefix=&forceOnObjectsSortingFiltering=false

the model output is of dimension 2048, which could be either the output from the resnet50, or the output of the projection head.

so to confirm: are these the features from the projection head z = g(h) (as described in the simCLR paper, Figure 2)? or from resNet50: h = f(x) (as described in the simCLR paper, Figure 2)? or the linear evaluation head for classification (as described in the github README, which should be logits of dimension (1000) for ImageNet )?

Thanks!

lkshrsch avatar Apr 04 '22 16:04 lkshrsch

Both projection head's and supervised linear head's weights are available in the checkpoints. I suppose you're using hub module? If so, you could choose output by providing signature that's available in module.get_output_info_dict(), I listed the results below. Note that the projection head's output is not included, so in order to get that, you may need to run the tf code with the checkpoint loaded to build a new graph.

{'block_group1': <hub.ParsedTensorInfo shape=(None, None, None, 256) dtype=float32 is_sparse=False>, 'block_group2': <hub.ParsedTensorInfo shape=(None, None, None, 512) dtype=float32 is_sparse=False>, 'block_group3': <hub.ParsedTensorInfo shape=(None, None, None, 1024) dtype=float32 is_sparse=False>, 'block_group4': <hub.ParsedTensorInfo shape=(None, None, None, 2048) dtype=float32 is_sparse=False>, 'default': <hub.ParsedTensorInfo shape=(None, 2048) dtype=float32 is_sparse=False>, 'final_avg_pool': <hub.ParsedTensorInfo shape=(None, 2048) dtype=float32 is_sparse=False>, 'initial_conv': <hub.ParsedTensorInfo shape=(None, None, None, 64) dtype=float32 is_sparse=False>, 'initial_max_pool': <hub.ParsedTensorInfo shape=(None, None, None, 64) dtype=float32 is_sparse=False>, 'logits_sup': <hub.ParsedTensorInfo shape=(None, 1000) dtype=float32 is_sparse=False>}

chentingpc avatar Apr 04 '22 19:04 chentingpc

I'm struggling to actually get the projection representations and still not quite certain what to do based on the previous comments in this thread. Does anyone have a minimal working example of loading the pre-trained model, pushing an input through, and getting the representation from the projection head?

ilia10000 avatar Apr 22 '22 03:04 ilia10000

Thanks for the great repo! I wanted to follow-up to explore whether this issue has been reconciled?

I'm also trying to access the projection representations. Specifically, I'd like to be able to pass in an image and get out just the representation (prior to the class-level logits). What layer should I use for this?

If I load a model as follows:

saved_model_path = 'gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_1x_sk0/saved_model/'
saved_model = tf.saved_model.load(saved_model_path)

The keys available when running inference on a new image as follows:

saved_model(image, trainable=False).keys()

dict_keys(['logits_sup', 'block_group3', 'block_group4', 'final_avg_pool', 'block_group2', 'block_group1', 'initial_max_pool', 'initial_conv'])

Which of these is the key associated with the representation? final_avg_pool?

Thank you for any insight @chentingpc or others!

collinskatie avatar Apr 03 '23 23:04 collinskatie

Hi final_avg_pool is the output of the resnet which is used for linear probing. hope that helps

chentingpc avatar Apr 04 '23 19:04 chentingpc

Thank you @chentingpc !! That's great to know!

collinskatie avatar Apr 05 '23 10:04 collinskatie