3DInfomax icon indicating copy to clipboard operation
3DInfomax copied to clipboard

Some questions about 3DInfomax

Open happyzhanglol opened this issue 2 years ago • 1 comments

Dear professor,I have some questions about the 3DInfomax. I want to get the evaluation metrics such as Precision,so I use the Function which you provided in your metric.py such as TruePositiveRate() and TrueNegativeRate() to get this metric. But I tried all OGB datasets and found that those metrics such as Precision,Accuracy and Recall were not ideal. I hope you can reply to me as soon as possible. Thank you, professor.

Here is the HIV dataset's metric: Precision: 0.008995866402983665 Accuracy: 0.9988852739334106 Recall: 0.002496626228094101 F1_score: 0.003908519633114338 ROC_AUC: 0.7427065372467041 PR_AUC: 0.2141391634941101 ogbg-molhiv: 0.742706502636204 BCEWithLogitsLoss: 0.17792926660992883

Here is the BBBP dataset's metric: Precision: 0.44607841968536377 Accuracy: 0.6127931475639343 Recall: 0.005654983688145876 F1_score: 0.011168383993208408 ROC_AUC: 0.6745756268501282 PR_AUC: 0.6546612977981567 ogbg-molbbbp: 0.6745756172839505 BCEWithLogitsLoss: 1.1453146849359785

Here is my metric code: class Precision(nn.Module): def init(self, threshold=0.5) -> None: super(Precision, self).init() self.threshold = threshold

def forward(self, x1: Tensor, x2: Tensor, pos_mask: Tensor = None) -> Tensor:
    batch_size, _ = x1.size()
    if x1.shape != x2.shape and pos_mask == None: 
        x2 = x2[:batch_size]
    sim_matrix = torch.einsum('ik,jk->ij', x1, x2)

    x1_abs = x1.norm(dim=1)
    x2_abs = x2.norm(dim=1)
    sim_matrix = sim_matrix / torch.einsum('i,j->ij', x1_abs, x2_abs)

    preds: Tensor = (sim_matrix + 1) / 2 > self.threshold
    if pos_mask == None:  # if we are comparing global with global
        pos_mask = torch.eye(batch_size, device=x1.device) 
        neg_mask = 1 - pos_mask 

    num_positives = len(x1)
    num_negatives = len(x1) * (len(x2) - 1)

    false_positives = ((preds.long() - pos_mask) * pos_mask).count_nonzero()
    true_positives = num_positives - ((preds.long() - pos_mask) * pos_mask).count_nonzero()

    false_negatives = (((~preds).long() - neg_mask) * neg_mask).count_nonzero()
    true_negatives = num_negatives - (((~preds).long() - neg_mask) * neg_mask).count_nonzero()

    pre = true_positives /(true_positives + false_positives)
    return pre

happyzhanglol avatar Jan 26 '23 03:01 happyzhanglol