spatialsample icon indicating copy to clipboard operation
spatialsample copied to clipboard

Optimize distance calculations: consider replacing sf::st_distance

Open atsyplenkov opened this issue 1 year ago β€’ 2 comments

Hi there!

Just a short note on the distance calculations. From my experience, it's one of the major bottlenecks in the package. For example, I can't even apply the spatialsample package for my datasets with 40k points in it.

https://github.com/tidymodels/spatialsample/blob/ded16914d80c5bb118d8a61ffd93e07be5aad93e/R/buffer.R#L27

What if we replace sf::st_distance with a slightly more robust function? For example, there is some evidence that Rfast::Dist works two to three times faster than sf::st_distance(). Perhaps it would be worth adding one more package to the dependency list in the name of a speed boost?

Unfortunately, both algorithms seem to have O(nΒ²) time complexity, which is not good, and Rfast is not a silver bullet. Additionally, in the case of longlat coordinates, sf::st_distance() may still be preferable as it computes Great Circle distance.

I can prepare a PR if you like the approach.

See below some benchmarking

library(sf)
#> Linking to GEOS 3.12.1, GDAL 3.8.4, PROJ 9.3.1; sf_use_s2() is TRUE
suppressPackageStartupMessages(library(Rfast))
library(ggplot2)
library(dplyr)
library(tidyr)
library(bench)

# Function to create points
create_points <- function(n) {
  bbox <- sf::st_bbox(c(
    xmin = 1400000, xmax = 2100000,
    ymin = 5400000, ymax = 6200000
  ))
  bbox <- sf::st_as_sfc(bbox)
  sf::st_crs(bbox) <- 2193
  sf::st_sample(bbox, n)
}

# Run benchmarks for different n
ns <- seq(1000, 10000, by = 1000)
results <- list()

for (n in ns) {
  set.seed(n)
  pts <- create_points(n)

  bm <- bench::mark(
    sf = sf::st_distance(pts, which = "Euclidean"),
    Rfast = pts |>
      sf::st_coordinates() |>
      Rfast::Dist(method = "euclidean"),
    time_unit = "ms",
    iterations = 5,
    check = FALSE
  )

  # Add n to the results
  bm$n <- n
  results[[as.character(n)]] <- bm
}

# Combine and prepare results
benchmark_df <- do.call(rbind, results)

# Reshape for faceted plotting
plot_df <-
  benchmark_df |>
  dplyr::transmute(
    n,
    method = as.character(expression),
    time = as.numeric(median),
    mem_alloc = as.numeric(mem_alloc)
  ) |>
  tidyr::pivot_longer(
    cols = c(time, mem_alloc),
    names_to = "metric",
    values_to = "value"
  ) |>
  dplyr::mutate(
    metric = factor(metric,
      levels = c("time", "mem_alloc"),
      labels = c("Time (milliseconds)", "Memory (bytes)")
    )
  )


# Plot the results
plot_df |>
  ggplot2::ggplot(
    ggplot2::aes(x = n, y = value, color = method)
  ) +
  ggplot2::geom_smooth(se = FALSE) +
  ggplot2::geom_point() +
  ggplot2::scale_x_continuous(breaks = ns) +
  ggplot2::scale_y_continuous(
    breaks = scales::pretty_breaks(n = 5),
    labels = scales::label_number(scale_cut = scales::cut_short_scale())
  ) +
  ggplot2::facet_wrap(~metric, scales = "free_y", nrow = 2) +
  ggplot2::labs(
    title = "sf vs Rfast Distance Calculations",
    y = "",
    x = "Number of Points",
    color = "Method"
  ) +
  ggplot2::theme_minimal() +
  ggplot2::theme(
    legend.position = "bottom",
    panel.grid.minor = ggplot2::element_blank(),
    strip.text = ggplot2::element_text(face = "bold")
  )
#> `geom_smooth()` using method = 'loess' and formula = 'y ~ x'



# Compare results
set.seed(123)
pts <- create_points(1000)

# Euclidan distance
sf_example <-
  sf::st_distance(pts, which = "Euclidean")
Rfast_example <- pts |>
  sf::st_coordinates() |>
  Rfast::Dist(method = "euclidean")

waldo::compare(as.double(sf_example), as.double(Rfast_example))
#> βœ” No differences

# Session Info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.1 Patched (2024-08-05 r86984 ucrt)
#>  os       Windows 10 x64 (build 19045)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  English_United States.utf8
#>  ctype    English_United States.utf8
#>  tz       Pacific/Auckland
#>  date     2024-11-13
#>  pandoc   3.2 @ c:\\scoop\\apps\\positron\\2024.11.0-140\\resources\\app\\quarto\\bin\\tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  bench        * 1.1.3      2023-05-04 [1] RSPM
#>  cachem         1.0.8      2023-05-01 [1] CRAN (R 4.4.1)
#>  class          7.3-22     2023-05-03 [2] CRAN (R 4.4.1)
#>  classInt       0.4-10     2023-09-05 [1] CRAN (R 4.4.1)
#>  cli            3.6.2      2023-12-11 [1] CRAN (R 4.4.1)
#>  colorspace     2.1-0      2023-01-23 [1] CRAN (R 4.4.1)
#>  DBI            1.2.2      2024-02-16 [1] CRAN (R 4.4.1)
#>  devtools       2.4.5      2022-10-11 [1] RSPM (R 4.4.0)
#>  digest         0.6.35     2024-03-11 [1] CRAN (R 4.4.1)
#>  dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.4.1)
#>  e1071          1.7-14     2023-12-06 [1] RSPM
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.4.1)
#>  evaluate       0.23       2023-11-01 [1] CRAN (R 4.4.1)
#>  fansi          1.0.6      2023-12-08 [1] CRAN (R 4.4.1)
#>  farver         2.1.1      2022-07-06 [1] CRAN (R 4.4.1)
#>  fastmap        1.1.1      2023-02-24 [1] CRAN (R 4.4.1)
#>  fs             1.6.4      2024-04-25 [1] CRAN (R 4.4.1)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.4.1)
#>  ggplot2      * 3.5.1      2024-04-23 [1] CRAN (R 4.4.1)
#>  glue           1.7.0      2024-01-09 [1] CRAN (R 4.4.1)
#>  gtable         0.3.5      2024-04-22 [1] CRAN (R 4.4.1)
#>  htmltools      0.5.8.1    2024-04-04 [1] CRAN (R 4.4.1)
#>  htmlwidgets    1.6.4      2023-12-06 [1] CRAN (R 4.4.1)
#>  httpuv         1.6.15     2024-03-26 [1] CRAN (R 4.4.1)
#>  KernSmooth     2.23-22    2023-07-10 [1] CRAN (R 4.4.1)
#>  knitr          1.46       2024-04-06 [1] CRAN (R 4.4.1)
#>  later          1.3.2      2023-12-06 [1] CRAN (R 4.4.1)
#>  lattice        0.22-6     2024-03-20 [2] CRAN (R 4.4.1)
#>  lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.4.1)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.4.1)
#>  Matrix         1.7-0      2024-04-26 [2] CRAN (R 4.4.1)
#>  memoise        2.0.1      2021-11-26 [1] CRAN (R 4.4.1)
#>  mgcv           1.9-1      2023-12-21 [2] CRAN (R 4.4.1)
#>  mime           0.12       2021-09-28 [1] CRAN (R 4.4.0)
#>  miniUI         0.1.1.1    2018-05-18 [1] CRAN (R 4.4.1)
#>  munsell        0.5.1      2024-04-01 [1] CRAN (R 4.4.1)
#>  nlme           3.1-164    2023-11-27 [1] CRAN (R 4.4.1)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.4.1)
#>  pkgbuild       1.4.4      2024-03-17 [1] CRAN (R 4.4.1)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.4.1)
#>  pkgload        1.3.4      2024-01-16 [1] RSPM (R 4.4.0)
#>  profmem        0.6.0      2020-12-13 [1] RSPM
#>  profvis        0.3.8      2023-05-02 [1] CRAN (R 4.4.1)
#>  promises       1.3.0      2024-04-05 [1] CRAN (R 4.4.1)
#>  proxy          0.4-27     2022-06-09 [1] RSPM
#>  purrr          1.0.2      2023-08-10 [1] CRAN (R 4.4.1)
#>  R.cache        0.16.0     2022-07-21 [1] CRAN (R 4.4.1)
#>  R.methodsS3    1.8.2      2022-06-13 [1] CRAN (R 4.4.0)
#>  R.oo           1.26.0     2024-01-24 [1] CRAN (R 4.4.0)
#>  R.utils        2.12.3     2023-11-18 [1] CRAN (R 4.4.1)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.4.1)
#>  Rcpp         * 1.0.12     2024-01-09 [1] CRAN (R 4.4.1)
#>  RcppParallel * 5.1.7      2023-02-27 [1] CRAN (R 4.4.1)
#>  RcppZiggurat * 0.1.6      2020-10-20 [1] RSPM
#>  remotes        2.5.0.9000 2024-10-01 [1] Github (r-lib/remotes@5b7eb08)
#>  reprex         2.1.0      2024-01-11 [1] CRAN (R 4.4.1)
#>  Rfast        * 2.1.0      2023-11-09 [1] RSPM
#>  rlang          1.1.4      2024-06-04 [1] CRAN (R 4.4.1)
#>  rmarkdown      2.28       2024-08-17 [1] RSPM
#>  scales         1.3.0      2023-11-28 [1] CRAN (R 4.4.1)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.4.1)
#>  sf           * 1.0-19     2024-11-05 [1] RSPM
#>  shiny          1.8.1.1    2024-04-02 [1] RSPM (R 4.4.0)
#>  stringi        1.8.3      2023-12-11 [1] CRAN (R 4.4.1)
#>  stringr        1.5.1      2023-11-14 [1] CRAN (R 4.4.1)
#>  styler         1.10.3     2024-04-07 [1] CRAN (R 4.4.1)
#>  tibble         3.2.1      2023-03-20 [1] CRAN (R 4.4.1)
#>  tidyr        * 1.3.1      2024-01-24 [1] CRAN (R 4.4.1)
#>  tidyselect     1.2.1      2024-03-11 [1] CRAN (R 4.4.1)
#>  units          0.8-5      2023-11-28 [1] CRAN (R 4.4.1)
#>  urlchecker     1.0.1.9000 2024-09-04 [1] https://r-lib.r-universe.dev (R 4.4.1)
#>  usethis        3.0.0      2024-07-29 [1] CRAN (R 4.4.1)
#>  utf8           1.2.4      2023-10-22 [1] CRAN (R 4.4.1)
#>  vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.4.1)
#>  waldo          0.5.2      2023-11-02 [1] CRAN (R 4.4.1)
#>  withr          3.0.0      2024-01-16 [1] CRAN (R 4.4.1)
#>  xfun           0.43       2024-03-25 [1] CRAN (R 4.4.1)
#>  xtable         1.8-4      2019-04-21 [1] CRAN (R 4.4.1)
#>  yaml           2.3.8      2023-12-11 [1] CRAN (R 4.4.1)
#> 
#>  [1] C:/Users/TsyplenkovA/AppData/Local/R/win-library/4.4
#>  [2] C:/Program Files/R/R-4.4.1patched/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Created on 2024-11-13 with reprex v2.1.0

atsyplenkov avatar Nov 12 '24 21:11 atsyplenkov

Thanks for opening the issue! I agree that function can be quite the pain. I'm not super willing to adopt RFast as a dependency, because that winds up causing all sorts of knock-on effects with how we handle units in arguments throughout the package. I'm going to leave this open though, because it'd be worth exploring C++ based alternatives to this distance function.

Depending on what specific function you're looking to use, you might be interested in rsample::clustering_cv(). This function can take any distance function as an argument, so is a lot more flexible than the spatialsample variants:

nc <- sf::read_sf(system.file("shape/nc.shp", package = "sf"))
nc_dist <- as.numeric(sf::st_distance(nc)) # or use RFast here, or other functions

rsample::clustering_cv(
  sf::st_drop_geometry(nc),
  AREA, # this is a dummy value to force the function to run
  distance_function = \(x) x,
  cluster_function = \(dists, v, ...) kmeans(dists, v)$cluster,
  x = nc_dist
)
#> # 10-cluster cross-validation 
#> # A tibble: 10 Γ— 2
#>    splits          id    
#>    <list>          <chr> 
#>  1 <split [88/12]> Fold01
#>  2 <split [91/9]>  Fold02
#>  3 <split [96/4]>  Fold03
#>  4 <split [92/8]>  Fold04
#>  5 <split [88/12]> Fold05
#>  6 <split [89/11]> Fold06
#>  7 <split [84/16]> Fold07
#>  8 <split [84/16]> Fold08
#>  9 <split [92/8]>  Fold09
#> 10 <split [96/4]>  Fold10

Created on 2024-11-12 with reprex v2.1.1

mikemahoney218 avatar Nov 12 '24 22:11 mikemahoney218

Totally understand your take on the Rfast. Perhaps we can add something like this? I have adapted code from a SO discussion I had several months ago regarding distance calculations. Back then it was the fastest way of distance estimation.

However, it only covers Euclidean distance, but it produces identical results toΒ sf::st_distance. I still need to consider what to do with longlat coordinates, though.

P.S. Thanks for the tip on how to use custom distance function!! Appreciate it

  library(sf)
#> Linking to GEOS 3.12.1, GDAL 3.8.4, PROJ 9.3.1; sf_use_s2() is TRUE
  library(Rcpp)
  library(units)
#> udunits database from C:/Users/TsyplenkovA/AppData/Local/R/win-library/4.4/units/share/udunits/udunits2.xml

  cppFunction("NumericMatrix distance_matrix_cpp(NumericMatrix points) {
  int n = points.nrow();
  NumericMatrix distances(n, n);

  for(int i = 0; i < n; i++) {
    distances(i,i) = 0;

    for(int j = i+1; j < n; j++) {
      double dx = points(i,0) - points(j,0);
      double dy = points(i,1) - points(j,1);
      double dist = sqrt(dx*dx + dy*dy);

      distances(i,j) = dist;
      distances(j,i) = dist;
    }
  }

  return distances;
}")

  # Wrapper function to add units
  cpp_distance_units <- function(pts) {
    crs <- sf::st_crs(pts)
    coords <- sf::st_coordinates(pts)
    dist_matrix <- distance_matrix_cpp(coords)
    ids <- as.character(seq_along(coords[, 1]))
    dimnames(dist_matrix) <- list(ids, ids)

    units(dist_matrix) <- crs$units
    dist_matrix
  }

  # Function from the previous message
  pts <- create_points(5000)

  bench::mark(
    sf = sf::st_distance(pts, which = "Euclidean"),
    cpp = cpp_distance_units(pts),
    rfast = {
      coords <- sf::st_coordinates(pts)
      Rfast::Dist(coords, method = "euclidean")
    },
    iterations = 10,
    relative = TRUE,
    check = FALSE
  )
#> # A tibble: 3 Γ— 6
#>   expression   min median `itr/sec` mem_alloc `gc/sec`
#>   <bch:expr> <dbl>  <dbl>     <dbl>     <dbl>    <dbl>
#> 1 sf          2.51   2.61      1         2.00     2.34
#> 2 cpp         1.08   1.11      2.35      1        1   
#> 3 rfast       1      1         2.61      1.01     1.67

  waldo::compare(
    sf::st_distance(pts, which = "Euclidean"),
    cpp_distance_units(pts)
  )
#> βœ” No differences

Created on 2024-11-13 with reprex v2.1.0

atsyplenkov avatar Nov 12 '24 23:11 atsyplenkov