n3net icon indicating copy to clipboard operation
n3net copied to clipboard

Aggregate selects same element multiple times

Open LemonPi opened this issue 6 years ago • 8 comments

I'm trying to evaluate NNN against conventional KNN on a simple test case. The test case is to find the 5 nearest neighbour for a permutation of indices (for easy intuitive verification). The problem is that the aggregate output is outputting the same value for all 5 neighbours.

Problem setup:

import torch
import non_local
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

N = 50
nn = non_local.N3AggregationBase(5, temp_opt={"external_temp": False})

x = torch.tensor(np.random.permutation(list(range(N))), dtype=torch.float, requires_grad=True)
x = x.reshape(1, N, 1).to(device)
xe = x
ye = xe
I = torch.tensor(list(range(N)), dtype=torch.long).repeat(N, 1).reshape(1, N, N).to(device)

z = nn(x, xe, ye, I)

Where the aggregate output z is

tensor([[[[10.0001, 10.0001, 10.0001, 10.0001, 10.0001]],
         [[42.0001, 42.0001, 42.0001, 42.0001, 42.0001]],
         [[22.0001, 22.0001, 22.0001, 22.0001, 22.0001]],
...

Is this supposed to be the case and I'm interpreting the result wrong? If so then what is the aggregate output z supposed to represent?

LemonPi avatar Feb 12 '19 20:02 LemonPi

Hi @LemonPi ,

I think this is a problem of symmetry. In a nutshell, our continuous relaxation of hard kNN selection is not good in breaking ties. If you take the index 10, then the indices 9/11, 8/12, 7/13, ..., are each equally distant to 10 and hence contribute with equal weight to the neighbor selection. Also the logits are updated equally (Eq. 9) and hence in the next round of neighbor selection they again have equal weights.

I think there is also some numerical instability involved in how Eq. 9 is implemented right now ... :)

tobiasploetz avatar Feb 13 '19 08:02 tobiasploetz

I see, this seems to be a relatively big problem since even adding noise didn't resolve tie-breaking (tried rand up to magnitudes up to 1).

LemonPi avatar Feb 13 '19 20:02 LemonPi

Hi @LemonPi ,

there was an issue with numerical stability in when computing log(1 - exp(x)). This should be fixed with the latest update.

I modified your example a bit (no permutation, decrease temperature):

import torch
import models.non_local as non_local
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

N = 50
nn = non_local.N3AggregationBase(5, temp_opt={"external_temp": False})
nn.cuda()
nn.nnn.log_temp_bias = -50 # decrease temperature -> NNN acts more like hard kNN

# x = torch.tensor(np.random.permutation(list(range(N))), dtype=torch.float, requires_grad=True)
x = torch.tensor(list(range(N)), dtype=torch.float, requires_grad=True)
n = torch.zeros_like(x).normal_() * 0.0001
x = x+n
x = x.reshape(1, N, 1).to(device)
xe = x
ye = xe
I = torch.tensor(list(range(N)), dtype=torch.long).repeat(N, 1).reshape(1, N, N).to(device)

z = nn(x, xe, ye, I)

for i in range(N):
	print("\t".join(["{:.2f}"]*5).format(*(z[0,i,0,:].tolist())))

This gives the following output:

0.00	1.00	2.00	3.00	4.00
1.00	2.00	0.00	3.00	4.00
2.00	1.00	3.00	0.00	4.00
3.00	4.00	2.00	5.00	1.00
4.00	3.00	5.00	2.00	6.00
5.00	6.00	4.00	7.00	3.00
6.00	7.00	5.00	8.00	4.00
7.00	6.00	8.00	5.00	9.00
8.00	7.00	9.00	6.00	10.00
9.00	10.00	8.00	11.00	7.00
10.00	11.00	9.00	12.00	8.00
11.00	10.00	12.00	9.00	13.00
12.00	13.00	11.00	14.00	10.00
13.00	12.00	14.00	15.00	11.00
14.00	15.00	13.00	12.00	16.00
15.00	14.00	16.00	13.00	17.00
16.00	17.00	15.00	18.00	14.00
17.00	16.00	18.00	15.00	19.00
18.00	19.00	17.00	20.00	16.00
19.00	20.00	18.00	21.00	17.00
20.00	19.00	21.00	18.00	22.00
21.00	22.00	20.00	19.00	23.00
22.00	21.00	23.00	20.00	24.00
23.00	24.00	22.00	25.00	21.00
24.00	25.00	23.00	26.00	22.00
25.00	24.00	26.00	23.00	27.00
26.00	25.00	27.00	24.00	28.00
27.00	28.00	26.00	29.00	25.00
28.00	27.00	29.00	26.00	30.00
29.00	30.00	28.00	31.00	27.00
30.00	31.00	29.00	32.00	28.00
31.00	30.00	32.00	29.00	33.00
32.00	31.00	33.00	34.00	30.00
33.00	34.00	32.00	35.00	31.00
34.00	33.00	35.00	32.00	36.00
35.00	36.00	34.00	33.00	37.00
36.00	37.00	35.00	38.00	34.00
37.00	38.00	36.00	39.00	35.00
38.00	37.00	39.00	36.00	40.00
39.00	40.00	38.00	41.00	37.00
40.00	41.00	39.00	42.00	38.00
41.00	40.00	42.00	39.00	43.00
42.00	43.00	41.00	40.00	44.00
43.00	42.00	44.00	41.00	45.00
44.00	45.00	43.00	46.00	42.00
45.00	44.00	46.00	43.00	47.00
46.00	47.00	45.00	92.00	49.00
47.00	94.00	49.00	45.00	44.00
48.00	49.00	47.00	46.00	45.00
49.00	48.00	47.00	46.00	45.00

I hope this solves your problem

tobiasploetz avatar Feb 22 '19 10:02 tobiasploetz

I see, I tried this again and the critical line is decreasing the temperature. There is a new issue however with some hallucinated values such as:

42.00003 , 84.00015 , 43.999893, 39.999966, 39.000072

Where did the 84 come from?!?

This seems to be a phenomena that occurs when the temperature is too low:

image

From what I understand from the paper, lowering the temperature more closely approximates hard kNN and results in sharper distributions. What practical issues does this have?

LemonPi avatar Feb 25 '19 21:02 LemonPi

I think this is related to Pytorch's implementation of log_softmax, which seemingly does not work correct if the maximal value of the argument has a large absolute value and appears multiple times:

F.log_softmax(torch.from_numpy(np.asarray([-1e2, -1e2], dtype=float)).float()).exp()
# tensor([0.5000, 0.5000])
F.log_softmax(torch.from_numpy(np.asarray([-1e20, -1e20], dtype=float)).float()).exp()
# tensor([1., 1.])
F.log_softmax(torch.from_numpy(np.asarray([1e20, 1e20], dtype=float)).float()).exp()
# tensor([1., 1.])

This causes the weights of the weighted averages to sum to something greater than one.

From a practical point of view this should be of minor relevance if you want to train the N3 block within your network since the gradient signal will vanish anyway the lower your temperature gets (in the limit of t->0 or log t -> -inf, N3 selection is just KNN selection and hence has zero gradients everywhere). Hence your network will probably never reach the above situation.

tobiasploetz avatar Feb 26 '19 10:02 tobiasploetz

Thanks, will keep in mind!

LemonPi avatar Feb 27 '19 21:02 LemonPi

The spurious values seems to be from IndexedMatmul2Efficient in aggregating the output (produce z from W) instead of from the log_softmax. For example, the W of querying for k=3 nearest neighbours of N = 30 gives:

array([[0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.36787945, 0.36787945],
       [1.        , 0.        , 0.        ],
       [0.        , 0.36787945, 0.36787945],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ]], dtype=float32)

Which seems correct - the distributions are concentrated around 22, 21, and 23. However the output z gives

array([21.999949, 16.18666 , 16.18666 ], dtype=float32)

Which doesn't correspond to the indices from W... The I is just arange(0,N) for each query point.

Everything in IndexMatmul2Efficient has no issues until

            z_interm = torch.cat([torch.matmul(y_full[:,i_k:i_k+1,:,:], x_interm) for i_k in range(k)], 1)

Which results in

z_interm[0,:,22]
Out[24]: 
tensor([[21.9999],
        [16.1867],
        [16.1867]], device='cuda:0')

Update, maybe the problem is in calculating W in the first place because the columns represent probability distributions so they should sum to 1. However in this case they do not...

LemonPi avatar Apr 22 '19 23:04 LemonPi

Fixing this by normalizing each output distribution removes the spurious values, but still ends up selecting the center value multiple times because of aggregation via expected value. It seems like the fundamental cause of this is that the distributions are ordered and this method is not relaxing kNN since in kNN we don't care about the order of neighbours. A more direct relaxation would give 1 distribution per query point instead of k distributions.

We can do normalization by adding the following at the end of NeuralNearestNeighbors.forward

        # normalize so output is a distribution
        for bb in range(b):
            for mm in range(m):
               W[bb, mm] /= torch.sum(W[bb, mm], dim=0)

Example of problem before normalization with N = 3000 data points and k=3 t_too_low_3000 Average error to kNN neighbourhoods (0 is exact) t_too_low_3000b

Same problem after normalization fix_t_too_low_3000 fix_t_too_low_3000b

LemonPi avatar Apr 24 '19 20:04 LemonPi