CFA_for_anomaly_localization
CFA_for_anomaly_localization copied to clipboard
error when training own database
please help me reslove this problem. I don't khown what this error is
I also ran into this problem, and here's my solution:
utils/metric.py
def get_threshold(gt, score):
gt_mask = np.asarray(gt)
gt_mask_copy = gt_mask.copy()
gt_mask[gt_mask_copy < 0.5] = 0
gt_mask[gt_mask_copy >= 0.5] = 1
# print(gt_mask.flatten())
precision, recall, thresholds = precision_recall_curve(gt_mask.flatten(), score.flatten())
a = 2 * precision * recall
b = precision + recall
f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
threshold = thresholds[np.argmax(f1)]
return threshold```
@wangxiang0722 thank very much. i try use your workaround