libm icon indicating copy to clipboard operation
libm copied to clipboard

Seperate implementation of hex float parsing for performance

Open CrazyboyQCD opened this issue 9 months ago • 2 comments

Currently hex float parsing use u128 as storage for all float types, but f16, f32 and f64 are fit in u64, use u64 as storage give 25%~35% performance increase in simple test. https://play.rust-lang.org/?version=nightly&mode=release&edition=2021&gist=cd4e078bb43278130a09c522d8b24820

#![feature(f16, f128)]

use hex_float::*;
mod hex_float {
    use super::T;
    pub fn hf16(s: &str) -> f16 {
        f16::from_bits(parse_any(s, 16, 10) as u16)
    }

    /// Construct a 32-bit float from hex float representation (C-style)
    #[allow(unused)]
    pub fn hf32(s: &str) -> f32 {
        f32::from_bits(parse_any(s, 32, 23) as u32)
    }

    /// Construct a 64-bit float from hex float representation (C-style)
    pub fn hf64(s: &str) -> f64 {
        f64::from_bits(parse_any(s, 64, 52) as u64)
    }

    /// Parse any float from hex to its bitwise representation.
    ///
    /// `nan_repr` is passed rather than constructed so the platform-specific NaN is returned.
    pub fn parse_any(s: &str, bits: u32, sig_bits: u32) -> T {
        let exp_bits: u32 = bits - sig_bits - 1;
        let max_msb: i32 = (1 << (exp_bits - 1)) - 1;
        // The exponent of one ULP in the subnormals
        let min_lsb: i32 = 1 - max_msb - sig_bits as i32;

        let exp_mask = ((1 << exp_bits) - 1) << sig_bits;

        let (neg, mut sig, exp) = match parse_hex(s.as_bytes()) {
            Parsed::Finite { neg, sig: 0, .. } => return (neg as T) << (bits - 1),
            Parsed::Finite { neg, sig, exp } => (neg, sig, exp),
            Parsed::Infinite { neg } => return ((neg as T) << (bits - 1)) | exp_mask,
            Parsed::Nan { neg } => {
                return ((neg as T) << (bits - 1)) | exp_mask | (1 << (sig_bits - 1));
            }
        };

        // exponents of the least and most significant bits in the value
        let lsb = sig.trailing_zeros() as i32;
        let msb = u128_ilog2(sig) as i32;
        let sig_bits = sig_bits as i32;

        assert!(msb - lsb <= sig_bits, "the value is too precise");
        assert!(msb + exp <= max_msb, "the value is too huge");
        assert!(lsb + exp >= min_lsb, "the value is too tiny");

        // The parsed value is X = sig * 2^exp
        // Expressed as a multiple U of the smallest subnormal value:
        // X = U * 2^min_lsb, so U = sig * 2^(exp-min_lsb)
        let mut uexp = exp - min_lsb;

        let shift = if uexp + msb >= sig_bits {
            // normal, shift msb to position sig_bits
            sig_bits - msb
        } else {
            // subnormal, shift so that uexp becomes 0
            uexp
        };

        if shift >= 0 {
            sig <<= shift;
        } else {
            sig >>= -shift;
        }
        uexp -= shift;

        // the most significant bit is like having 1 in the exponent bits
        // add any leftover exponent to that
        assert!(uexp >= 0 && uexp < (1 << exp_bits) - 2);
        sig += (uexp as T) << sig_bits;

        // finally, set the sign bit if necessary
        sig | ((neg as T) << (bits - 1))
    }

    /// A parsed floating point number.
    enum Parsed {
        /// Absolute value sig * 2^e
        Finite {
            neg: bool,
            sig: T,
            exp: i32,
        },
        Infinite {
            neg: bool,
        },
        Nan {
            neg: bool,
        },
    }
    const fn u128_ilog2(v: T) -> u32 {
        assert!(v != 0);
        T::BITS - 1 - v.leading_zeros()
    }

    /// Parse a hexadecimal float x
    const fn parse_hex(mut b: &[u8]) -> Parsed {
        let mut neg = false;
        let mut sig: T = 0;
        let mut exp: i32 = 0;

        if let &[c @ (b'-' | b'+'), ref rest @ ..] = b {
            b = rest;
            neg = c == b'-';
        }

        match *b {
            [b'i' | b'I', b'n' | b'N', b'f' | b'F'] => return Parsed::Infinite { neg },
            [b'n' | b'N', b'a' | b'A', b'n' | b'N'] => return Parsed::Nan { neg },
            _ => (),
        }

        if let &[b'0', b'x' | b'X', ref rest @ ..] = b {
            b = rest;
        } else {
            panic!("no hex indicator");
        }

        let mut seen_point = false;
        let mut some_digits = false;

        while let &[c, ref rest @ ..] = b {
            b = rest;

            match c {
                b'.' => {
                    assert!(!seen_point);
                    seen_point = true;
                    continue;
                }
                b'p' | b'P' => break,
                c => {
                    let digit = hex_digit(c);
                    some_digits = true;
                    let of;
                    (sig, of) = sig.overflowing_mul(16);
                    assert!(!of, "too many digits");
                    sig |= digit as T;
                    // up until the fractional point, the value grows
                    // with more digits, but after it the exponent is
                    // compensated to match.
                    if seen_point {
                        exp -= 4;
                    }
                }
            }
        }
        assert!(some_digits, "at least one digit is required");
        some_digits = false;

        let mut negate_exp = false;
        if let &[c @ (b'-' | b'+'), ref rest @ ..] = b {
            b = rest;
            negate_exp = c == b'-';
        }

        let mut pexp: i32 = 0;
        while let &[c, ref rest @ ..] = b {
            b = rest;
            let digit = dec_digit(c);
            some_digits = true;
            let of;
            (pexp, of) = pexp.overflowing_mul(10);
            assert!(!of, "too many exponent digits");
            pexp += digit as i32;
        }
        assert!(some_digits, "at least one exponent digit is required");

        if negate_exp {
            exp -= pexp;
        } else {
            exp += pexp;
        }

        Parsed::Finite { neg, sig, exp }
    }

    const fn dec_digit(c: u8) -> u8 {
        match c {
            b'0'..=b'9' => c - b'0',
            _ => panic!("bad char"),
        }
    }

    const fn hex_digit(c: u8) -> u8 {
        match c {
            b'0'..=b'9' => c - b'0',
            b'a'..=b'f' => c - b'a' + 10,
            b'A'..=b'F' => c - b'A' + 10,
            _ => panic!("bad char"),
        }
    }
}
type T = u64;

fn main() {
    let mut args = std::env::args();
    let _ = args.next();
    let a = args.next().unwrap_or_else(|| "0x1.92p+1".to_owned());
    let b = args.next().unwrap_or_else(|| "0x1.921fbp+1".to_owned());
    let c = args
        .next()
        .unwrap_or_else(|| "0x1.921fb54442d18p+1".to_owned());
    let n = 100_000;
    let t = std::time::Instant::now();
    let mut v = 0.0;
    for _ in 0..n {
        v += hf16(&a);
    }
    println!("16: {} {:?}", v > 10.0, t.elapsed());
    let t = std::time::Instant::now();
    let mut v = 0.0;
    for _ in 0..n {
        v += hf32(&b);
    }
    println!("32: {} {:?}", v > 10.0, t.elapsed());

    let t = std::time::Instant::now();
    let mut v = 0.0;
    for _ in 0..n {
        v += hf64(&c);
    }
    println!("64: {} {:?}", v > 10.0, t.elapsed());
}

CrazyboyQCD avatar Feb 20 '25 07:02 CrazyboyQCD