pytorch-metric-learning
pytorch-metric-learning copied to clipboard
Does DistributedWrapper support two stream input?
Hi, how does the DistributedWrapper support two stream input? I see the original implementation DistributedWrapper class only has three input with def forward(self, embeddings, labels)
, but I want something like self.miner(query_embed, labels, doc_embed, labels.clone())
in my implementation after wrapping miner/loss (e.g. self.miner = pml_dist.DistributedMinerWrapper(self.miner)
).
It doesn't support ref_emb
and ref_labels
. I'll have to think about how to add that functionality.
I have attempted to implement this functionality and it seems to work correctly. Can I send a pull request for this?
Yes please do
Added in v1.6.0