linfa
linfa copied to clipboard
Let linfa_linear::LinearRegression support weights from a dataset
I just stumbled that when updating a dataset with weights, that those weights don't get respected when doing linear regression. After looking in the source code for the fit()
method, only X and y get extracted from the dataset. Would it be possible to add those weights in the fit()
-method?
https://github.com/rust-ml/linfa/blob/d4955622a4546ce4e2ea8ba061f826dcdc11a25d/algorithms/linfa-linear/src/ols.rs#L107C4-L132C6
I am still a newbie in Rust, but I guess an updated version would look somewhat like this:
// snip line 108 and above
let y = dataset.as_single_targets();
let w = dataset.weights().unwrap_or_else(Array::ones(X.ncols()); // Use weights of 1 if no weights are provided
// snip 110-116
let X = concatenate(Axis(1), &[X.view(), Array2::ones((X.nrows(), 1)).view()]).unwrap();
let X = X * w;
// snip 118-124
let (X, y) = (X.to_owned(), y.to_owned());
let X = X * w;