neural-combinatorial-rl-pytorch icon indicating copy to clipboard operation
neural-combinatorial-rl-pytorch copied to clipboard

mask

Open ricgama opened this issue 7 years ago • 35 comments

Hello @pemami4911,

The problem really was with the mask. I've fixed it and the network started to learn. My Decoder now is:

class Decoder(nn.Module):
    def __init__(self, feactures_dim,hidden_size, n_layers=1):
        super(Decoder, self).__init__()
                
        self.W1 = Var(hidden_size, hidden_size)
        self.W2 = Var(hidden_size, hidden_size)
        self.b2 = Var(hidden_size)
        self.V = Var(hidden_size)
        
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, num_layers=n_layers)

    def forward(self, input, hidden, enc_outputs,mask, prev_idxs):
        
        w1e = torch.matmul(enc_outputs,self.W1)
        w2h = (torch.matmul(hidden[0][-1], self.W2) + self.b2).unsqueeze(1)
        
        u = F.tanh(w1e + w2h)
        a = torch.matmul(u, self.V)
        
        a, mask = self.apply_mask( a, mask, prev_idxs)
        a = F.softmax(a)
        res, hidden = self.lstm(input, hidden)          
        return a, hidden, mask
    
    def apply_mask(self, attentions, mask, prev_idxs):    
        if mask is None:
            mask = Variable(torch.ones(attentions.size())).cuda()
            
        maskk = mask.clone()

        if prev_idxs is not None:

            for i,j in zip(range(attentions.size(0)),prev_idxs.data):
                maskk[i,j[0]] = 0
                
            masked= maskk*attentions + maskk.log()
        else:
            masked = attentions  

        return masked, maskk

For n=10, I'm obtaining the following during training the AC version:


Step 0
Average train model loss:  -81.07959747314453
Average train critic loss:  29.443866729736328
Average train pred-reward:  -0.08028505742549896
Average train reward:  5.2866997718811035
Average loss:  -15.106804847717285
------------------------
Step  1000
Average train model loss:  -0.7814755792869255
Average train critic loss:  0.7740849611759186
Average train pred-reward:  4.219553744537756
Average train reward:  4.272005982398987
Average loss:  -6.201663847446442
------------------------

(...)

Step  19000
Average train model loss:  -0.06441724334075116
Average train critic loss:  0.1361817416474223
Average train pred-reward:  3.0573679950237276
Average train reward:  3.0583059163093567
Average loss:  -1.5689961900115013

I've checked and the returned solutions are all feasible so it seems that it is really converging. I will clean my code and hope that by the end of the week I will have some training history plots and test set validation. If you like, I can share my notebook.

Best regards

ricgama avatar Nov 06 '17 10:11 ricgama

I think sharing notebooks is always a good idea, so please share it. Would love to check it.

unnir avatar Nov 06 '17 10:11 unnir

What exactly was the issue with the mask? It looks like you’re taking mask.log()- won’t that return NaN where mask is 0?

pemami4911 avatar Nov 06 '17 11:11 pemami4911

It returns -inf, and the softmax function will return 0.
mask test .ipynb.zip

I think that the main problem was with numerical instabilities when I forced the probabilities to 0. Also, I think that cloning the probabilities and not the mask caused some problems on the backpropagation (not so sure about this...(https://discuss.pytorch.org/t/how-to-copy-a-variable-in-a-network-graph/1603/7))

ricgama avatar Nov 06 '17 12:11 ricgama

What was the problem with the masking that you fixed?

pemami4911 avatar Nov 06 '17 12:11 pemami4911

@pemami4911 Sorry, I've updated the answer above.

ricgama avatar Nov 06 '17 12:11 ricgama

Certainly, there must be a more elegant way of doing this :)

ricgama avatar Nov 06 '17 16:11 ricgama

Can you tell me what the hyperparameters you are using are? Are you using the exponential moving average critic? Did you try it on TSP_20?

pemami4911 avatar Nov 06 '17 18:11 pemami4911

I added the maskk = mask.clone() line and it seems to be learning something!! will update soon..

pemami4911 avatar Nov 06 '17 18:11 pemami4911

@ricgama After 1 epoch (10,000 minibatches of size 128), on my held-out validation set of 1000 graphs I got:

Validation overall avg_reward: 3.0088058819770813
Validation overall reward var: 0.1305618987729639

I saw some tours as low as 2.4! It's learning! THANK YOU! haha

pemami4911 avatar Nov 06 '17 18:11 pemami4911

After 1 epoch of TSP 20, I got:

Validation overall avg_reward: 4.1758026211261745
Validation overall reward var: 0.14051547629226666

With some tours as low as 3.6. According to the paper, I should be aiming for an average reward of 3.89.

Sweet.

pemami4911 avatar Nov 06 '17 19:11 pemami4911

Great!!

For now, I'm just using a simple version with hidden_dim = 128 and no glimpses in the Pointer Net. I'm training with the paper AC network.

PointerNet_TSP.zip

I'm posting here my test sets for n=10 and n=20 for best results comparison. For n=10 I do:

tmp = np.load(test_path)
p = list(tmp['p'])
x = list(tmp['x'])
test = [[p[i],x[i]] for i in range(len(x))]

labels_te = np.array([x[0] for x in test],dtype=np.int64)
labels_te = np.lib.pad(labels_te,(0, 1), mode='constant', constant_values=0)
labels_te = np.delete(labels_te,(labels_te.shape[0]-1), axis=0)

inp_enc_te = np.array([x[1] for x in test])

where the labels are the optimal labels. For n=20:

tmp = np.load(test_path)
inp_enc_tr = np.array(tmp['x'])

For n=10 I obtain: true: 2.86695502642 predicted: 2.89031121256 with Supervised Learning.

I will try to post the RL results, for n=10 and 20, by the end of the week.

ricgama avatar Nov 06 '17 22:11 ricgama

Cool, I'll update my repo with some results by the end of the week too. For TSP 20 RL, the validation average reward is down to 4.02 and still dropping little by little after 3 hours!

Can't believe it was just a mask.clone() that was breaking everything.. usually small bugs just act as regularizers.. I guess not in Deep RL :(

pemami4911 avatar Nov 06 '17 22:11 pemami4911

After 50 epochs for TSP 20, I got 3.95 average val reward! Fairly close to 3.89, it probably would have gotten there eventually if I had let it keep training.

After two epochs on TSP 50, I'm seeing:

Validation overall avg_reward: 6.54308279466629
Validation overall reward var: 0.15608837694146416

not bad!

pemami4911 avatar Nov 07 '17 17:11 pemami4911

Do you have the training history plots?

ricgama avatar Nov 07 '17 22:11 ricgama

I haven't made any nice plots yet, but these are quick screenshots from Tensorboard for TSP 50. Been running for ~21 hours, looks like the validation avg reward has just about reached 6.05-6.10.

tsp_50_avg_reward Zoomed in on average training reward (stochastic decoding) first few hours of training
tsp_50_avg_reward_2 average training reward (stochastic decoding) so far
tsp_50_val_avg_reward Validation reward (greedy decoding). The plot shows each reward (length of tour) for the set of 1000 val graphs- after every epoch (10,000 random training graphs), I evaluate by shuffling the 1000 held-out graphs and running on each one of them. So, this isn't showing an "average" - the average is just computed at the end of running over all 1000 graphs each time I do a validation pass

pemami4911 avatar Nov 08 '17 14:11 pemami4911

After 2 epochs for TSP 20, I got

Step  20000
Average train model loss:  -0.4126803118472656
Average train critic loss:  0.23534413764963638
Average train pred-reward:  4.466843187868655
Average train reward:  4.4662888267257435
Average loss:  -6.895590880492309
------------------------

worse than your: Validation overall avg_reward: 4.1758026211261745. Now I'm trying with one attention glimpse. Are you using decaying lr?

ricgama avatar Nov 08 '17 21:11 ricgama

Is this SL or RL? And is your train reward with greedy decoding, or stochastic decoding?

I am using the lr schedule from the paper - starting at 1e-3 and every 5k steps decrease by a factor of 0.96. I'm using the exponential moving average critic, not the critic network.

pemami4911 avatar Nov 08 '17 22:11 pemami4911

It's RL and stochastic decoding. With one attention glimpse it appears a bit better so I will train 2 epochs and do greedy decoding and beam search to compare. My hardware is slower than yours so I want to check n=20 before moving to n=50...

ricgama avatar Nov 08 '17 22:11 ricgama

OK- yeah you'll want to compare with greedy decoding, not stochastic. just FYI the beam search in my codebase isn't functional yet- it's only coded to propagate a single decoding hypothesis forward at each time step, which is equivalent to greedy decoding. Fixing that is on my to-do list :D

pemami4911 avatar Nov 08 '17 22:11 pemami4911

I've implemented my beam search. It works very well but for now only for batch=1, so it's a bit slower. I can send it to you if you like...

ricgama avatar Nov 08 '17 22:11 ricgama

Meanwhile, I think I will implement the Active Search of the paper. Have you looked into it?

ricgama avatar Nov 08 '17 23:11 ricgama

Yeah you can send it to me if you'd like!

I haven't looked into implementing that yet. Not sure when I'll get to that part, got some other things I'm working on in the mean time

pemami4911 avatar Nov 09 '17 01:11 pemami4911

@pemami4911 When I was working on my BS to handle RL trained Pointer Model I found some inconsistencies that I have to look into before I share the code.

After 2 epochs Validation Av reward: 4.262 for n=20. I'm guessing that it's around the same as you within random fluctuations.

ricgama avatar Nov 10 '17 22:11 ricgama

Hello @pemami4911,

During my n=50 training it appeared loss=nan. It's strange because for n=20 it trained perfectly. I'm trying to debug my code to fix this.

While you were training for n=50 the flag print(' [!] resampling due to race condition') appeared often? Do you have any sugestions?

ricgama avatar Nov 20 '17 10:11 ricgama

Yes - You can see in my stochastic decoding function that I check if any cities were sampled that shouldn't have been- and if so, I resample all cities at that decode step. If that occurs (I think it's a race condition..?) I print out [!] resampling due to race condition. You should probably add that check too

pemami4911 avatar Nov 20 '17 18:11 pemami4911

So there must be a bug with the .multinomial() function because probability=0 actions should not be sampled. I'm running a script simulating the masking/sampling loop to have an estimate of the probability of "bad" sampling. It should be very small but different from 0. I saw your workaround for this problem. As the probability is very small, resampling again does the job. In theory, I think we should have a while condition not satisfied loop to guarantee that we never resample zero prob. actions. I think that it is worth it to report this issue, " .multinomial() sampling 0 prob. actions" on Pytorch forum. What do you think? After a proper masking, we should be able to just sit back and relax while the model is training...

ricgama avatar Nov 21 '17 00:11 ricgama

Were you able to replicate the "bad sampling" with your script?

pemami4911 avatar Nov 21 '17 01:11 pemami4911

yes. It's just:


def apply_mask( attentions, mask, prev_idxs):    
    if mask is None:
        mask = torch.zeros(attentions.size()).byte().cuda()          

    maskk = mask.clone()

    if prev_idxs is not None:

        for i,j in zip(range(attentions.size(0)),prev_idxs.data):
            maskk[i,j[0]] = 1

        attentions[maskk] = -np.inf

    return attentions, maskk

def count(n):
    k = 0
    for j in range(n):

        attentions = Variable(torch.Tensor(128,50).uniform_(-10, 10).cuda())
        prev_actions = None
        mask = None
        actions = []
        for di in range(50):
                attentions, mask = apply_mask( attentions, mask, prev_actions)
                probs = F.softmax(attentions).cuda()
                prev_actions = probs.multinomial()
                for old_idxs in actions:
                    # compare new idxs
                    if old_idxs.eq(prev_actions).data.any():
                        k+=1
                        print(' [!] resampling')

                actions.append(prev_actions)
    return k

By the end of the day I will have an estimate.

ricgama avatar Nov 21 '17 08:11 ricgama

The relative frequency on a run of 100000 batches, size 128 and n=50, was 0.00043.

Do you mind that I post a question on the Pytorch forum, using the code above?

ricgama avatar Nov 21 '17 20:11 ricgama

Yeah, go for it

pemami4911 avatar Nov 21 '17 21:11 pemami4911