msm icon indicating copy to clipboard operation
msm copied to clipboard

Add fast-path to rtnorm

Open MLopez-Ibanez opened this issue 1 year ago • 4 comments

MLopez-Ibanez avatar May 18 '23 15:05 MLopez-Ibanez

Twice faster for the case when all arguments have length 1.

image

  expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
  <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
1 test(rtno…  43.8ms  47.3ms     18.6     13.2MB     6.75    58    21      3.11s
2 test(msm:… 107.3ms 114.4ms      8.16    23.2MB    13.1     15    24      1.84s

MLopez-Ibanez avatar May 18 '23 15:05 MLopez-Ibanez

Code to benchmark (note that the benchmark also checks that the output is identical):

library(msm)

rtnorm_new <- function (n, mean = 0, sd = 1, lower = -Inf, upper = Inf) {
    if (length(n) > 1)
        n <- length(n)
    # sd <- vapply(sd, max, numeric(1), 1e-15, USE.NAMES=FALSE) # Small values of sd break the function.
    # Fast-path for frequent case.
    if (length(mean) == 1L && length(sd) == 1L && length(lower) == 1L && length(upper) == 1L) {
      lower <- (lower - mean) / sd ## Algorithm works on mean 0, sd 1 scale
      upper <- (upper - mean) / sd
      nas <- is.na(mean) | is.na(sd) | is.na(lower) | is.na(upper)
      if (any(nas)) warning("NAs produced")
      alg <- if ((lower > upper) && nas) -1L # return NaN
      else if ((lower < 0 && upper == Inf) ||
               (lower == -Inf && upper > 0) ||
               (is.finite(lower) && is.finite(upper) && (lower < 0) && (upper > 0) && (upper - lower > sqrt(2*pi))))
        0L # standard "simulate from normal and reject if outside limits" method. Use if bounds are wide.
      else if (lower >= 0 && (upper > lower + 2*sqrt(exp(1)) /
                              (lower + sqrt(lower^2 + 4)) * exp((lower*2 - lower*sqrt(lower^2 + 4)) / 4)))
        1L # rejection sampling with exponential proposal. Use if lower >> mean
      else if (upper <= 0 && (-lower > -upper + 2*sqrt(exp(1)) /
                              (-upper + sqrt(upper^2 + 4)) * exp((upper*2 - -upper*sqrt(upper^2 + 4)) / 4)))
        2L # rejection sampling with exponential proposal. Use if upper << mean.
      else 3L # rejection sampling with uniform proposal. Use if bounds are narrow and central.

      ret <- rep_len(NaN, n)
      if (alg == -1L) {
        return(ret)
      } else if (alg == 0L) {
        ind.no <- seq_len(n)
        while (length(ind.no) > 0) {
          y <- rnorm(length(ind.no))
          done <- which(y >= lower & y <= upper)
          ret[ind.no[done]] <- y[done]
          ind.no <- setdiff(ind.no, ind.no[done])
        }
      } else if (alg == 1L) {
        ind.expl <- seq_len(n)
        a <- (lower + sqrt(lower^2 + 4)) / 2
        while (length(ind.expl) > 0) {
          z <- rexp(length(ind.expl), a) + lower
          u <- runif(length(ind.expl))
          done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= upper))
          ret[ind.expl[done]] <- z[done]
          ind.expl <- setdiff(ind.expl, ind.expl[done])
        }
      } else if (alg == 2L) {
        ind.expu <- seq_len(n)
        a <- (-upper + sqrt(upper^2 +4)) / 2
        while (length(ind.expu) > 0) {
          z <- rexp(length(ind.expu), a) - upper
          u <- runif(length(ind.expu))
          done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= -lower))
          ret[ind.expu[done]] <- -z[done]
          ind.expu <- setdiff(ind.expu, ind.expu[done])
        }
      } else {
        ind.u <- seq_len(n)
        K <- if (lower > 0) lower^2 else if (upper < 0) upper^2 else 0
        while (length(ind.u) > 0) {
          z <- runif(length(ind.u), lower, upper)
          rho <- exp((K - z^2) / 2)
          u <- runif(length(ind.u))
          done <- which(u <= rho)
          ret[ind.u[done]] <- z[done]
          ind.u <- setdiff(ind.u, ind.u[done])
        }
      }
    } else {
    mean <- rep(mean, length=n)
    sd <- rep(sd, length=n)
    lower <- rep(lower, length=n)
    upper <- rep(upper, length=n)
    lower <- (lower - mean) / sd ## Algorithm works on mean 0, sd 1 scale
    upper <- (upper - mean) / sd
    ind <- seq(length.out=n)
    ret <- numeric(n)
    nas <- is.na(mean) | is.na(sd) | is.na(lower) | is.na(upper)
    if (any(nas)) warning("NAs produced")
    ## Different algorithms depending on where upper/lower limits lie.
    alg <- ifelse(
                  ((lower > upper) | nas),
                  -1,# return NaN
                  ifelse(
                         ((lower < 0 & upper == Inf) |
                          (lower == -Inf & upper > 0) |
                          (is.finite(lower) & is.finite(upper) & (lower < 0) & (upper > 0) & (upper-lower > sqrt(2*pi)))
                          ),
                         0, # standard "simulate from normal and reject if outside limits" method. Use if bounds are wide.
                         ifelse(
                                (lower >= 0 & (upper > lower + 2*sqrt(exp(1)) /
                                 (lower + sqrt(lower^2 + 4)) * exp((lower*2 - lower*sqrt(lower^2 + 4)) / 4))),
                                1, # rejection sampling with exponential proposal. Use if lower >> mean
                                ifelse(upper <= 0 & (-lower > -upper + 2*sqrt(exp(1)) /
                                       (-upper + sqrt(upper^2 + 4)) * exp((upper*2 - -upper*sqrt(upper^2 + 4)) / 4)),
                                       2, # rejection sampling with exponential proposal. Use if upper << mean.
                                       3)))) # rejection sampling with uniform proposal. Use if bounds are narrow and central.

    ind.nan <- ind[alg==-1]; ind.no <- ind[alg==0]; ind.expl <- ind[alg==1]; ind.expu <- ind[alg==2]; ind.u <- ind[alg==3]
    ret[ind.nan] <- NaN
    while (length(ind.no) > 0) {
        y <- rnorm(length(ind.no))
        done <- which(y >= lower[ind.no] & y <= upper[ind.no])
        ret[ind.no[done]] <- y[done]
        ind.no <- setdiff(ind.no, ind.no[done])
    }
    stopifnot(length(ind.no) == 0)
    while (length(ind.expl) > 0) {
        a <- (lower[ind.expl] + sqrt(lower[ind.expl]^2 + 4)) / 2
        z <- rexp(length(ind.expl), a) + lower[ind.expl]
        u <- runif(length(ind.expl))
        done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= upper[ind.expl]))
        ret[ind.expl[done]] <- z[done]
        ind.expl <- setdiff(ind.expl, ind.expl[done])
    }
    stopifnot(length(ind.expl) == 0)
    while (length(ind.expu) > 0) {
        a <- (-upper[ind.expu] + sqrt(upper[ind.expu]^2 +4)) / 2
        z <- rexp(length(ind.expu), a) - upper[ind.expu]
        u <- runif(length(ind.expu))
        done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= -lower[ind.expu]))
        ret[ind.expu[done]] <- -z[done]
        ind.expu <- setdiff(ind.expu, ind.expu[done])
    }
    stopifnot(length(ind.expu) == 0)
    while (length(ind.u) > 0) {
        z <- runif(length(ind.u), lower[ind.u], upper[ind.u])
        rho <- ifelse(lower[ind.u] > 0,
                      exp((lower[ind.u]^2 - z^2) / 2), ifelse(upper[ind.u] < 0,
                                                            exp((upper[ind.u]^2 - z^2) / 2),
                                                            exp(-z^2/2)))
        u <- runif(length(ind.u))
        done <- which(u <= rho)
        ret[ind.u[done]] <- z[done]
        ind.u <- setdiff(ind.u, ind.u[done])
    }
    stopifnot(length(ind.u) == 0)
    }
    ret*sd + mean
}

test <- function(f) {
  set.seed(42)
  unlist(mapply(f, n = sample(c(1,1,10,100), 1000, replace=TRUE), mean = 2*runif(1000), sd = 0.001+runif(1000), lower=0, upper=1))
}

library(bench)
(x <- bench::mark(
               test(rtnorm_new),
               test(msm::rtnorm),
               check=TRUE, min_time=5))
plot(x)

MLopez-Ibanez avatar May 18 '23 15:05 MLopez-Ibanez

Hi Manuel - Thanks for this work. I'm unsure about putting it in though, because it results in a lot of duplicated code, and it makes the resulting function very long. I was curious what exactly it is it that makes such a difference to the efficiency? Presumably something inside ifelse() is doing a lot of unnecessary work, and if so, is that a clue to making a version that would handle both cases more cleanly?

Also I get the impression that there are several other implementations of the truncated normal distribution in other R packages - I haven't investigated these, but I'm curious what they all do differently. If there are more advanced/efficient ones out there (C++-based implementations perhaps??), then there probably isn't much point refining the one in msm. After all it's not really anything to do with multistate models.

chjackson avatar May 22 '23 12:05 chjackson

There is this one: https://github.com/olafmersmann/truncnorm But I haven't checked it for correctness.

MLopez-Ibanez avatar May 22 '23 13:05 MLopez-Ibanez