levenshtein-distance-benchmarks icon indicating copy to clipboard operation
levenshtein-distance-benchmarks copied to clipboard

Some thoughts about rust implementation

Open Pzixel opened this issue 4 years ago • 2 comments

I read your article, and would like to suggest one small change. As you have seen, allocator in non-GC languages is a huge pain, so there is no suprise that GC outperform Rust. I've seen this multiple times, sometimes people come from Haskell chats saying "hey, this is my small Haskell program which outperform rust, what a heck?", and you can be sure there is a lot of boxes, small heap allocations and other non-idiomatic stuff.

Rust is a systems language (unlike Go), so sometimes it doesn't perform well in "generic" cases but it always give you ways to use information you have. For example, Rust implementation you wrote probably outperform Go on large srings, but it fails on strings 0..16 characters long.

Solution? Express this in types. For example, this small change:


use heapless::Vec; // fixed capacity `std::Vec`
use heapless::consts::U16; // type level integer used to specify capacity

pub fn levenshtein_distance(source: &str, target: &str) -> usize {
    if source.is_empty() {
        return target.len();
    }

    if target.is_empty() {
        return source.len();
    }

    let mut cache: Vec<_, U16> = (0..=target.chars().count()).collect();

    for (i, source_char) in source.chars().enumerate() {
        let mut next_dist = i + 1;

        for (j, target_char) in target.chars().enumerate() {
            let current_dist = next_dist;

            let mut dist_if_substitute = cache[j];
            if source_char != target_char {
                dist_if_substitute += 1;
            }

            let dist_if_insert = current_dist + 1;
            let dist_if_delete = cache[j + 1] + 1;

            next_dist = std::cmp::min(
                dist_if_substitute,
                std::cmp::min(dist_if_insert, dist_if_delete),
            );

            cache[j] = current_dist;
        }

        cache[target.len()] = next_dist;
    }

    cache[target.len()]
}

Drops execution time on my machine from 2.06 sec to 1.169. Of course it doesn't work for strings larger than 16 characters, but this is the price for a systems language. You can write then a function that calls this or that implementation based on string length, etc. (already done in smallvec crate)

Being said, Rust will not outperform a GC language where you create lots of small objects in the heap, it's just impossible. Sometimes Rust can outperform Go's GC but it won't be able to outperform superior GC with immense amount of effort spent like JVM one. It's just the scenario GC was made for, no surprise naive code of someone's generic program can't outperform what best engineers are making for years.


P.S. Also, it's more idiomatic to use local functions and iterators, so this code:

let benchmark = || {
    for _ in 0..10000 {
        let pairs = lines.iter().zip(
            std::iter::once(&"").chain(lines.iter()));
        for (l1, l2) in pairs {
            levenshtein_distance(l1, l2);
        }
    }
};

Makes it clearer that we compare each string with previous one (with "" as a previous for the first one), when it took me some time to figure out what happens with last_value in original code


P.P.S.

Didn't see #2 , but I think all I said stands still.


P.P.P.S.

Applying both changes, so @BurntSushi code becomes:

#[derive(Debug, Default)]
pub struct LevenshteinDistance<N1: ArrayLength<char>, N2: ArrayLength<usize>> {
    source: Vec<char, N1>,
    target: Vec<char, N1>,
    cache: Vec<usize, N2>,
}

Gives a small enhancement and results in 1.09 on my machine.

Full code

use std::cmp::min;
use std::time::Instant;
use heapless::{Vec, ArrayLength}; // fixed capacity `std::Vec`
use heapless::consts::U16; // type level integer used to specify capacity

#[derive(Debug, Default)]
pub struct LevenshteinDistance<N1: ArrayLength<char>, N2: ArrayLength<usize>> {
    source: Vec<char, N1>,
    target: Vec<char, N1>,
    cache: Vec<usize, N2>,
}

impl<N1: ArrayLength<char>, N2: ArrayLength<usize>> LevenshteinDistance<N1, N2> {
    pub fn distance(&mut self, source: &str, target: &str) -> usize {
        if source.is_empty() {
            return target.len();
        }
        if target.is_empty() {
            return source.len();
        }

        self.source.clear();
        self.source.extend(source.chars());
        self.target.clear();
        self.target.extend(target.chars());
        self.cache.clear();
        self.cache.extend(0..=self.target.len());

        for (i, source_char) in self.source.iter().enumerate() {
            let mut next_dist = i + 1;

            for (j, target_char) in self.target.iter().enumerate() {
                let current_dist = next_dist;

                let mut dist_if_substitute = self.cache[j];
                if source_char != target_char {
                    dist_if_substitute += 1;
                }

                let dist_if_insert = current_dist + 1;
                let dist_if_delete = self.cache[j + 1] + 1;

                next_dist = min(
                    dist_if_delete,
                    min(dist_if_insert, dist_if_substitute),
                );

                self.cache[j] = current_dist;
            }

            self.cache[target.len()] = next_dist;
        }

        self.cache[target.len()]
    }
}

type LevDistance = LevenshteinDistance<U16, U16>;

fn main() {
    let lines: std::vec::Vec<&str> = include_str!("../sample.txt").split('\n').collect();

    let benchmark = || {
        let mut levenshtein = LevDistance::default();
        for _ in 0..10000 {
            let pairs = lines.iter().zip(
                std::iter::once(&"").chain(lines.iter()));
            for (l1, l2) in pairs {
                levenshtein.distance(l1, l2);
            }
        }
    };

    use std::time::Instant;
    let now = Instant::now();

    {
        benchmark();
    }

    let elapsed = now.elapsed();
    let sec = (elapsed.as_secs() as f64) + (elapsed.subsec_nanos() as f64 / 1000_000_000.0);
    print!("{}", sec);

    // check
    let answers: std::vec::Vec<String> = (0..lines.len()-1)
        .map(|i|LevDistance::default().distance(lines[i], lines[i+1]))
        .map(|dist| dist.to_string())
        .collect();
    eprintln!("{}", answers.join(","));
}

Pzixel avatar May 04 '20 09:05 Pzixel