detcon-pytorch
detcon-pytorch copied to clipboard
Nonlocal positives (matching masks across batch items)
Thank you for putting together this repo! I'm working on a domain-specific use of DetCon right now and your code has helped a huge amount.
I'm interested in adding an option for 'nonlocal positives' when label IDs are consistent across images, i.e. treating matching class IDs between different batch items as being positives.
I'm trying to figure out exactly which step in DetConBLoss() would need to be modified. Am I right in thinking it would simply require setting labels_aa=same_obj_aa, etc rather than masking these by the labels variable? E.g.:
if nonlocal_positives:
labels_aa = same_obj_aa
labels_ab = same_obj_ab
labels_ba = same_obj_ba
labels_bb = same_obj_bb
else:
# this is the problem line for using corresponding labels between views - perhaps needs to only be applied when nonlocal_positives is False?
labels_aa = labels * same_obj_aa # (B,Ca,B,Ca) masked with item id in batch
labels_ab = labels * same_obj_ab # (B,Ca,B,Cb) masked with item id in batch
labels_ba = labels * same_obj_ba # (B,Cb,B,Ca) masked with item id in batch
labels_bb = labels * same_obj_bb # (B,Cb,B,Cb) masked with item id in batch