math icon indicating copy to clipboard operation
math copied to clipboard

a more efficient (log)sofmax

Open bnicenboim opened this issue 1 year ago • 12 comments

I've just recently read https://academic.oup.com/imajna/article/41/4/2311/5893596?login=false which shows that for log-softmax implemented as

a <- max(x); log(exp(x - a)/sum(exp(x-a)))

is more accurate than this version (which as far as I understand the hpp files is the one that Stan uses)

x - matrixStats::logSumExp(x)

(Sorry for the R code, but the point is the same).

I just thought it was worth to point this out. If there are good reasons for the way softmax is implemented, please just close and ignore, otherwise it might be useful...

bnicenboim avatar Aug 10 '22 08:08 bnicenboim

Just to confirm that Stan does use the latter. See

https://github.com/stan-dev/math/blob/92075708b1d1796eb82e3b284cd11e544433518e/stan/math/prim/fun/log_softmax.hpp#L49

and

https://github.com/stan-dev/math/blob/83b3731b934e4ff1905f33c3ab2ab85371397189/stan/math/rev/fun/log_softmax.hpp#L106

I will defer to other developers on commenting whether the proposed approach would be suitable for Stan.

rok-cesnovar avatar Aug 10 '22 08:08 rok-cesnovar

Thanks @bnicenboim. As a general rule, we should do whatever Higham recommends (his blog on numerical analysis is fantastic)! We're using his matrix exponential function algorithm, too.

Our log-sum-exp function uses the shift,

log_sum_exp(v) = max(v) + log(sum(exp(v - max(v)))

So the prim version of log_softmax isn't so bad. But the reverse mode goes off course and doesn't even use the log-softmax we implemented in prim. At the very least, line 106 in the reverse implementation should call the prim version. But rather than doing that, I'll just code the approach in the Blanchard et al. paper that Bruno cited.

bob-carpenter avatar Aug 10 '22 15:08 bob-carpenter

We're also not computing the derivatives optimally according to that paper. I'll also fix that.

Another question: is it OK if I change the boundary condition? Right now, log-sum-exp applied to an empty container throws an exception. Instead, it should return -infinity, because sum of an empty element is zero, and log of zero is negative infinity.

bob-carpenter avatar Aug 10 '22 18:08 bob-carpenter

I would be ok with the change in boundary behavior as you outline. That makes sense and it should have been like that in the first place... which brings me to the question what is the sum of an empty set in Stan?

wds15 avatar Aug 10 '22 18:08 wds15

Sums of empty containers evaluate to 0 and products of empty containers evaluate to 1. These are the usual boundary conditions for empty containers because they generalize inductively properly, like in the accumulators in C++ 11.

bob-carpenter avatar Aug 10 '22 18:08 bob-carpenter

then I am inclined to call the current behavior a bug and the fix is what you propose.

wds15 avatar Aug 10 '22 18:08 wds15

I agree with @wds15.

rok-cesnovar avatar Aug 10 '22 18:08 rok-cesnovar

I just wrote to the authors of that paper to ask what they recommend for log_softmax. They recommend against

softmax(x) = exp(x - max(x) - log_sum_exp(x - max(x)))

but I couldn't find the recommendations for log softmax in their paper. We implement it in the obvious way as:

log_softmax(x) = x - max(x) - log_sum_exp(x - max(x))

@bnicenboim : did you find a mention of log-softmax in their paper? The question is whether we should just implement it as

log_softmax(x) = log(softmax(x))

bob-carpenter avatar Aug 11 '22 19:08 bob-carpenter

no, I didn't, but given that they recommend against:

softmax(x) = exp(x - max(x) - log_sum_exp(x - max(x)))

I assumed that this was a bad idea for the log softmax

(x - max(x) - log_sum_exp(x - max(x))

It cannot be that the problem is the removing of the exp(), right? it must be division vs difference....

But I would wait for the answer of the authors, they should have a more thoughtful answer :)

bnicenboim avatar Aug 11 '22 19:08 bnicenboim

I was going to try to fix this issue, but after spending about 8 hours trying to implement log_softmax(x) as log(softmax(x)) following the advice of Nick Higham et al., I'm giving up.

For reasons I don't understand, softmax is coded completely differently than log_softmax. Someone tried to make log_softmax work for arbitrary containers, but we only need an implementation for Eigen column vectors---that's all that's exposed in the language. I also don't understand why one throws an exception on size zero input (correct behavior) and one just returns the empty vector (wrong behavior because that's not a simplex). I also don't understand how they're supposed to work with mat_var or even if softmax works for mat_var.

I'm hoping this gets easier when @SteveBronder's doc lands, but I fear it's going to be super confusing given that there appear to be a bunch of different ways to code the callback functions.

If you want to start from where I left off, which includes a lot of cleanup on testing, it's on branch bugfix/2802-softmax-arith. Sadly, the last commit message is, "failed attempt to compile log_softmax". Everything but log_softmax seems to be working, but the whole point of doing this is to fix that function.

bob-carpenter avatar Aug 15 '22 19:08 bob-carpenter

@bob-carpenter I think I managed to fix the compile issue for you. The test under mix for log_softmax now compiles and runs for me ok.

wds15 avatar Aug 16 '22 07:08 wds15

Thanks, @wds15. I'm giving up on this issue. I can't keep up with the C++ in the math lib with the limited amount of time I have to code.

To summarize, the minimum fix is to redefine the double based value of log_softmax in reverse mode to be implemented as log(softmax(x)) rather than with the unfolded arithmetic. Where I got stuck was in other nice-to-have features:

  1. softmax and log_softmax having identical signatures. They should. Just Eigen::VectorXd is fine.
  2. softmax and log_softmax throwing with size zero input. Only log_softmax does now.
  3. having log_softmax delegate to softmax rather than reimplement.
  4. stop binding the double-based value of a matrix and recompute it in the callback---it's a huge memory sink to save it and I think we should be conservative with memory. This is a "bug" throughout the new reverse mode code.
  5. both should work for var-mat, but I have no idea how to do that.

bob-carpenter avatar Aug 16 '22 16:08 bob-carpenter