lola
lola copied to clipboard
Pervasive reshape bugs in train_cg?
Unless I misunderstand what the code is trying to do, the following pattern in train_cg.py
is a bug:
trainBatch1 = [[], [], [], [], [], []] # line 137
...
while j < max_epLength: # line 149
...
trainBatch1[3].append(s1) # line 180
...
...
# line 236:
last_state = np.reshape(
np.concatenate(trainBatch1[3], axis=0),
[batch_size, trace_length, env.ob_space_shape[0],
env.ob_space_shape[1], env.ob_space_shape[2]])[:,-1,:,:,:]
The issue is that np.concatenate(trainBatch1[3], axis=0)
is stored in memory with the time (trace_length
) axis first and the batch axis second, and should be reshaped to [trace_length, batch_size, ...]
and then transposed to move the batch axis forward. Reshaping straight to [batch_size, trace_length, ...]
will silently misinterpret the order in which the elements are stored in memory.
The same buggy append-reshape pattern happens for basically all the things stored in trainBatch0, trainBatch1
, with the offending reshapes happening in various places in other files, which expect [batch, time]
storage order. I think the easiest fix would be to establish the desired storage order of trainBatch1
right after the loop over j < max_epLength
, e.g.
trainBatch1 = [np.stack(seq, axis=1) for seq in trainBatch1]
and similar for trainBatch0
. Now trainBatch1[3]
has exactly the shape you want it to have at line 236, so last_state = trainBatch1[3][:, -1, :, :, :]
will do. You can still trainBatch1[i].reshape([batch_size * trace_length, ...])
if you need the batch and time axes flattened, and this will correctly reshape back to [batch_size, trace_length, ...]
.
Thank you for finding and filing this - looks like a bad bug :( Please submit a PR for the fix when you have a chance. Thanks a lot!