Some questions about 3DInfomax
Dear professor,I have some questions about the 3DInfomax. I want to get the evaluation metrics such as Precision,so I use the Function which you provided in your metric.py such as TruePositiveRate() and TrueNegativeRate() to get this metric. But I tried all OGB datasets and found that those metrics such as Precision,Accuracy and Recall were not ideal. I hope you can reply to me as soon as possible. Thank you, professor.
Here is the HIV dataset's metric: Precision: 0.008995866402983665 Accuracy: 0.9988852739334106 Recall: 0.002496626228094101 F1_score: 0.003908519633114338 ROC_AUC: 0.7427065372467041 PR_AUC: 0.2141391634941101 ogbg-molhiv: 0.742706502636204 BCEWithLogitsLoss: 0.17792926660992883
Here is the BBBP dataset's metric: Precision: 0.44607841968536377 Accuracy: 0.6127931475639343 Recall: 0.005654983688145876 F1_score: 0.011168383993208408 ROC_AUC: 0.6745756268501282 PR_AUC: 0.6546612977981567 ogbg-molbbbp: 0.6745756172839505 BCEWithLogitsLoss: 1.1453146849359785
Here is my metric code: class Precision(nn.Module): def init(self, threshold=0.5) -> None: super(Precision, self).init() self.threshold = threshold
def forward(self, x1: Tensor, x2: Tensor, pos_mask: Tensor = None) -> Tensor:
batch_size, _ = x1.size()
if x1.shape != x2.shape and pos_mask == None:
x2 = x2[:batch_size]
sim_matrix = torch.einsum('ik,jk->ij', x1, x2)
x1_abs = x1.norm(dim=1)
x2_abs = x2.norm(dim=1)
sim_matrix = sim_matrix / torch.einsum('i,j->ij', x1_abs, x2_abs)
preds: Tensor = (sim_matrix + 1) / 2 > self.threshold
if pos_mask == None: # if we are comparing global with global
pos_mask = torch.eye(batch_size, device=x1.device)
neg_mask = 1 - pos_mask
num_positives = len(x1)
num_negatives = len(x1) * (len(x2) - 1)
false_positives = ((preds.long() - pos_mask) * pos_mask).count_nonzero()
true_positives = num_positives - ((preds.long() - pos_mask) * pos_mask).count_nonzero()
false_negatives = (((~preds).long() - neg_mask) * neg_mask).count_nonzero()
true_negatives = num_negatives - (((~preds).long() - neg_mask) * neg_mask).count_nonzero()
pre = true_positives /(true_positives + false_positives)
return pre