Add special case for derivative of modified_bessel_function(0,x) to greatly improve model estimation speed
Description
As described here, evaluating the derivative of modified_bessel_function(0,x) can be sped up dramatically with a simple change. Issue 3008 described how to do that for the von_mises_lpdf, where the derivative is hand-coded, but that solution won't apply to custom models that use the modified_bessel_function(0,x).
In the forward and reverse passes, the derivative of the modified_bessel_function(v,x) of general order v is calculated as:
$$ \frac{\delta I_v(x)}{\delta x} = I_{v-1}(x) - \frac{v}{x}I_v(x) $$
For $I_0(x)$, this results in the calculation:
$$ \frac{\delta I_0(x)}{\delta x} = I_{-1}(x) - \frac{0}{x}I_0(x) $$
Since, $I_{-1}(x) = I_{1}(x)$ and the second term is 0 we have (see 10.29.3):
$$ \frac{\delta I_0(x)}{\delta x} = I_{1}(x) $$
As described here, calculating modified_bessel_function(1,x) is about 10 times faster than calculating modified_bessel_function(-1,x). Thus, the above code while applicable for any order, results in very inefficient calculation for models that use modified_bessel_function(0,x), which is the most common order (at least in my field). This is because it unnecessarily calculates $I_0(x)$, even though this terms disappears, and it calculate $I_{-1}(x) instead of $I_1(x)$
Example
In a model I'm currently building, which has the likelihood:
$$ f(\theta, c, k) = exp\bigg(\frac{c\ exp(y\ cos(\theta))}{2 \pi I_0(y)}\bigg)/Z(c,y) $$
after many other optimizations, now 90% of the time is spent in calculating $I_0(y)$. E.g., using the profile function of the cmdstanr package:
functions {
real sdm_lpdf(vector y, vector mu, vector kappa) {
profile("lpdf_be") {
be = modified_bessel_first_kind(0, kappa);
}
// code for calculating the rest of the likelihood
}
}
// other code
model {
// other code
profile("model_lpdf_total") {
target += sdm_lpdf(Y | mu, kappa);
}
// other code
}
shows
name thread_id total_time forward_time reverse_time chain_stack no_chain_stack autodiff_calls no_autodiff_calls
1 model_lpdf_total 1 1233.100 76.9433000 1.15616e+03 1462836268 1462164925 60918 1
2 lpdf_be 1 1201.490 52.2222000 1.14927e+03 731053200 0 60918 1
and the vast majority of that time is the reverse autodiff pass
Requested change
I envision two possibilities:
-
add a conditional statement to the fwd and rev passes that handles the derivative of the special case of modified_bessel_function(0,x)
-
replace the derivative formula with
$$ \frac{\delta I_v(x)}{\delta x} = I_{v+1}(x) + \frac{v}{x}I_v(x) $$
which is equivalent (see 10.29.2) to the current statement, but will avoid the inneficient calculation for negative order. The downside of this option is two-fold - first, it still calculates the an extra bessel function, even thought it will be canceled by multiplication by 0 (is this correct? I'm not sure how autodiff handles such cases). Second, it will make the derivative of models that use $I_1(x)$ less efficient instead.
When I rerun the model with manually changing my stan installation code with option 1, I see ~4 times faster estimation of the model (e.g. from 11h down to 3h!)
Expected Output
For option 1), a possible implementation is to change the following in rev
bvi_->adj_
+= adj_
* (-ad_ * modified_bessel_first_kind(ad_, bvi_->val_) / bvi_->val_
+ modified_bessel_first_kind(ad_ - 1, bvi_->val_));
to (described with pseudo code for conditional statements, because I don't know what is the most efficient way to code that - is it a simple if {}... else {}?)
// if ad_ == 0
bvi_->adj_
+= adj_ * modified_bessel_first_kind(1, bvi_->val_);
// else
bvi_->adj_
+= adj_
* (-ad_ * modified_bessel_first_kind(ad_, bvi_->val_) / bvi_->val_
+ modified_bessel_first_kind(ad_ - 1, bvi_->val_));
And change the following in fwd:
// if v == 0
return fvar<T>(z.d_ * modified_bessel_first_kind(1, z.val_));
// else
return fvar<T>(modified_bessel_first_kind_z,
-v * z.d_ * modified_bessel_first_kind_z / z.val_
+ z.d_ * modified_bessel_first_kind(v - 1, z.val_));
Current Version:
v4.8.0
@andrjohns I can try to implement this after our discussion in the other issue. Do you think that the conditional approach checking if the order of the bessel function is 0 is appropriate?