odin icon indicating copy to clipboard operation
odin copied to clipboard

Support Matrix Multiplication

Open CGMossa opened this issue 2 years ago • 6 comments

Benchmarking an R+deSolve code against the equivalent odin code yielded a surprising result:

# A tibble: 2 × 13
  expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory     time      
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list>     <list>    
1 odin_c        303ms    305ms      3.28    8.59MB      0       2     0      610ms <NULL> <Rprofmem> <bench_tm>
2 deSolve_r     121ms    121ms      8.23   38.04MB     24.7     1     3      121ms <NULL> <Rprofmem> <bench_tm>
# ℹ 1 more variable: gc <list>

I suspect the culprit is the lack of matrix multiplication in odin (or maybe I don't know how to invoke it). In the deSolve part:

between_sites <- transmission_rate * S * (foi_matrix %*% I)

However, in the odin part I do:

foi[, ] <- transmission_rate * foi_mat[i, j] * I[j]
delta_transmission[] <- transmission_rate * S[i] * I[i] + transmission_rate * S[i] * sum(foi[i,])
  deriv(S[]) <- -delta_transmission[i]
  deriv(I[]) <- +delta_transmission[i] - delta_recovery[i]

Here I've omitted the parts I don't think are necessary.

  • Is matrix multiplication supported? Unfortunately, R's C-facilities have a weird way of doing matrix multiplication, so it might not be supported yet.

I'll work on a minimal testcase to check if this is indeed the problem.

Under details, I have more complete excerpts of my code:

Details

odin::odin({
  foi[, ] <- foi_mat[i, j] * I[j]
  delta_transmission[] <- transmission_rate * S[i] * I[i] + transmission_rate * S[i] * sum(foi[i,])
  delta_recovery[] <- recovery_rate * I[i]
  deriv(S[]) <- -delta_transmission[i]
  deriv(I[]) <- +delta_transmission[i] - delta_recovery[i]
  deriv(R[]) <- delta_recovery[i]

  source_id <- user()
  target_id <- user()
  Total[] <- S[i] + I[i] + R[i]
  output(source_prevalence) <- I[as.integer(source_id)] / Total[as.integer(source_id)]
  output(target_prevalence) <- I[as.integer(target_id)] / Total[as.integer(target_id)]

  transmission_rate <- user(0.05)
  recovery_rate <- user(0.01)
  S0[] <- user()
  I0[] <- user()
  foi_mat[,] <- user()
  initial(S[]) <- S0[i]
  initial(I[]) <- I0[i]
  initial(R[]) <- Total[i] - S0[i] - I0[i]
  dim(S0) <- user()
  dim(S) <- N
  dim(I) <- N
  dim(R) <- N
  dim(I0) <- N
  dim(foi_mat) <- c(N, N)
  dim(foi) <- c(N, N)
  dim(Total) <- N
  dim(delta_transmission) <- N
  dim(delta_recovery) <- N
  N <- length(S0)
},
verbose = TRUE, validate = TRUE, target = "c", pretty = TRUE,
skip_cache = FALSE) ->
  model_generator

model <-
  model_generator$new(S0 = site_S,
                      I0 = site_I,
                      foi_mat = foi_matrix,
                      source_id = as.integer(source_site_id),
                      target_id = as.integer(target_site_id))
model$set_user(transmission_rate = 0.05, recovery_rate = 0.01)

Benchmarking:

bench::mark(
  odin_c = model$run(0:100),
  deSolve_r = {
    
    site_R <- site_S
    site_R[] <- 0
    deSolve::ode(
      y = c(S = site_S, I = site_I, R = site_R),
      times = 0:100,
      func = function(time, state, parms) {
        with(parms, {
          S <- state[1:N]
          I <- state[(N + 1):(2 * N)]
          R <- state[(2 * N + 1):(3 * N)]
          
          Total <- S + I + R
          
          between_sites <- transmission_rate * S * (foi_matrix %*% I)
          
          
          source_prevalence <- I[[source_id]] / Total[[source_id]]
          target_prevalence <- I[[target_id]] / Total[[target_id]]
          
          list(c(
            dS = -transmission_rate * S * I - between_sites,
            dI = +transmission_rate * S * I + between_sites - recovery_rate * I,
            dR = recovery_rate * I
          ),
          source_prevalence = source_prevalence,
          target_prevalence = target_prevalence)
        })
      },
      parms = list(
        transmission_rate = 0.05,
        recovery_rate = 0.01,
        source_id = as.integer(source_site_id),
        target_id = as.integer(target_site_id),
        N = length(site_S)
      )
    )
  },
  check = FALSE
) %>% 
  print()

CGMossa avatar May 05 '23 10:05 CGMossa

unfortunately this is not that surprising - if the model is dominated by a matrix multiplication, then the version that uses a linear algebra library will be much faster.

Supporting this properly has been on the back burner for a long time (#38, #134, #213 - these mostly concern multinomial distributions but the syntactic issue in #134 is shared and is the primary blocker). The actual calling convention is not that bad, though it does mean that models need to have a working copy of gfortran to compile which is quite annoying in practice, particularly for people on macs

richfitz avatar May 05 '23 10:05 richfitz

I'm glad you agree. For my use-case, I can circumvent this by being a little more clever about this. But to stick to this issue, and since you know this stuff already:

  • How come you cannot inherit the OS-specific settings for this stuff that R does for itself on these platforms? First, I would think (maybe naively) that you can use R CMD config to compile with the right flags on different platforms:
C:\Users\minin>R CMD config LAPACK_LIBS
-LC:/Users/minin/scoop/apps/r/current/bin/x64 -lRlapack

But if I just think about BLAS (whatever that is). First, it says:

 R packages that use these should have PKG_LIBS in src/Makevars include
   $(BLAS_LIBS) $(FLIBS)

So on my Windows machine it is

C:\Users\minin>R CMD config FLIBS
-lgfortran -lm -lquadmath

C:\Users\minin>R CMD config BLAS_LIBS
-LC:/Users/minin/scoop/apps/r/current/bin/x64 -lRblas

Then apparently dgemm is the Fortran routine that is supposed to do this, I've copied the prototype/header:


/* DGEMM - perform one of the matrix-matrix operations    */
/* C := alpha*op( A )*op( B ) + beta*C */
BLAS_extern void
F77_NAME(dgemm)(const char *transa, const char *transb, const int *m,
		const int *n, const int *k, const double *alpha,
		const double *a, const int *lda,
		const double *b, const int *ldb,
		const double *beta, double *c, const int *ldc 
		FCLEN FCLEN);

Finally, I've asked ChatGPT about this and it suggested this code for invoking this:

SEXP matrix_mult(SEXP a, SEXP b) {
  SEXP result;
  int nrow_a = nrows(a);
  int ncol_a = ncols(a);
  int nrow_b = nrows(b);
  int ncol_b = ncols(b);

  if (ncol_a != nrow_b) {
    error("Matrix dimensions do not match for multiplication.");
    return R_NilValue;
  }

  PROTECT(result = allocMatrix(REALSXP, nrow_a, ncol_b));

  double alpha = 1.0;
  double beta = 0.0;
  F77_CALL(dgemm)("N", "N", &nrow_a, &ncol_b, &ncol_a, &alpha, REAL(a), &nrow_a,
                  REAL(b), &nrow_b, &beta, REAL(result), &nrow_a);

  UNPROTECT(1);
  return result;
}

I don't know where these "N" comes from. But there are more than one of these, and this one is particularly matrix-matrix (while I apparently need matrix-vector). Presumably it is those SEXPTYPEs that the differentiator.

I've googled and BLAS should be supported on Mac. I don't know how that relates to LAPLACK, or where they are switched or changed.

Details

Usage: R CMD config [options] [VAR]

Get the value of a basic R configure variable VAR which must be among
those listed in the 'Variables' section below, or the header and
library flags necessary for linking a front-end against R.

Options:
  -h, --help            print short help message and exit
  -v, --version         print version info and exit
      --cppflags        print pre-processor flags required to compile a
                        C/C++ file as part of a front-end using R as a library
      --ldflags         print linker flags needed for linking a front-end
                        against the R library
      --no-user-files   ignore customization files under ~/.R
      --no-site-files   ignore site customization files under R_HOME/etc
      --all             print names and values of all variables below

Variables:
  AR            command to make static libraries
  BLAS_LIBS     flags needed for linking against external BLAS libraries
  CC            C compiler command
  CFLAGS        C compiler flags
  CC17          Ditto for the C17 or earlier compiler
  C17FLAGS
  CC23          Ditto for the C23 or later compiler
  C23FLAGS
  CPICFLAGS     special flags for compiling C code to be included in a
                shared library
  CPPFLAGS      C/C++ preprocessor flags, e.g. -I<dir> if you have
                headers in a nonstandard directory <dir>
  CXX           default compiler command for C++ code
  CXXFLAGS      compiler flags for CXX
  CXXPICFLAGS   special flags for compiling C++ code to be included in a
                shared library
  CXX11         compiler command for C++11 code
  CXX11STD      flag used with CXX11 to enable C++11 support
  CXX11FLAGS    further compiler flags for CXX11
  CXX11PICFLAGS
                special flags for compiling C++11 code to be included in
                a shared library
  CXX14         compiler command for C++14 code
  CXX14STD      flag used with CXX14 to enable C++14 support
  CXX14FLAGS    further compiler flags for CXX14
  CXX14PICFLAGS
                special flags for compiling C++14 code to be included in
                a shared library
  CXX17         compiler command for C++17 code
  CXX17STD      flag used with CXX17 to enable C++17 support
  CXX17FLAGS    further compiler flags for CXX17
  CXX17PICFLAGS
                special flags for compiling C++17 code to be included in
                a shared library
  CXX20         compiler command for C++20 code
  CXX20STD      flag used with CXX20 to enable C++20 support
  CXX20FLAGS    further compiler flags for CXX20
  CXX23         compiler command for C++23 code
  CXX23STD      flag used with CXX23 to enable C++23 support
  CXX23FLAGS    further compiler flags for CXX23
  CXX23PICFLAGS
                special flags for compiling C++23 code to be included in
                a shared library
  DYLIB_EXT     file extension (including '.') for dynamic libraries
  DYLIB_LD      command for linking dynamic libraries which contain
                object files from a C or Fortran compiler only
  DYLIB_LDFLAGS
                special flags used by DYLIB_LD
  FC            Fortran compiler command
  FFLAGS        fixed-form Fortran compiler flags
  FCFLAGS       free-form Fortran 9x compiler flags
  FLIBS         linker flags needed to link Fortran code
  FPICFLAGS     special flags for compiling Fortran code to be turned
                into a shared library
  JAR           Java archive tool command
  JAVA          Java interpreter command
  JAVAC         Java compiler command
  JAVAH         Java header and stub generator command
  JAVA_HOME     path to the home of Java distribution
  JAVA_LIBS     flags needed for linking against Java libraries
  JAVA_CPPFLAGS C preprocessor flags needed for compiling JNI programs
  LAPACK_LIBS   flags needed for linking against external LAPACK libraries
  LIBnn         location for libraries, e.g. 'lib' or 'lib64' on this platform
  LDFLAGS       linker flags, e.g. -L<dir> if you have libraries in a
                nonstandard directory <dir>
  LTO LTO_FC LTO_LD  flags for Link-Time Optimization
  MAKE          Make command
  NM            comand to display symbol tables
  OBJC          Objective C compiler command
  OBJCFLAGS     Objective C compiler flags
  RANLIB        command to index static libraries
  SAFE_FFLAGS   Safe (as conformant as possible) Fortran compiler flags
  SHLIB_CFLAGS  additional CFLAGS used when building shared objects
  SHLIB_CXXFLAGS
                additional CXXFLAGS used when building shared objects
  SHLIB_CXXLD   command for linking shared objects which contain
                object files from a C++ compiler (and CXX11 CXX14 CXX17 CXX20 CXX23)
  SHLIB_CXXLDFLAGS
                special flags used by SHLIB_CXXLD (and CXX11 CXX14 CXX17 CXX20 CXX23)
  SHLIB_EXT     file extension (including '.') for shared objects
  SHLIB_FFLAGS  additional FFLAGS used when building shared objects
  SHLIB_LD      command for linking shared objects which contain
                object files from a C or Fortran compiler only
  SHLIB_LDFLAGS
                special flags used by SHLIB_LD
  TCLTK_CPPFLAGS
                flags needed for finding the tcl.h and tk.h headers
  TCLTK_LIBS    flags needed for linking against the Tcl and Tk libraries

Windows only:
  COMPILED_BY   name and version of compiler used to build R
  LOCAL_SOFT    absolute path to '/usr/local' software collection
  R_TOOLS_SOFT  absolute path to 'R tools' software collection
  OBJDUMP       command to dump objects

Report bugs at <https://bugs.R-project.org>.

CGMossa avatar May 05 '23 11:05 CGMossa

Okay, I've also found this snippet here that might be helpful:

https://cran.r-project.org/bin/macosx/RMacOSX-FAQ.html#Which-BLAS-is-used-and-how-can-it-be-changed_003f

And of course this:

https://cran.r-project.org/doc/manuals/r-release/R-admin.html#Linear-algebra

CGMossa avatar May 05 '23 12:05 CGMossa

Thanks - that part is straightforward and we do it elsewhere (for example https://github.com/mrc-ide/eigen1/blob/master/src/util.c#L16-L17) - the pain comes when users have not correctly installed the fortran parts of the toolchain - and on macs that changes every couple of years as apple and R-core change how things get installed.

The blocker on this is the odin syntax, and that's been unresolved for about 5 years so I doubt we will get to it soon!

richfitz avatar May 05 '23 12:05 richfitz

Good. I won't comment on the syntax just yet.. Especially since I don't know anything about parsers. I guess the problem is that right now the line order doesn't matter, but for the three-step definition it would need to? In any case, thanks for indulging this conversation.

I guess, for my personal understanding, on Windows we have Rblas.dll, and I had hoped it was possible to just link to that, and not need a Fortran compiler. On Windows however, we have Rtools, and most likely it also contains Fortran compiler.. So I don't really have experience with this. I would have guessed -shared plus linking to those Rblas.dll or equivalent elsewhere would have been enough...

CGMossa avatar May 05 '23 12:05 CGMossa

Windows tends to be fine because R core controls the whole toolchain. On mac, at linking, you get issues if libgfortran is not found

Line order won't matter for this either - the intention is to support y <- A %*% x and convert that to the appropriate blas call based on what we know about y, A and x. The issue is when (inevitably) people want to apply these transformations to higher order objects, so looping over part of y at each operation, so we're thinking about things like:

y[., ] <- A[j, ., .] %*% x[., j]

at the moment

richfitz avatar May 05 '23 13:05 richfitz