linfa icon indicating copy to clipboard operation
linfa copied to clipboard

Let linfa_linear::LinearRegression support weights from a dataset

Open MathisC22 opened this issue 1 year ago • 0 comments

I just stumbled that when updating a dataset with weights, that those weights don't get respected when doing linear regression. After looking in the source code for the fit() method, only X and y get extracted from the dataset. Would it be possible to add those weights in the fit()-method?

https://github.com/rust-ml/linfa/blob/d4955622a4546ce4e2ea8ba061f826dcdc11a25d/algorithms/linfa-linear/src/ols.rs#L107C4-L132C6

I am still a newbie in Rust, but I guess an updated version would look somewhat like this:

// snip line 108 and above
let y = dataset.as_single_targets();
let w = dataset.weights().unwrap_or_else(Array::ones(X.ncols()); // Use weights of 1 if no weights are provided

// snip 110-116
    let X = concatenate(Axis(1), &[X.view(), Array2::ones((X.nrows(), 1)).view()]).unwrap();
    let X = X * w;
 // snip 118-124
    let (X, y) = (X.to_owned(), y.to_owned());
    let X = X * w;

MathisC22 avatar Jan 03 '24 09:01 MathisC22