dcgan.torch icon indicating copy to clipboard operation
dcgan.torch copied to clipboard

linear interpolation?

Open dribnet opened this issue 8 years ago • 13 comments

Hi - I've been doing a lot of work lately with interpolation in latent space, and I think linear interpolation might not be the best interpolation operator for high dimensional spaces. Though admittedly this is common practice, this seemed as good a place as any to discuss this, since the dcgan code seems to do exactly that here:

noiseL = torch.FloatTensor(opt.nz):uniform(-1, 1)
noiseR = torch.FloatTensor(opt.nz):uniform(-1, 1)
if opt.noisemode == 'line' then
   -- do a linear interpolation in Z space between point A and point B
   -- each sample in the mini-batch is a point on the line
    line  = torch.linspace(0, 1, opt.batchSize)
    for i = 1, opt.batchSize do
        noise:select(1, i):copy(noiseL * line[i] + noiseR * (1 - line[i]))
    end

I'm starting with the assumption that torch.FloatTensor(opt.nz):uniform(-1, 1) is a valid way to uniformly sample from the prior in the latent space. In the examples below, I'll leave the nz dimension at the default of 100. Let's do an experiment and see what the expected lengths of these vectors are.

image

I see a gaussian with mean about 5.76 and with 0.25 standard deviation. I believe this means that >99% of vectors would be expected to have a length between 4.8 and 6.8 (4 standard deviations out). This result should not be a big surprise if we think about taking 100 independent random numbers and then running them through the distance formula.

But now let's think about the effects of linear interpolation between these random vectors. At an extreme, we have the linearly interpolated midpoints halfway between any two of these vectors - let's see what the expected lengths of these are.

image

So now we have a gaussian with a mean vector of 4.06 and 0.24 standard deviation. Needless to say, these are not the same distribution, and in fact they are effectively disjoint - the probability of an item from the second appearing in the first is vanishingly small. In other words, the points on the linearly interpolated path are many standard deviations away from points expected in the prior distribution.

If my premise is correct that torch.FloatTensor(opt.nz):uniform(-1, 1) performs a uniform sampling across the latent space (a big if, and I'd like to verify this!), then the prior is more shaped like a hypersphere. In that case, spherical interpolation makes a lot more sense, and in my own experiments I've had good qualitative results with this approach. Curious what others think. Also note that this reasoning could be extended beyond just interpolation since this would also affect other interpretable operations - such as finding the average in a subset of labeled data (eg: average man or woman in faces).

dribnet avatar Mar 07 '16 07:03 dribnet

I think it depends on the shape of your latent space. If your latent space is spherical already (i.e. you learnt to sample from a gaussian while training, rather than from uniform), linear interpolation seems okay. If you sample Z from a cube (uniform), a spherical interpolation seems like a much better idea.

https://github.com/soumith/dcgan.torch/blob/master/main.lua#L23

soumith avatar Mar 07 '16 15:03 soumith

Agreed it depends on the shape of the latent space. But at z=100, switching from a prior of noise:uniform(-1, 1) to a prior of noise:normal(0,1) yields the same result: points along the linear interpolation between two randomly selected points will go way outside the expected distribution.


To clarify my main point, calling the following code:

local noise1 = torch.Tensor(opt.batchSize, 100, 1, 1)
noise1:normal(0, 1)
local noise2 = torch.Tensor(opt.batchSize, 100, 1, 1)
noise2:normal(0, 1)

Will always result in two 100 dim vectors with length about 10. If you choose to linearly interpolate between them, you will invariably get a "tentpole" effect in which the length decreases from 10 to 7 at the midpoint, which is over 4 standard deviations away from the expected length. Shouldn't the interpolated points instead be from the same distribution as the original samples?


Happy to uncover my own conceptual flaw in how latent spaces are sampled, which is certainly possible. My ipython notebook code to replicate this is below.

from matplotlib import pylab as plt
%matplotlib inline
import numpy as np
# random_points = np.random.uniform(low=-1, high=1, size=(1000,100))
random_points = np.random.normal(loc=0, scale=1, size=(1000,100))
lengths = map(np.linalg.norm, random_points)
print("Mean length is {:3.2f} and std is {:3.2f}".format(np.mean(lengths), np.std(lengths)))
n, bins, patches = plt.hist(lengths, 50, normed=1, facecolor='green', alpha=0.75)
plt.show()

image

# take midpoint of two adjacent points in vector
def midpoint_length(points, ix):
    num_points = len(points)
    next_ix = (ix + 1) % num_points
    avg = (points[ix] + points[next_ix]) / 2.0
    return np.linalg.norm(avg)

mid_lengths = []
for i in range(len(random_points)):
    mid_lengths.append(midpoint_length(random_points, i))
print("Mean length is {:3.2f} and std is {:3.2f}".format(np.mean(mid_lengths), np.std(mid_lengths)))
n, bins, patches = plt.hist(mid_lengths, 50, normed=1, facecolor='green', alpha=0.75)
plt.show()

image

dribnet avatar Mar 09 '16 09:03 dribnet

To visually demonstrate the relevance to this codebase, I constructed 5 (uniform) random interpolations from the pre-trained bedrooms_4_net_G.t7 model:

bedroom_pairs1

Each interpolation is presented in pairs: the first line is linear interpolation and the second is spherical interpolation. To my eye, the first line often suffers from blurring in the center or other visual washout while the second line stays crisper and more visually consistent with the style of the endpoints. This is a pattern I've seen in other latent spaces as well.

We can also visualize the tentpole effect by graphing the lengths of all of the vectors across the interpolation. Here are the five linear interpolations:

tentpole1

The lengths at each end are about 5.75, but in the center they sag down to just above 4. This is exactly what is predicted by the distributions in the original comment. Arguably, this sag correlates exactly with the visual artifacts in the rendered images above.

We can compare that to the lengths when using spherical interpolation, which of course won't sag:

tentpole2

Getting the shape of the latent space right is important, which is why I've spent time making this case here. If my argument is right, then this has implications for how to most accurately compute interpolations, extrapolations, flythroughs, averages, etc. in latent space. Alternately, a different prior could perhaps be used so that these operations could remain linear.

dribnet avatar Mar 11 '16 09:03 dribnet

After reading this through, I am convinced that spherical interpolation is essential as well. The weak generations in the center are something I've noticed as well. Thanks for pointing this out. I think the overall fix is to always keep the intermediate vectors constant.

soumith avatar Mar 11 '16 16:03 soumith

I just spoke to Arthur Szlam who smacked me on the head, because he's been telling this to me for months, and i conveniently ignored it.

He says the latent space should also be sampled from points on an n-dimensional hypersphere, and when doing interpolations, you just take the path on the great circle.

soumith avatar Mar 11 '16 19:03 soumith

Thanks for thinking this through with me. Having just revisited Domingos' classic A Few Useful Things to Know about Machine Learning, this quote now meant more to me:

INTUITION FAILS IN HIGH DIMENSIONS After overfitting, the biggest problem in machine learning is the curse of dimensionality. [...] Our intuitions, which come from a three dimensional world, often do not apply in high-dimensional ones. In high dimensions, most of the mass of a multivariate Gaussian distribution is not near the mean, but in an increasingly distant “shell” around it; and most of the volume of a high-dimensional orange is in the skin, not the pulp.

So are you suggesting constraining all latent vectors to lie exactly on the unit n-sphere? That would be an interesting simplification if it worked. I instead wrote my own interpolator which does a great circle path with elevation changes. Feel free to adapt it to this codebase if you'd like.

def slerp(val, low, high):
    omega = np.arccos(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)))
    so = np.sin(omega)
    return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega)/so * high

dribnet avatar Mar 21 '16 08:03 dribnet

Thanks for the slerp implementation, I started using it and thought I'd share some fixed edge cases. Not sure what the convention should be for the degenerate opposite vectors case, currently it just lerps them.

def slerp(val, low, high):
    omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1))
    so = np.sin(omega)
    if so == 0:
        return (1.0-val) * low + val * high # L'Hopital's rule/LERP
    return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high

print(slerp(0, np.array([1,0,0]), np.array([1,0,0])))
print(slerp(0.5, np.array([1,0,0]), np.array([1,0,0])))
print(slerp(0, np.array([1,0,0]), np.array([0.5,0,0])))
print(slerp(0.5, np.array([1,0,0]), np.array([0.5,0,0])))
print(slerp(0, np.array([1,0,0]), np.array([-1,0,0])))
print(slerp(0.5, np.array([1,0,0]), np.array([-1,0,0])))
# [ 1.  0.  0.]
# [ 1.  0.  0.]
# [ 1.  0.  0.]
# [ 0.75  0.    0.  ]
# [ 1.  0.  0.]
# [ 0.  0.  0.]

pqn avatar Mar 22 '16 21:03 pqn

@pqn As far as I understood this you're using slerp only for interpolating between two vectors (low and high). How do you sample z for the generator during GAN training?
Is it

z = np.random.normal(loc=0, scale=1, size=(1000,100))

as in the comment above by dribnet?

mgarbade avatar Dec 05 '17 10:12 mgarbade

@mgarbade I actually just stole the slerp for feature vector interpolation when playing with this paper (which is not a GAN) in MXNet. The equivalent input z was always just a one-hot vector for a specific 3D model and not randomly sampled.

pqn avatar Dec 05 '17 19:12 pqn

Hey, this is a super useful thread, thanks to all!

In case it's useful for anyone (this thread ranks high on SEO :), multivariate gaussian converges to hypersphere with radius of sqrt(dim-1) (can approximate as sqrt(dim) for large dim) and variance of 0.5. (Technically chi distribution with variance 1, but for high dim corresponds to gaussian with variance 0.5).

See below for empirical code example, and nice brief explanation at https://www.johndcook.com/blog/2011/09/01/multivariate-normal-shell/)

import math
import numpy as np
for i in range(1,15):
    n = 10000 # sample size
    dim = i*i*i*i
    z = np.random.normal(size=[n,dim])
    r = np.linalg.norm(z, axis=1)
    mean, var = np.mean(r), np.var(r)
    print('dim:{}, sqrt(dim-1):{}, mean:{}, var:{}'.format(dim, math.sqrt(dim-1), mean, var))
  
    
dim:1, sqrt(dim-1):0.0, mean:0.795150900024, var:0.366173691484
dim:16, sqrt(dim-1):3.87298334621, mean:3.93135579711, var:0.491231344077
dim:81, sqrt(dim-1):8.94427191, mean:8.96891458959, var:0.489195309416
dim:256, sqrt(dim-1):15.9687194227, mean:15.9929240032, var:0.498972907482
dim:625, sqrt(dim-1):24.9799919936, mean:24.9836431855, var:0.490930587752
dim:1296, sqrt(dim-1):35.9861084309, mean:35.9903244552, var:0.496412523249
dim:2401, sqrt(dim-1):48.9897948557, mean:49.0022012748, var:0.498904908632
dim:4096, sqrt(dim-1):63.9921870231, mean:64.0026298725, var:0.497066494871
dim:6561, sqrt(dim-1):80.9938269253, mean:80.9966366457, var:0.490071214555
dim:10000, sqrt(dim-1):99.994999875, mean:99.9974767903, var:0.516483725478
dim:14641, sqrt(dim-1):120.995867698, mean:120.997610236, var:0.493307838545
dim:20736, sqrt(dim-1):143.996527736, mean:143.998235492, var:0.500247771622
dim:28561, sqrt(dim-1):168.997041394, mean:168.998178752, var:0.509050583724
dim:38416, sqrt(dim-1):195.997448963, mean:196.013193242, var:0.497141009441

memo avatar May 18 '18 13:05 memo

Hey, super useful thank you! I needed a batchwise, n dimensional slerp that does a different range of interpolation steps for each sample for my purposes, so I adapted your code to tensorflow. Here it is in case anyone needs it:

def tf_slerp(val, low, high):
  # Val must be Batch_size, n_timesteps
 # low must be batch_size, n_dimensions
  # high must be batch_size, n_dimensions
  
  dim_size = low.shape[-1]
  time_steps = val.shape[-1]
  p1 = low/tf.tile( tf.expand_dims(tf.norm(low, axis = 1),axis = 1), [1,dim_size])
  p2 = high/tf.tile( tf.expand_dims(tf.norm(high, axis = 1),axis = 1), [1,dim_size])
  dot = tf.reduce_sum(p1*p2, axis = -1) # batchwise dot of our Batch*num_dims.

  omega = tf.acos(tf.clip_by_value(
      dot, -1,1,))
  so = tf.sin(omega)
  # if (so == 0):
  # return (1.0-val)*low + val * high
  so = tf.tile(tf.expand_dims(tf.expand_dims(so, axis = 1),axis = 2), [1,time_steps,dim_size])
  omega = tf.tile(tf.expand_dims(tf.expand_dims(omega, axis = 1), axis = 2), [1,time_steps,dim_size])
  val = tf.tile(tf.expand_dims(val, axis = 2), [1,1,dim_size])
  low = tf.tile(tf.expand_dims(low, axis = 1), [1,time_steps,1])
  high = tf.tile(tf.expand_dims(high, axis = 1), [1,time_steps,1])
  return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high

And usage to visualise it in the 2D case

high = np.array([[0.0,1.0], [0.0,3.0]])
values = np.array([np.linspace(0,1,10), np.linspace(0,0.5,10)])

array = tf_slerp(values, low, high)

plt.scatter(array[0,:,0], array[0,:,1])
plt.show()
plt.scatter(array[1,:,0], array[1,:,1])
plt.show()```

sholtodouglas avatar Aug 27 '19 06:08 sholtodouglas

Thanks for the slerp implementation, I started using it and thought I'd share some fixed edge cases. Not sure what the convention should be for the degenerate opposite vectors case, currently it just lerps them.

def slerp(val, low, high):
    omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1))
    so = np.sin(omega)
    if so == 0:
        return (1.0-val) * low + val * high # L'Hopital's rule/LERP
    return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high

print(slerp(0, np.array([1,0,0]), np.array([1,0,0])))
print(slerp(0.5, np.array([1,0,0]), np.array([1,0,0])))
print(slerp(0, np.array([1,0,0]), np.array([0.5,0,0])))
print(slerp(0.5, np.array([1,0,0]), np.array([0.5,0,0])))
print(slerp(0, np.array([1,0,0]), np.array([-1,0,0])))
print(slerp(0.5, np.array([1,0,0]), np.array([-1,0,0])))
# [ 1.  0.  0.]
# [ 1.  0.  0.]
# [ 1.  0.  0.]
# [ 0.75  0.    0.  ]
# [ 1.  0.  0.]
# [ 0.  0.  0.]

@pqn, thank you for this snippet. Could you explain how is the case (so==0) is related to "L'Hopital's rule"?

cocoaaa avatar Oct 28 '20 01:10 cocoaaa

@cocoaaa It's been a while since I thought about this but IIRC, if omega (thus np.sin(omega)) is at 0, then the usual slerp formula will become undefined since that's in the denominator. But the slerp formula can still be smoothly extended at 0 by calculating the limit via L'Hopital's rule. This simplifies to linear interpolation, which also probably gives better numerical stability in the region near 0 as well. Seems the Wikipedia article corroborates my memory.

pqn avatar Oct 29 '20 06:10 pqn