gotch icon indicating copy to clipboard operation
gotch copied to clipboard

PosWeight limitation to integer?

Open Peter2121 opened this issue 1 year ago • 1 comments

In loss.go file:

type lossFnOptions struct {
	ClassWeights []float64
	Reduction    int64 // 0: "None", 1: "mean", 2: "sum"
	IgnoreIndex  int64
	PosWeight    int64 // index of the weight attributed to positive class. Used in BCELoss
}

In BCELoss function options.PosWeight is used as follows:

posWeight = ts.MustOfSlice([]int64{options.PosWeight})

Then, posWeight tensor is used in MustBinaryCrossEntropyWithLogits call. This function finally calls AtgBinaryCrossEntropyWithLogits. It seems that this function has no limitation for tensor type. Pytorch accepts float tensor in BCEWithLogitsLoss call as well.

Why do you force PosWeight to be integer?

Peter2121 avatar Oct 18 '24 19:10 Peter2121

@Peter2121 ,

Thanks for your report. You are right, pos_weight is just optional scaling factor for the target. Feel free to PR. Thanks.

sugarme avatar Dec 13 '24 02:12 sugarme