marginaleffects icon indicating copy to clipboard operation
marginaleffects copied to clipboard

Difference results marginaleffects vs riskRegression/predict (cox)

Open csthiago opened this issue 1 year ago • 4 comments

Hi Vincent,

I was testing the marginaleffects with cox regression. I get slightly different results that using predict from survival (point estimate).


library(survival)
library(riskRegression)
library(marginaleffects)
set.seed(100)

dt <- survival::bladder
dt$rx <- as.factor(dt$rx)
fit <- coxph(formula = Surv(stop,event)~ rx+size+number,data=dt,y=TRUE,x=TRUE)

avg_comparisons(fit, variables = "rx",newdata = datagrid(stop=20),
            type="survival")
rx1_mean <- predict(fit, type="survival",
        newdata = dt |> mutate(stop=20,
                               rx="1")) |> 
  mean()
rx2_mean <- predict(fit, type="survival",
        newdata = dt |> mutate(stop=20,
                               rx="2")) |> 
  mean()
rx2_mean-rx1_mean
ateFit1a <- ate(fit, data = dt, treatment = "rx", times = 20)
summary(ateFit1a)

The value I get from avg_comparisons is image The point estimate using predict is 0.1034387 and the values from riskRegression ATE is: image

Do you have any idea why the difference?

Thank you very much

csthiago avatar Oct 21 '24 13:10 csthiago

datagrid() sets all unnamed variables to their mean or mode. That could be a difference, but I don't know what the ate() function does.

vincentarelbundock avatar Oct 21 '24 15:10 vincentarelbundock

the ate calculates the average treatment effect using g-formula (or ipw or double robust). I have redone using dt > mutate and the PE is equal (silly me). The CI is different, but the riskRegression uses influence function instead delta. Just to confirm one point. To calculate risk difference for cox regression (in each time), the sytanx would be:

avg_comparisons(fit, variables = "rx",by="time",
            type="survival")

and for risk ratio would be type="lp" ?

Thanks

csthiago avatar Oct 21 '24 16:10 csthiago

As always, this function computes a difference on the specified scale. Here, type="lp" means a difference between the two groups on the linear probability scale.

You can compute ratios instead of differences on any scale, by specifying type with avg_comparisons(fit, variables = "rx", comparison = "ratio")

vincentarelbundock avatar Oct 21 '24 17:10 vincentarelbundock

To do g-computation using avg_comparisons() to get the same estimate as the ate() function, you only need to add grid_type = "counterfactual" to datagrid. So running the following should yield the correct estimate:

avg_comparisons(fit, variables = "rx", newdata = datagrid(stop = 20, grid_type = "counterfactual"),
                type = "survival")

Note that the predictions are on the survival probability scale when setting type = "survival", whereas for ate(), the predictions are on the risk scale (1 - survival). That means to get the risk difference with the correct sign, you need to manually program the comparison:

avg_comparisons(fit, variables = "rx", newdata = datagrid(stop = 20, grid_type = "counterfactual"),
                type = "survival", comparison = \(hi, lo) (1 - mean(hi)) - (1 - mean(lo)))

To get the risk ratio, you again need to manually program the comparison because the usual method of setting comparison = "ratio" computes the survival ratio. So the risk ratio would be:

avg_comparisons(fit, variables = "rx", newdata = datagrid(stop = 20, grid_type = "counterfactual"),
                type = "survival", comparison = \(hi, lo) (1 - mean(hi)) / (1 - mean(lo)))

This uses a symmetric confidence interval around the ratio, which may not be accurate. Instead, we can compute the symmetric confidence interval around the log of the risk ratio and then exponentiate that:

avg_comparisons(fit, variables = "rx",newdata = datagrid(stop = 20, grid_type = "counterfactual"),
                type = "survival", comparison = \(hi, lo) log((1 - mean(hi))/(1 - mean(lo))), transform = "exp")

The point estimate will be the same, and agrees with that of ate(). I can't speak to which method of computing the standard error is more accurate (delta method vs. influence function).

Note, to get the marginal risk predictions using avg_predictions(), you can supply the same arguments (without comparisons) but to get the risks rather than survival probabilities you would use byfun, e.g.,

avg_predictions(fit, variables = "rx", newdata = datagrid(stop = 20, grid_type = "counterfactual"),
                type = "survival", byfun = \(...) 1 - weighted.mean(...))

ngreifer avatar Oct 25 '24 18:10 ngreifer

Thank you very much, @ngreifer !

csthiago avatar Oct 26 '24 09:10 csthiago

Thanks @csthiago for raising this and @ngreifer for sovling it! I really appreciate it.

I'll leave this open for now as I'll soon be working on a potential coxph vignette. Might refer back to this eventually.

vincentarelbundock avatar Oct 26 '24 15:10 vincentarelbundock