pytorch-metric-learning
pytorch-metric-learning copied to clipboard
Does DistributedWrapper support two stream input?
Hi, how does the DistributedWrapper support two stream input? I see the original implementation DistributedWrapper class only has three input with def forward(self, embeddings, labels) , but I want something like self.miner(query_embed, labels, doc_embed, labels.clone()) in my implementation after wrapping miner/loss (e.g. self.miner = pml_dist.DistributedMinerWrapper(self.miner)).
It doesn't support ref_emb and ref_labels. I'll have to think about how to add that functionality.
I have attempted to implement this functionality and it seems to work correctly. Can I send a pull request for this?
Yes please do
Added in v1.6.0