msm
msm copied to clipboard
Add fast-path to rtnorm
Twice faster for the case when all arguments have length 1.
expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
<bch:expr> <bch:t> <bch:t> <dbl> <bch:byt> <dbl> <int> <dbl> <bch:tm>
1 test(rtno… 43.8ms 47.3ms 18.6 13.2MB 6.75 58 21 3.11s
2 test(msm:… 107.3ms 114.4ms 8.16 23.2MB 13.1 15 24 1.84s
Code to benchmark (note that the benchmark also checks that the output is identical):
library(msm)
rtnorm_new <- function (n, mean = 0, sd = 1, lower = -Inf, upper = Inf) {
if (length(n) > 1)
n <- length(n)
# sd <- vapply(sd, max, numeric(1), 1e-15, USE.NAMES=FALSE) # Small values of sd break the function.
# Fast-path for frequent case.
if (length(mean) == 1L && length(sd) == 1L && length(lower) == 1L && length(upper) == 1L) {
lower <- (lower - mean) / sd ## Algorithm works on mean 0, sd 1 scale
upper <- (upper - mean) / sd
nas <- is.na(mean) | is.na(sd) | is.na(lower) | is.na(upper)
if (any(nas)) warning("NAs produced")
alg <- if ((lower > upper) && nas) -1L # return NaN
else if ((lower < 0 && upper == Inf) ||
(lower == -Inf && upper > 0) ||
(is.finite(lower) && is.finite(upper) && (lower < 0) && (upper > 0) && (upper - lower > sqrt(2*pi))))
0L # standard "simulate from normal and reject if outside limits" method. Use if bounds are wide.
else if (lower >= 0 && (upper > lower + 2*sqrt(exp(1)) /
(lower + sqrt(lower^2 + 4)) * exp((lower*2 - lower*sqrt(lower^2 + 4)) / 4)))
1L # rejection sampling with exponential proposal. Use if lower >> mean
else if (upper <= 0 && (-lower > -upper + 2*sqrt(exp(1)) /
(-upper + sqrt(upper^2 + 4)) * exp((upper*2 - -upper*sqrt(upper^2 + 4)) / 4)))
2L # rejection sampling with exponential proposal. Use if upper << mean.
else 3L # rejection sampling with uniform proposal. Use if bounds are narrow and central.
ret <- rep_len(NaN, n)
if (alg == -1L) {
return(ret)
} else if (alg == 0L) {
ind.no <- seq_len(n)
while (length(ind.no) > 0) {
y <- rnorm(length(ind.no))
done <- which(y >= lower & y <= upper)
ret[ind.no[done]] <- y[done]
ind.no <- setdiff(ind.no, ind.no[done])
}
} else if (alg == 1L) {
ind.expl <- seq_len(n)
a <- (lower + sqrt(lower^2 + 4)) / 2
while (length(ind.expl) > 0) {
z <- rexp(length(ind.expl), a) + lower
u <- runif(length(ind.expl))
done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= upper))
ret[ind.expl[done]] <- z[done]
ind.expl <- setdiff(ind.expl, ind.expl[done])
}
} else if (alg == 2L) {
ind.expu <- seq_len(n)
a <- (-upper + sqrt(upper^2 +4)) / 2
while (length(ind.expu) > 0) {
z <- rexp(length(ind.expu), a) - upper
u <- runif(length(ind.expu))
done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= -lower))
ret[ind.expu[done]] <- -z[done]
ind.expu <- setdiff(ind.expu, ind.expu[done])
}
} else {
ind.u <- seq_len(n)
K <- if (lower > 0) lower^2 else if (upper < 0) upper^2 else 0
while (length(ind.u) > 0) {
z <- runif(length(ind.u), lower, upper)
rho <- exp((K - z^2) / 2)
u <- runif(length(ind.u))
done <- which(u <= rho)
ret[ind.u[done]] <- z[done]
ind.u <- setdiff(ind.u, ind.u[done])
}
}
} else {
mean <- rep(mean, length=n)
sd <- rep(sd, length=n)
lower <- rep(lower, length=n)
upper <- rep(upper, length=n)
lower <- (lower - mean) / sd ## Algorithm works on mean 0, sd 1 scale
upper <- (upper - mean) / sd
ind <- seq(length.out=n)
ret <- numeric(n)
nas <- is.na(mean) | is.na(sd) | is.na(lower) | is.na(upper)
if (any(nas)) warning("NAs produced")
## Different algorithms depending on where upper/lower limits lie.
alg <- ifelse(
((lower > upper) | nas),
-1,# return NaN
ifelse(
((lower < 0 & upper == Inf) |
(lower == -Inf & upper > 0) |
(is.finite(lower) & is.finite(upper) & (lower < 0) & (upper > 0) & (upper-lower > sqrt(2*pi)))
),
0, # standard "simulate from normal and reject if outside limits" method. Use if bounds are wide.
ifelse(
(lower >= 0 & (upper > lower + 2*sqrt(exp(1)) /
(lower + sqrt(lower^2 + 4)) * exp((lower*2 - lower*sqrt(lower^2 + 4)) / 4))),
1, # rejection sampling with exponential proposal. Use if lower >> mean
ifelse(upper <= 0 & (-lower > -upper + 2*sqrt(exp(1)) /
(-upper + sqrt(upper^2 + 4)) * exp((upper*2 - -upper*sqrt(upper^2 + 4)) / 4)),
2, # rejection sampling with exponential proposal. Use if upper << mean.
3)))) # rejection sampling with uniform proposal. Use if bounds are narrow and central.
ind.nan <- ind[alg==-1]; ind.no <- ind[alg==0]; ind.expl <- ind[alg==1]; ind.expu <- ind[alg==2]; ind.u <- ind[alg==3]
ret[ind.nan] <- NaN
while (length(ind.no) > 0) {
y <- rnorm(length(ind.no))
done <- which(y >= lower[ind.no] & y <= upper[ind.no])
ret[ind.no[done]] <- y[done]
ind.no <- setdiff(ind.no, ind.no[done])
}
stopifnot(length(ind.no) == 0)
while (length(ind.expl) > 0) {
a <- (lower[ind.expl] + sqrt(lower[ind.expl]^2 + 4)) / 2
z <- rexp(length(ind.expl), a) + lower[ind.expl]
u <- runif(length(ind.expl))
done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= upper[ind.expl]))
ret[ind.expl[done]] <- z[done]
ind.expl <- setdiff(ind.expl, ind.expl[done])
}
stopifnot(length(ind.expl) == 0)
while (length(ind.expu) > 0) {
a <- (-upper[ind.expu] + sqrt(upper[ind.expu]^2 +4)) / 2
z <- rexp(length(ind.expu), a) - upper[ind.expu]
u <- runif(length(ind.expu))
done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= -lower[ind.expu]))
ret[ind.expu[done]] <- -z[done]
ind.expu <- setdiff(ind.expu, ind.expu[done])
}
stopifnot(length(ind.expu) == 0)
while (length(ind.u) > 0) {
z <- runif(length(ind.u), lower[ind.u], upper[ind.u])
rho <- ifelse(lower[ind.u] > 0,
exp((lower[ind.u]^2 - z^2) / 2), ifelse(upper[ind.u] < 0,
exp((upper[ind.u]^2 - z^2) / 2),
exp(-z^2/2)))
u <- runif(length(ind.u))
done <- which(u <= rho)
ret[ind.u[done]] <- z[done]
ind.u <- setdiff(ind.u, ind.u[done])
}
stopifnot(length(ind.u) == 0)
}
ret*sd + mean
}
test <- function(f) {
set.seed(42)
unlist(mapply(f, n = sample(c(1,1,10,100), 1000, replace=TRUE), mean = 2*runif(1000), sd = 0.001+runif(1000), lower=0, upper=1))
}
library(bench)
(x <- bench::mark(
test(rtnorm_new),
test(msm::rtnorm),
check=TRUE, min_time=5))
plot(x)
Hi Manuel - Thanks for this work. I'm unsure about putting it in though, because it results in a lot of duplicated code, and it makes the resulting function very long. I was curious what exactly it is it that makes such a difference to the efficiency? Presumably something inside ifelse()
is doing a lot of unnecessary work, and if so, is that a clue to making a version that would handle both cases more cleanly?
Also I get the impression that there are several other implementations of the truncated normal distribution in other R packages - I haven't investigated these, but I'm curious what they all do differently. If there are more advanced/efficient ones out there (C++-based implementations perhaps??), then there probably isn't much point refining the one in msm
. After all it's not really anything to do with multistate models.
There is this one: https://github.com/olafmersmann/truncnorm But I haven't checked it for correctness.