xview2_1st_place_solution
xview2_1st_place_solution copied to clipboard
Question about validate function of classification stage
Hi, I have a question about the validate function of classification stage.
The codes are in def validate(net,data_loader), train50_cls_cce.py.
`
for j in range(msks.shape[0]): # msks.shape[0] is batch size
tp[4] += np.logical_and(msks[j, 0] > 0, msk_pred[j] > 0).sum()
fn[4] += np.logical_and(msks[j, 0] < 1, msk_pred[j] > 0).sum()
fp[4] += np.logical_and(msks[j, 0] > 0, msk_pred[j] < 1).sum()
targ = lbl_msk[j][msks[j, 0] > 0]
pred = msk_damage_pred[j].argmax(axis=0)
pred = pred * (msk_pred[j] > _thr)
pred = pred[msks[j, 0] > 0]
for c in range(4):
tp[c] += np.logical_and(pred == c, targ == c).sum()
fn[c] += np.logical_and(pred != c, targ == c).sum()
fp[c] += np.logical_and(pred == c, targ != c).sum()
` I was wondering why there is 'pred == c, targ == c' , when the values of c ranges from 0 to 3 while the values of targ ranges from 1 to 4? Or did I get it wrong? Please explain to me. Thanks a million!