mlr3proba
mlr3proba copied to clipboard
Problems with Survival SVMs [discussion]
Hi,
I made some effort to train and tune survival SVMs in a small dataset. Using a simple autotune example, I found out that the SVM survival learner can either fail (some fault with the optimization solvers I think) or get stuck (training never ends, CPU at 100%). I used a lrn('surv.kaplan') as a fallback learner and added a learner$timeout to deal with these issues but I think that this instability is a bad sign for a learner. These issues mostly relate to the choice of type: whenever it's not regression there is a high chance that you will face such issues (C-indexes are close to 0.5 in the example below from using the kaplan estimator). I have seen the SVM learner fail also when type=regression (more sparsely).
I post the following tuning example here so that others benefit from this investigation. Commenting the learner$fallback and learner$timeout lines will lead to the issues I mentioned.
library(mlr3verse)
#> Loading required package: mlr3
library(mlr3proba)
library(survivalsvm)
#> Loading required package: survival
set.seed(42)
task = as_task_surv(x = veteran, time = 'time', event = 'status')
poe = po('encode')
task = poe$train(list(task))[[1]]
train_indxs = sample(seq_len(nrow(veteran)), 120)
test_indxs = setdiff(seq_len(nrow(veteran)), train_indxs)
learner = lrn('surv.svm',
type = to_tune(c('regression', 'vanbelle1', 'vanbelle2', 'hybrid')),
diff.meth = to_tune(c('makediff1', 'makediff2', 'makediff3')),
gamma.mu = to_tune(ps(
gamma = p_dbl(1e-03, 10, logscale = TRUE),
mu = p_dbl(1e-03, 10, logscale = TRUE, depends = type == 'hybrid'),
.extra_trafo = function(x, param_set) {
list(gamma.mu = c(x$gamma, x$mu))
},
.allow_dangling_dependencies = TRUE
)),
kernel = to_tune(c('lin_kernel', 'add_kernel', 'rbf_kernel', 'poly_kernel'))
)
# saves you from when the learner crashes
learner$fallback = lrn('surv.kaplan')
# saves you from when the learner is stuck
learner$timeout = c('train' = 1, 'predict' = Inf)
#learner$param_set$values$eig.tol = 1e-03
#learner$param_set$values$conv.tol = 1e-03
#learner$param_set$values$posd.tol = 1e-03
#learner$param_set$values$opt.meth = 'ipop'
#learner$param_set$values$sigf = 2
#generate_design_random(learner$param_set$search_space(), 20)
generate_design_random(learner$param_set$search_space(), 3)$transpose()
#> [[1]]
#> [[1]]$type
#> [1] "hybrid"
#>
#> [[1]]$diff.meth
#> [1] "makediff3"
#>
#> [[1]]$kernel
#> [1] "lin_kernel"
#>
#> [[1]]$gamma.mu
#> [1] 0.01853109 0.97598798
#>
#>
#> [[2]]
#> [[2]]$type
#> [1] "vanbelle2"
#>
#> [[2]]$diff.meth
#> [1] "makediff3"
#>
#> [[2]]$kernel
#> [1] "add_kernel"
#>
#> [[2]]$gamma.mu
#> [1] 0.01089036
#>
#>
#> [[3]]
#> [[3]]$type
#> [1] "hybrid"
#>
#> [[3]]$diff.meth
#> [1] "makediff3"
#>
#> [[3]]$kernel
#> [1] "lin_kernel"
#>
#> [[3]]$gamma.mu
#> [1] 0.931249 1.488555
ssvm_at = AutoTuner$new(
learner = learner,
resampling = rsmp('cv', folds = 5),
measure = msr('surv.cindex'),
terminator = trm('evals', n_evals = 10),
tuner = tnr('random_search'))
ssvm_at$train(task)
#> INFO [15:25:11.388] [bbotk] Starting to optimize 5 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10, k=0]'
#> INFO [15:25:11.436] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:25:11.462] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:25:11.503] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:25:11.844] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:25:12.144] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:25:12.450] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:25:12.765] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:25:13.232] [mlr3] Finished benchmark
#> INFO [15:25:13.261] [bbotk] Result of batch 1:
#> INFO [15:25:13.263] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:25:13.263] [bbotk] regression <NA> 1.99383 NA lin_kernel 0.6893636 0 0
#> INFO [15:25:13.263] [bbotk] runtime_learners uhash
#> INFO [15:25:13.263] [bbotk] 1.597 16669f80-5de6-4c79-a768-08928c934405
#> INFO [15:25:13.271] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:25:13.289] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:25:13.293] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:25:16.780] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:25:20.174] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:25:23.868] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:25:27.330] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:25:30.758] [mlr3] Finished benchmark
#> INFO [15:25:30.785] [bbotk] Result of batch 2:
#> INFO [15:25:30.787] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:25:30.787] [bbotk] vanbelle2 makediff2 -3.545851 NA rbf_kernel 0.5 0 5
#> INFO [15:25:30.787] [bbotk] runtime_learners uhash
#> INFO [15:25:30.787] [bbotk] NA c98a2b5b-9cb6-4b2d-8669-4aea5e396cde
#> INFO [15:25:30.796] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:25:30.823] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:25:30.828] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:25:31.536] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:25:32.367] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:25:33.067] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:25:33.807] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:25:34.653] [mlr3] Finished benchmark
#> INFO [15:25:34.679] [bbotk] Result of batch 3:
#> INFO [15:25:34.681] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:25:34.681] [bbotk] hybrid makediff1 -6.114898 1.024288 rbf_kernel 0.5238242 0 0
#> INFO [15:25:34.681] [bbotk] runtime_learners uhash
#> INFO [15:25:34.681] [bbotk] 3.703 3edbcfe6-a153-47d5-b6f2-818d504735b4
#> INFO [15:25:34.692] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:25:34.714] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:25:34.719] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:25:38.152] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:25:41.555] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:25:45.237] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:25:48.701] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:25:52.225] [mlr3] Finished benchmark
#> INFO [15:25:52.255] [bbotk] Result of batch 4:
#> INFO [15:25:52.256] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:25:52.256] [bbotk] vanbelle2 makediff2 1.982577 NA rbf_kernel 0.5 0 5
#> INFO [15:25:52.256] [bbotk] runtime_learners uhash
#> INFO [15:25:52.256] [bbotk] NA 0f54c084-42ca-4e08-bc4b-818bc7922e5f
#> INFO [15:25:52.265] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:25:52.286] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:25:52.292] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:25:55.867] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:25:59.550] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:26:03.482] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:26:07.046] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:26:10.614] [mlr3] Finished benchmark
#> INFO [15:26:10.642] [bbotk] Result of batch 5:
#> INFO [15:26:10.644] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:26:10.644] [bbotk] vanbelle2 makediff2 -3.050726 NA lin_kernel 0.5 0 5
#> INFO [15:26:10.644] [bbotk] runtime_learners uhash
#> INFO [15:26:10.644] [bbotk] NA 1d365cc0-c18b-42bc-92b0-fecac6aaac4d
#> INFO [15:26:10.653] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:26:10.670] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:26:10.676] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:26:10.932] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:26:11.186] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:26:11.435] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:26:11.676] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:26:11.925] [mlr3] Finished benchmark
#> INFO [15:26:11.954] [bbotk] Result of batch 6:
#> INFO [15:26:11.955] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:26:11.955] [bbotk] regression <NA> -5.757422 NA lin_kernel 0.6854107 0 0
#> INFO [15:26:11.955] [bbotk] runtime_learners uhash
#> INFO [15:26:11.955] [bbotk] 1.127 cde21793-0b12-4566-8e8b-2bb756563a27
#> INFO [15:26:11.965] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:26:11.988] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:26:11.996] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:26:12.304] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:26:12.608] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:26:12.900] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:26:13.192] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:26:13.475] [mlr3] Finished benchmark
#> INFO [15:26:13.503] [bbotk] Result of batch 7:
#> INFO [15:26:13.504] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:26:13.504] [bbotk] regression <NA> 0.2568419 NA lin_kernel 0.6893636 0 0
#> INFO [15:26:13.504] [bbotk] runtime_learners uhash
#> INFO [15:26:13.504] [bbotk] 1.352 f4e9fa94-2b73-4d39-8e67-f864d3a7c71b
#> INFO [15:26:13.513] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:26:13.531] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:26:13.536] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:26:14.563] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:26:15.517] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:26:16.501] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:26:17.924] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:26:19.006] [mlr3] Finished benchmark
#> INFO [15:26:19.041] [bbotk] Result of batch 8:
#> INFO [15:26:19.043] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:26:19.043] [bbotk] hybrid makediff3 -1.907343 -6.24123 add_kernel 0.5645394 0 1
#> INFO [15:26:19.043] [bbotk] runtime_learners uhash
#> INFO [15:26:19.043] [bbotk] NA 8ae16fcb-7969-420b-af39-56a0ce68a74c
#> INFO [15:26:19.053] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:26:19.072] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:26:19.077] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:26:22.564] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:26:26.089] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:26:29.908] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:26:33.430] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:26:36.925] [mlr3] Finished benchmark
#> INFO [15:26:36.953] [bbotk] Result of batch 9:
#> INFO [15:26:36.955] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:26:36.955] [bbotk] vanbelle2 makediff2 -1.883382 NA lin_kernel 0.5 0 5
#> INFO [15:26:36.955] [bbotk] runtime_learners uhash
#> INFO [15:26:36.955] [bbotk] NA 5a765b5a-0741-4f75-95c3-9096c3916b65
#> INFO [15:26:36.965] [bbotk] Evaluating 1 configuration(s)
#> INFO [15:26:36.983] [mlr3] Running benchmark with 5 resampling iterations
#> INFO [15:26:36.988] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 1/5)
#> INFO [15:26:37.058] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 2/5)
#> INFO [15:26:37.137] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 3/5)
#> INFO [15:26:37.210] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 4/5)
#> INFO [15:26:37.284] [mlr3] Applying learner 'surv.svm' on task 'veteran' (iter 5/5)
#> INFO [15:26:37.357] [mlr3] Finished benchmark
#> INFO [15:26:37.386] [bbotk] Result of batch 10:
#> INFO [15:26:37.388] [bbotk] type diff.meth gamma mu kernel surv.cindex warnings errors
#> INFO [15:26:37.388] [bbotk] vanbelle1 makediff1 -5.990032 NA rbf_kernel 0.5337007 0 0
#> INFO [15:26:37.388] [bbotk] runtime_learners uhash
#> INFO [15:26:37.388] [bbotk] 0.242 d249648b-6561-4155-9347-243b96263347
#> INFO [15:26:37.410] [bbotk] Finished optimizing after 10 evaluation(s)
#> INFO [15:26:37.410] [bbotk] Result:
#> INFO [15:26:37.412] [bbotk] type diff.meth gamma mu kernel learner_param_vals x_domain
#> INFO [15:26:37.412] [bbotk] regression <NA> 1.99383 NA lin_kernel <list[3]> <list[3]>
#> INFO [15:26:37.412] [bbotk] surv.cindex
#> INFO [15:26:37.412] [bbotk] 0.6893636
Created on 2022-08-15 by the reprex package (v2.0.1)
Heya thanks for raising the issue. To be very honest I've never had success in tuning {survivalsvm} successfully (even outside of this package). It's been buggy for ages and I'm unconvinced by the underlying implementation.
Just looking at your code above some quick comments: 1) I'd always recommend using hybrid and never tune by type as all others are just a special case of hybrid when gamma or mu are 0. 2) I've noticed choice of kernel can affect crashing.
Would you mind experimenting with {survivalsvm} directly and not via mlr3proba to see if the problem persists?
Hi Raphael,
Great to find another person who has found survival SVMs unstable. I wouldn't recommend this learner to anyone unless hyperparameters are hand-picked and no proper tuning is applied (which is, well, not nice).
I did some tests with hybrid type while tuning the gamma.mu and kernel and it seems to be the case that the polynomial kernel is the one that causes the issue (but that may depend on the dataset or other things of course, I have no idea). An example hyperparameter configuration that fails is the following:
library(survivalsvm)
#> Loading required package: survival
fit = survivalsvm(Surv(time, status) ~ ., data = veteran, type = 'hybrid',
gamma.mu = c(0.76, 0.09), diff.meth = 'makediff3',
kernel = 'poly_kernel')
#> Error in tcrossprod(K, Dc): non-conformable arguments
Created on 2022-08-21 by the reprex package (v2.0.1)
Yup, buggy! I'm going to close the issue here. I don't think we should add a warning to the learner as in reality it will just perform badly and people will choose other learners. You might want to consider opening an issue in survivalsvm though?