linfa icon indicating copy to clipboard operation
linfa copied to clipboard

Adding Multi-Task ElasticNet support

Open YuhanLiin opened this issue 3 years ago • 3 comments

Continuation of #194

YuhanLiin avatar Aug 14 '22 21:08 YuhanLiin

Running the new multi-task example gives the following output:

intercept:  [182.11111111111111, 35.666666666666664, 55.55555555555556]
params: [[-0.9723742003724933, -0.12992938479472216, 0.20256364290951492],
 [0.017231919622246364, -0.00785311972200309, 0.006638074127064588],
 [0.0269082650844912, 0.021197761913871658, -0.027310155988705367]]
z score: Ok([[-1.0608739975723132, -1.6800812346255631, 2.2563865434388126],
 [0.018800267889146287, -0.10154653698276617, 0.0739424949093583],
 [0.029357297568142263, 0.274102444676571, -0.3042118890982639]], shape=[3, 3], strides=[3, 1], layout=Cc (0x5), const ndim=2)
predicted variance: [-47.348308658688005, -2.278139252532177, -50.96366188947603]

The variance looks pretty high, but I'm not sure if that's an issue.

YuhanLiin avatar Aug 14 '22 21:08 YuhanLiin

Codecov Report

Base: 38.68% // Head: 38.59% // Decreases project coverage by -0.09% :warning:

Coverage data is based on head (9551892) compared to base (44b244c). Patch coverage: 31.70% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #238      +/-   ##
==========================================
- Coverage   38.68%   38.59%   -0.10%     
==========================================
  Files          93       93              
  Lines        6087     6223     +136     
==========================================
+ Hits         2355     2402      +47     
- Misses       3732     3821      +89     
Impacted Files Coverage Δ
algorithms/linfa-elasticnet/src/hyperparams.rs 14.58% <0.00%> (ø)
algorithms/linfa-elasticnet/src/lib.rs 0.00% <0.00%> (ø)
algorithms/linfa-elasticnet/src/algorithm.rs 35.05% <33.33%> (-2.14%) :arrow_down:
...rithms/linfa-trees/src/decision_trees/algorithm.rs 39.73% <0.00%> (+1.78%) :arrow_up:

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

codecov-commenter avatar Aug 14 '22 21:08 codecov-commenter

+1 can you point me to the test with high variance?

It's not a test, but the new example I added

YuhanLiin avatar Aug 26 '22 06:08 YuhanLiin

reviewed the example and made two changes to make the explained variance more usable

  1. use more than two samples for validation, otherwise the second class has zero variance
  2. compare validation dataset to estimated values (R2 is not symmetric)
--- a/algorithms/linfa-elasticnet/examples/multitask_elasticnet.rs
+++ b/algorithms/linfa-elasticnet/examples/multitask_elasticnet.rs
@@ -3,7 +3,7 @@ use linfa_elasticnet::{MultiTaskElasticNet, Result};

 fn main() -> Result<()> {
     // load Diabetes dataset
-    let (train, valid) = linfa_datasets::linnerud().split_with_ratio(0.90);
+    let (train, valid) = linfa_datasets::linnerud().split_with_ratio(0.80);

     // train pure LASSO model with 0.1 penalty
     let model = MultiTaskElasticNet::params()
@@ -18,7 +18,7 @@ fn main() -> Result<()> {

     // validate
     let y_est = model.predict(&valid);
-    println!("predicted variance: {}", valid.r2(&y_est)?);
+    println!("predicted variance: {}", y_est.r2(&valid)?);

     Ok(())
 }

which gives

predicted variance: [-4.143623744690414, -0.2630142563112303, -0.2542410304293199]

so worse than taking the average, but the dataset is really small 😅

bytesnake avatar Nov 09 '22 09:11 bytesnake