pytorch-metric-learning icon indicating copy to clipboard operation
pytorch-metric-learning copied to clipboard

Does DistributedWrapper support two stream input?

Open NoTody opened this issue 2 years ago • 3 comments

Hi, how does the DistributedWrapper support two stream input? I see the original implementation DistributedWrapper class only has three input with def forward(self, embeddings, labels) , but I want something like self.miner(query_embed, labels, doc_embed, labels.clone()) in my implementation after wrapping miner/loss (e.g. self.miner = pml_dist.DistributedMinerWrapper(self.miner)).

NoTody avatar Jul 20 '22 16:07 NoTody

It doesn't support ref_emb and ref_labels. I'll have to think about how to add that functionality.

KevinMusgrave avatar Jul 21 '22 07:07 KevinMusgrave

I have attempted to implement this functionality and it seems to work correctly. Can I send a pull request for this?

NoTody avatar Jul 21 '22 14:07 NoTody

Yes please do

KevinMusgrave avatar Jul 21 '22 20:07 KevinMusgrave

Added in v1.6.0

KevinMusgrave avatar Sep 03 '22 19:09 KevinMusgrave