pytorch-CortexNet icon indicating copy to clipboard operation
pytorch-CortexNet copied to clipboard

Does the prednet accept batches?

Open Sahaj09 opened this issue 4 years ago • 1 comments

Does the prednet accept batches during train/test?

The input is given as-

input_sequence = Variable(torch.rand(T, 1, 1, 4 * 2 ** L, 6 * 2 ** L))

I assumed (time-step, batch size, channels, length, breadth) is the input format. Am I wrong?

Sahaj09 avatar Aug 27 '20 23:08 Sahaj09

Both MatchNet and TempoNet expect one element at the time.

for t in range(0, min(args.big_t, x.size(0)) - 1):
    ce_loss, mse_loss, state, x_hat_data = compute_loss(x[t], x[t + 1], y[t], state)
def compute_loss(x_, next_x, y_, state_):
    (x_hat, state_), (_, idx) = model(V(x_), state_)
    ...
    return ce_loss_, mse_loss_, state_, x_hat.data

Atcold avatar Aug 30 '20 08:08 Atcold