lightly icon indicating copy to clipboard operation
lightly copied to clipboard

Start on SupCon loss #1554

Open KylevdLangemheen opened this issue 4 months ago • 4 comments

As per #1554 this PR starts on an implementation for SupCon loss.

The officially referenced pytorch implementation does not yet support multi-gpu, but the official tensorflow implementation does.

Currently implemented is support for all three contrast modes under the definition for $\mathcal{L}^{sup}_{out}$ (equation 2 in https://arxiv.org/abs/2004.11362). There is not yet support for capping the number of positives used.

Currently implemented is also two very basic tests, one which just runs the loss with some random features and labels, and one which compares the output of this implementation to the existing NTXentLoss when labels is None. More tests are definitely needed, and the implementation is not final (and likely still has some bugs).

Note: This method could be expanded with an altered version of a memory bank which also stores labels.

KylevdLangemheen avatar Aug 07 '25 14:08 KylevdLangemheen

Hi @KylevdLangemheen, thank you for your contribution! Will have a look and give you feedback on how to proceed.

yutong-xiang-97 avatar Aug 13 '25 07:08 yutong-xiang-97

Codecov Report

:x: Patch coverage is 91.83673% with 4 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 86.14%. Comparing base (ee30cd4) to head (f7d94fb).

Files with missing lines Patch % Lines
lightly/loss/supcon_loss.py 91.66% 4 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1877      +/-   ##
==========================================
+ Coverage   86.10%   86.14%   +0.03%     
==========================================
  Files         168      169       +1     
  Lines        6979     7028      +49     
==========================================
+ Hits         6009     6054      +45     
- Misses        970      974       +4     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Aug 13 '25 07:08 codecov[bot]

Hi! I have redone the loss from scratch, basing it off of your implementation! Is it online such that I can reference it? I have also included it in the test.

Let me know if this way of doing distributed is correct. It's slightly different from how it's handled in NTXentLoss, but conceptually this made sense to me. Is there a way to test it even if you only have a single GPU, or do I need to spin up a multi-gpu instance :thinking:

p.s. I added temperature rescaling as an optional parameter in order to compare it to the NTXentLoss. I can also remove it altogether.

KylevdLangemheen avatar Aug 25 '25 17:08 KylevdLangemheen

Great, thanks Kyle!

Either I or @yutong-xiang-97 will have a look very soon. About the distributed testing: We don't do it in the unit tests, but you can do it locally also without any GPUs but in a multi-process CPU setting. Example code below (important is mainly that the "gloo" backend is used, in contrast to "nccl" for CUDA GPUs).

# dist_train_cpu.py
import os
from argparse import ArgumentParser
import contextlib

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

MASTER_ADDR = "localhost"
MASTER_PORT = "12355"

@contextlib.contextmanager
def setup_dist(rank: int, world_size: int):
    try:
        os.environ['MASTER_ADDR'] = MASTER_ADDR
        os.environ['MASTER_PORT'] = MASTER_PORT
        dist.init_process_group("gloo", rank=rank, world_size=world_size)
        yield
    finally:
        dist.destroy_process_group()

def train_dist(rank: int, world_size: int) -> None:
    # Setup the process group.
    with setup_dist(rank, world_size):
        # insert a test here
        pass

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--world-size", type=int, required=True)
    args = parser.parse_args()

    mp.spawn(
        train_dist, 
        args=(args.world_size), 
        nprocs=args.world_size
    )

liopeer avatar Aug 25 '25 18:08 liopeer