LogarithmicNumbers.jl icon indicating copy to clipboard operation
LogarithmicNumbers.jl copied to clipboard

Sums with multiple terms

Open gdalle opened this issue 2 years ago • 7 comments

Would it be possible to take advantage of sums with many terms, so that x + y + z becomes logsumexp([x,y,z]) as opposed to logaddexp(x, logaddexp(y, z)) ?

gdalle avatar May 24 '23 08:05 gdalle

I just tried a small benchmark and it does look a bit faster:

julia> using LogarithmicNumbers, BenchmarkTools

julia> function mysum(a::AbstractVector{ULogarithmic{T}}) where {T}
           m = typemin(T)
           se = zero(T)
           for x in a
               if x.log < m
                   se += exp(x.log - m)
               elseif x.log == m
                   se += one(T)
               else
                   se = muladd(se, exp(m - x.log), one(T))
                   m = x.log
               end
           end
           lse = m + log(se)
           return exp(ULogarithmic, lse)
       end;

julia> a = exp.(ULogarithmic, rand(1_000));

julia> @btime sum($a)
  29.931 μs (0 allocations: 0 bytes)
exp(7.446988417750581)

julia> @btime mysum($a)
  4.841 μs (0 allocations: 0 bytes)
exp(7.446988417750575)

My implementation of logsumexp is from https://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html. I think we better not use LogExpFunctions.jl because

  • it has several unwanted dependencies
  • it wouldn't access the x.log easily

gdalle avatar May 24 '23 19:05 gdalle

Cool! How does the benchmark compare when the array has two elements? I'm just wondering if your version is essentially a faster +. I haven't looked at my + for a while but AFAIR it has a lot of branching to deal with special cases (NaN, Inf, 0).

cjdoris avatar May 25 '23 15:05 cjdoris

For two elements my version is slightly worse, and also probably less numerically stable

gdalle avatar May 26 '23 05:05 gdalle

julia> using UnicodePlots

julia> n_vals = vcat(2, 10 .^ (1:7));

julia> current_implem = zeros(length(n_vals));

julia> my_implem = zeros(length(n_vals));

julia> for (k, n) in enumerate(n_vals)
           a = exp.(ULogarithmic, rand(n))
           current_implem[k] = @belapsed sum($a)
           my_implem[k] = @belapsed mysum($a)
       end

julia> lineplot(
           log10.(n_vals),
           current_implem ./ my_implem; 
           xlabel="log array size", ylabel="perf gain"
       )
               ┌────────────────────────────────────────┐ 
             7 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
               │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⡀⠀⠀⠀⠀⠀│ 
               │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠒⠉⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠉⠉⠒⠒⠢│ 
               │⠀⠀⠀⠀⠀⠀⠀⠀⡠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
               │⠀⠀⠀⠀⠀⠀⢀⠜⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
               │⠀⠀⠀⠀⠀⢰⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
               │⠀⠀⠀⠀⢀⠇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
   perf gain   │⠀⠀⠀⠀⡸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
               │⠀⠀⠀⢀⠇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
               │⠀⠀⠀⡸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
               │⠀⠀⢀⠇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
               │⠀⠀⡸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
               │⠀⢠⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
               │⠀⠈⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
             0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
               └────────────────────────────────────────┘ 
               ⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀7⠀ 
               ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀log array size⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ 

gdalle avatar May 26 '23 05:05 gdalle

Bumping this, any interest in a PR?

gdalle avatar Sep 06 '23 08:09 gdalle

Sure, what would the PR be? A new sum function?

cjdoris avatar Sep 13 '23 17:09 cjdoris

Yeah probably, but I don't know how to make it général enough to accept arbitrary iterators yet still find a workable eltype

gdalle avatar Sep 14 '23 08:09 gdalle