lola icon indicating copy to clipboard operation
lola copied to clipboard

Pervasive reshape bugs in train_cg?

Open cooijmanstim opened this issue 3 years ago • 1 comments

Unless I misunderstand what the code is trying to do, the following pattern in train_cg.py is a bug:

trainBatch1 = [[], [], [], [], [], []]  # line 137
...
while j < max_epLength:  # line 149
  ...
  trainBatch1[3].append(s1)  # line 180
  ...
...
# line 236:
last_state = np.reshape(
  np.concatenate(trainBatch1[3], axis=0),
  [batch_size, trace_length, env.ob_space_shape[0],
   env.ob_space_shape[1], env.ob_space_shape[2]])[:,-1,:,:,:]

The issue is that np.concatenate(trainBatch1[3], axis=0) is stored in memory with the time (trace_length) axis first and the batch axis second, and should be reshaped to [trace_length, batch_size, ...] and then transposed to move the batch axis forward. Reshaping straight to [batch_size, trace_length, ...] will silently misinterpret the order in which the elements are stored in memory.

The same buggy append-reshape pattern happens for basically all the things stored in trainBatch0, trainBatch1, with the offending reshapes happening in various places in other files, which expect [batch, time] storage order. I think the easiest fix would be to establish the desired storage order of trainBatch1 right after the loop over j < max_epLength, e.g.

trainBatch1 = [np.stack(seq, axis=1) for seq in trainBatch1]

and similar for trainBatch0. Now trainBatch1[3] has exactly the shape you want it to have at line 236, so last_state = trainBatch1[3][:, -1, :, :, :] will do. You can still trainBatch1[i].reshape([batch_size * trace_length, ...]) if you need the batch and time axes flattened, and this will correctly reshape back to [batch_size, trace_length, ...].

cooijmanstim avatar Apr 05 '21 18:04 cooijmanstim

Thank you for finding and filing this - looks like a bad bug :( Please submit a PR for the fix when you have a chance. Thanks a lot!

jakobnicolaus avatar Apr 13 '21 23:04 jakobnicolaus