TensorComprehensions icon indicating copy to clipboard operation
TensorComprehensions copied to clipboard

Support for "sumExp(x) +=! exp( ... )"

Open jeanfeydy opened this issue 7 years ago • 12 comments

Hi everyone,

First of all, let me thank you for putting so much effort into this accessible, versatile library.

As a mathematician working at the intersection between shape analysis and optimal transport theory, I am interested in computing log-likelihoods of Gaussian models, i.e. expressions of the form

log(μ ★ k)(x) = log( ∫ k(x-x') dμ(x') )

where μ is a measure (say, a sum of atomic dirac masses) and k is a Gaussian kernel whose density with respect to the Lebesgue measure is proportionnal to

k(x) = exp( -(x/σ)^2).

These expressions can be used to fit Gaussian Mixture Models or to compute Wasserstein distances between samples - see http://optimaltransport.github.io and log-domain implementations of the Sinkhorn algorithm.

Unfortunately, one cannot just compute the convolution μ★k between the source measure μ and a Gaussian kernel k before applying a pointwise logarithm: since k decays so quickly towards zero, μ★k is nearly zero away from the support of μ and applying a careless log results in numerical instabilities.

Hence, one should always write

μ = Σ_i μ(x_i') δ_{x'_i}, log(μ ★ k)(x) = log[ Σ_i exp{ log(μ(x_i')) + log(k(x-x'_i)) } ]

and use the log-sum-exp trick to compute numerically accurate results.

If μ is given as a list of weights+points, one can compute log(μ★k)(x) with a vanilla PyTorch implementation, or an efficient PyTorch+CUDA library that I am about to release with two colleagues from the shape analysis community.

On the other hand, if μ is represented through its density sampled on a grid, i.e. as a bitmap A=(μ(x)), the clever solution would be to implement separable log-convolutions on tensors. That is, define an operation similar to a "Softmax layer" that takes as input:

  • a bitmap log_A(x) (or log_A(x,y), log_A(x,y,z) in 2D/3D),
  • a vector log_K(x) encoding the kernel function

In order to output in a numerically stable way the bitmap tensor log_B(x) such that

log(B)(x) = log(A ★ K)(x) = log[ Σ_i exp{ log_A(i)) + log_K(x-i)) } ]

We were planning to implement those operations by hand in CUDA... But Tensor Comprehensions could allow us to share these ideas in a standard, maintainable way, which would be much better!

We are thus quite excited by your release, bar a little hitch: I could not implement "log-convolutions" efficiently using TC :-(

import torch
from torch.autograd import Variable
import tensor_comprehensions as tc

use_cuda = torch.cuda.is_available()
dtype    = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

# Straightforward convolution with a filter of size Xp = 2*X-1
convolution_1D_lang = """
def convolution_1D(float(X) A, float(Xp) K) -> (B) {
    B(x)   +=! A(xp) * K(X-1+x-xp) where xp in 0:X
}
"""

# Safe, step-by-step computation
log_convolution_1D_lang = """
def log_convolution_1D(float(X) log_A, float(Xp) log_K) -> (maxExp,myExp,sumExp,LSE,log_B) {
    maxExp(x) max=!      log_A(xp) + log_K(X-1+x-xp)
    myExp(x,xp)  =  exp( log_A(xp) + log_K(X-1+x-xp) - maxExp(x) )
    sumExp(x)   +=! myExp(x,xp)
    LSE(x)       = log(sumExp(x))
    log_B(x)     = maxExp(x) + LSE(x)
}
"""

# Remove the "useless" variable myExp to get a *linear memory footprint*.
# I just collapse the lines "myExp(x,xp) = ...", "sumExp +=! ..."
# into a single statement.
#
# This simplification is critical as in practice
# (when we implement "log_convolution_3D", etc.),
# `A_log` is a (256,256,256) 3D tensor and
# `log_K` is a (511) vector (if we compute separable log-convolutions).
#
# Creating a full (256,256,256,511) array "myExp" is thus intractable...
log_convolution_1Db_lang = """
def log_convolution_1Db(float(X) log_A, float(Xp) log_K) -> (maxExp,sumExp,LSE,log_B) {
    maxExp(x) max=!      log_A(xp) + log_K(X-1+x-xp)
    sumExp(x)   +=! exp( log_A(xp) + log_K(X-1+x-xp) - maxExp(x) )
    LSE(x)       = log(sumExp(x))
    log_B(x)     = maxExp(x) + LSE(x)
}
"""

# Let's process our TC strings
convolution_1D      = tc.define(convolution_1D_lang,      name="convolution_1D")
log_convolution_1D  = tc.define(log_convolution_1D_lang,  name="log_convolution_1D")
log_convolution_1Db = tc.define(log_convolution_1Db_lang, name="log_convolution_1Db")

# And test our routines on some data
# For the sake of simplicity, we use a dirac signal
#          A = [0,0,0,1,0,0,0,0,0,0]
X = 10
log_A = -1000 * torch.ones(X).type(dtype)
log_A[3] = 0
log_A = Variable(log_A)

# We'll convolve A with a Gaussian kernel K of std=2
sigma = Variable(torch.Tensor([2]).type(dtype))
C = Variable( torch.arange(-X+1,X).type(dtype))
log_K = -(C/sigma)**2

B      = convolution_1D( log_A.exp(), log_K.exp() )
log_B  = log_convolution_1D(  log_A, log_K )[-1]
log_Bb = log_convolution_1Db( log_A, log_K )[-1]

print("A :\n",         log_A.exp().view(1,-1), "\n")
print("K :\n",         log_K.exp().view(1,-1), "\n")
print("B = A ★ K :\n",           B.view(1,-1), "\n")
print("log(B) :\n",          log_B.view(1,-1), "\n")
print("log(B) - buggy :\n", log_Bb.view(1,-1), "\n")

Output:

/home/jean/anaconda3/envs/py36/lib/python3.6/site-packages/torch/cuda/__init__.py:97: UserWarning:
    Found GPU0 GeForce GTX 960M which is of cuda capability 5.0.
    PyTorch no longer supports this GPU because it is too old.

  warnings.warn(old_gpu_warn % (d, name, major, capability[1]))
[WARNING]: No mapping options passed, 'naive' type mapping options will be used and will likely have bad performance. See help(your_layer.__call__) for setting mapping options.
[WARNING]: No mapping options passed, 'naive' type mapping options will be used and will likely have bad performance. See help(your_layer.__call__) for setting mapping options.
[WARNING]: No mapping options passed, 'naive' type mapping options will be used and will likely have bad performance. See help(your_layer.__call__) for setting mapping options.
A :
 Variable containing:
    0     0     0     1     0     0     0     0     0     0
[torch.cuda.FloatTensor of size 1x10 (GPU 0)]


K :
 Variable containing:

Columns 0 to 5
 1.6052e-09  1.1254e-07  4.7851e-06  1.2341e-04  1.9305e-03  1.8316e-02

Columns 6 to 11
 1.0540e-01  3.6788e-01  7.7880e-01  1.0000e+00  7.7880e-01  3.6788e-01

Columns 12 to 17
 1.0540e-01  1.8316e-02  1.9305e-03  1.2341e-04  4.7851e-06  1.1254e-07

Columns 18 to 18
 1.6052e-09
[torch.cuda.FloatTensor of size 1x19 (GPU 0)]


B = A ★ K :
 Variable containing:
 0.1054  0.3679  0.7788  1.0000  0.7788  0.3679  0.1054  0.0183  0.0019  0.0001
[torch.cuda.FloatTensor of size 1x10 (GPU 0)]


log(B) :
 Variable containing:
-2.2500 -1.0000 -0.2500  0.0000 -0.2500 -1.0000 -2.2500 -4.0000 -6.2500 -9.0000
[torch.cuda.FloatTensor of size 1x10 (GPU 0)]


log(B) - buggy :
 Variable containing:
  nan   nan   nan   nan   nan   nan   nan   nan   nan   nan
[torch.cuda.FloatTensor of size 1x10 (GPU 0)]

Unfortunately, the "simplification" did not go as expected: it introduced gibberish values (in this case, NaNs) in log_Bb (the correct values are those computed in log_B).

I see three possibilities for this (unexpected) behavior:

  • There's a bug coming from my "source" PyTorch+TC installation. On my laptop, I'm using a GPU with CC=5.0 which is not officially supported by PyTorch anymore, so there could be something fishy in the ATen dependencies, etc. However all the tests in TC/test_python/test_layers seem to work fine, so this would be a little bit surprising.

  • There's a bug in the TC compiler.

  • Statements such as A(i) += exp( ... ) are not supported (yet?) by TC.

What do you think about this? If we could make the log_convolution_1Db work correctly, this would be of utmost interest to all the researchers working around optimal transport and shape analysis in 2D/3D. (In practice, A is typically a 3D MRI/CT scan and K is a gaussian kernel)

As requested, here's some technical info:

  • OS: Ubuntu 17.10, freshly installed.
  • How you installed TC (docker, conda, source): from source.
  • Python version: 3.6.4
  • CUDA/cuDNN version: 9.0
  • Conda version (if using conda): 4.4.11
  • GCC/GXX version (if compiling from source): 4.8
  • LLVM/Tapir git hash used (if compiling from source): The suggested command $HOME/clang+llvm-tapir5.0/bin/clang --version didn't output anything. However, simply typing clang --version gives me
clang version 5.0.0 (https://github.com/wsmoses/Tapir-Clang 2637f015d66418964aa0225534c004dd71a174b8) (/opt/conda/conda-bld/git_cache/github.com/wsmoses/Tapir-LLVM ec3ad2b8d3810dde9c0aaccf3f3f971144d90bc2)
Target: x86_64-unknown-linux-gnu
Thread model: posix
InstalledDir: /home/jean/anaconda3/bin
  • Commit hash of the TC repo and submodules (if compiling from source): 044c263449ddbe2037e68acbc90253a0cc9f5cb0, i.e. yesterday's master (Merge pull request #166 from facebookresearch/fix-flags). Should I have chosen the v0.1.1 commit instead?

Ok, this "bug report" got longer than I expected it to be... But since you seem to relish user feedback, I hope that's fine for you :-)

Once again, thanks a lot for your hard work towards the research community, Best regards,

Jean

jeanfeydy avatar Mar 19 '18 16:03 jeanfeydy

@jeanfeydy thanks for your very detailed issue and very interesting use case. So we had a quick chat with @apaszke about the log-sum exp trick just yesterday in the context of logsoftmax, flagging him.

I need a bit more time to digest all this but I'll try to make a C++ unit test out of it because there is no fundamental reason TC shouldn't support it.

nicolasvasilache avatar Mar 19 '18 17:03 nicolasvasilache

Hi @jeanfeydy, thanks for your kind words. we do support +=exp(...) reductions. Checkout one example here: https://github.com/facebookresearch/TensorComprehensions/blob/master/test_python/layers/test_softmax.py#L32

having gibberish values is definitely something not expected. May I recommend that you try conda installation of TC and then inspect the values? ATM, we don't provide official instructions for installing TC with PyTorch built from source. There is a task open for it https://github.com/facebookresearch/TensorComprehensions/issues/130

if you can let us know whether conda package works as expected, this rules out the case of bug in TC compiler and I can prioritize https://github.com/facebookresearch/TensorComprehensions/issues/130

Thanks so much.

prigoyal avatar Mar 19 '18 17:03 prigoyal

Hi @jeanfeydy, thanks for your report, we will take a look. Some quick answers

Statements such as A(i) += exp( ... ) are not supported (yet?) by TC.

I don't see a reason (other than a bug) why this shoudn't work.

Should I have chosen the v0.1.1 commit instead?

No, there was a fix for max=!, which you need, after that.

@prigoyal I'm not sure latest conda binary includes the fix for max=!, if it does not let's not recommend replacing TC built from source with it here

ftynse avatar Mar 19 '18 17:03 ftynse

I assigned this to myself to figure out if the issue lies with pytorch source install with TC.

prigoyal avatar Mar 19 '18 17:03 prigoyal

Hi @ftynse , the binary includes the fix for max=! :) My recommendation for using conda package is to check for the gibberish value :) thanks

prigoyal avatar Mar 19 '18 17:03 prigoyal

Wow, thanks a lot for your quick answers! Unfortunately, it looks like there's something strange happening in the TC compiler: I could reproduce the bug on a completely different machine, using the standard conda packages for PyTorch and TC.

More precisely: on a distant machine with Ubuntu 16.04, nvidia 387.26, a GeForce GTX 780 Ti GPU and CUDA 9.1, I typed

conda create --name py36 python=3.6
source activate py36
conda install -y -c pytorch -c tensorcomp tensor_comprehensions
python log_convolution_1D.py  # the test script, written in the Opening Post

The result was basically the same, with gibberish values instead of NaNs:

B = A ★ K :
 Variable containing:
 0.1054  0.3679  0.7788  1.0000  0.7788  0.3679  0.1054  0.0183  0.0019  0.0001
[torch.cuda.FloatTensor of size 1x10 (GPU 0)]
 

log(B) :
 Variable containing:
-2.2500 -1.0000 -0.2500  0.0000 -0.2500 -1.0000 -2.2500 -4.0000 -6.2500 -9.0000
[torch.cuda.FloatTensor of size 1x10 (GPU 0)]
 

log(B) - buggy :
 Variable containing:
1.00000e-08 *
  0.0000  0.0000 -2.9802  0.0000 -2.9802  0.0000  0.0000  0.0000  0.0000  0.0000
[torch.cuda.FloatTensor of size 1x10 (GPU 0)]

To answer @prigoyal : as suggested, I started my script by copy-pasting the softmax layer. For me, the "dream feature" would then be to be allowed to replace conveniently blocks like

def softmax(float(N, D) I) -> (maxVal, expDistance, expSum) {
  maxVal(n) max=! I(n, d)
  expDistance(n, d) = exp(I(n, d) - maxVal(n))
  expSum(n) +=! expDistance(n, d)
}

with

def softmax(float(N, D) I) -> (maxVal, expSum) {
  maxVal(n) max=! I(n, d)
  expSum(n) +=! exp(I(n, d) - maxVal(n))
}

The above substitution works, which is great! ... But the more sophisticated "log-convolution" does not.

Playing around with TC, I was able to find an other bug (gibberish convolution outputs when the filter size exceeds ~32), for which I will submit a related report. I'm wondering if those bugs could be related, coming from a "memory overflow" in the reduction operation that you rarely encounter in typical CNNs (since people only use very small filter sizes, say 3x3 in VGG).

jeanfeydy avatar Mar 20 '18 14:03 jeanfeydy

Could you please test with filter sizes 31,32 and 33? This sounds like there may be a problem with warp-based reductions.

ftynse avatar Mar 20 '18 14:03 ftynse

Hi @ftynse , I submitted a bug report for the standard convolution. As for the log-convolution (i.e. this bug report), no choice of mapping options solves the problem: the output of log_convolution_1Db is gibberish, whatever the input size and mapping.

jeanfeydy avatar Mar 20 '18 15:03 jeanfeydy

Hi @jeanfeydy , it looks like #179 issue captures the drilled down issue and hence this current issue can be closed?

prigoyal avatar Mar 21 '18 12:03 prigoyal

Hi @prigoyal , I'm not sure about this: the bug #179 is related to large convolution filters / loops, whereas this one happens whenever I combine a non-trivial reduction with a non-trivial range inference. That is, writing

sumExp(x)   +=! exp( log_A(xp) + log_K(X-1+x-xp) - maxExp(x) )

outputs gibberish results whatever I do: choosing small vectors or smart mapping options does not solve anything.

jeanfeydy avatar Mar 21 '18 12:03 jeanfeydy

thanks @jeanfeydy. can you provide us with a gist link for repro?

prigoyal avatar Mar 21 '18 12:03 prigoyal

Sure, here it is https://gist.github.com/jeanfeydy/155e0f2bd5bb9b1dd8a2a4f6e7109530 (Just a copy-paste of the OP; I didn't know that we could share scripts so easily...)

jeanfeydy avatar Mar 21 '18 13:03 jeanfeydy