gotch
gotch copied to clipboard
PosWeight limitation to integer?
In loss.go file:
type lossFnOptions struct {
ClassWeights []float64
Reduction int64 // 0: "None", 1: "mean", 2: "sum"
IgnoreIndex int64
PosWeight int64 // index of the weight attributed to positive class. Used in BCELoss
}
In BCELoss function options.PosWeight is used as follows:
posWeight = ts.MustOfSlice([]int64{options.PosWeight})
Then, posWeight tensor is used in MustBinaryCrossEntropyWithLogits call. This function finally calls AtgBinaryCrossEntropyWithLogits. It seems that this function has no limitation for tensor type. Pytorch accepts float tensor in BCEWithLogitsLoss call as well.
Why do you force PosWeight to be integer?
@Peter2121 ,
Thanks for your report. You are right, pos_weight is just optional scaling factor for the target. Feel free to PR. Thanks.