dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

rnn classifies mnist

Open never-to-never opened this issue 1 year ago • 3 comments

test.txt I use LSTM to classify mnist data and find that the loss of the network cannot converge at all. Is the RNN given by the framework correct? I give the script that runs

never-to-never avatar May 04 '23 08:05 never-to-never

  1. Please paste your code as opposed to attaching it as a file, especially if the code is short.

  2. Why are you using an RNN?

  3. From looking at your code, it doesn't seem like you're really using the time component. Are you sure that in your preprocessing you're replicating the data over the time axis? A

IanQS avatar May 04 '23 16:05 IanQS

import haiku as hk
import jax
import jax.numpy as jnp
import optax
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

sequence_length = 28
input_size = 28
hidden_size = 128
num_classes = 10
batch_size = 128
num_epochs = 30
learning_rate = 0.001

train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor(), download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

def unroll_net(seqs: jax.Array):
    core = hk.LSTM(128)
    batch_size = seqs.shape[1]
    outs, state = hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))
    return hk.Linear(10)(outs[-1]), state

model = hk.transform(unroll_net)

rng = jax.random.PRNGKey(428)
opt = optax.adam(1e-3)

@jax.jit
def loss(params, x, y):
  pred, _ = model.apply(params, None, x)
  return jnp.mean(jnp.square(pred - y))

@jax.jit
def accuracy(predict, target):
    return jnp.sum(jnp.argmax(predict, axis=1) == jnp.argmax(target, axis=1))

@jax.jit
def update(step, params, opt_state, x, y):
    l, grads = jax.value_and_grad(loss)(params, x, y)
    grads, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, grads)
    return l, params, opt_state

train_ds = iter(train_loader)
valid_ds = iter(test_loader)
sample_x, _ = next(train_ds)
sample_x = sample_x.reshape(sequence_length, -1, input_size)
sample_x = jnp.asarray(sample_x)
params = model.init(rng, sample_x)
opt_state = opt.init(params)
length = len(train_ds)

for step in range(length-1):
    if step % 10 == 0:
        x, y = next(valid_ds)
        x = x.reshape(sequence_length, -1, input_size)
        x = jnp.asarray(x)
        y = jnp.asarray(y)
        y = jnp.array(y[:, None] == jnp.arange(10), jnp.float32)
        print("Step {}: valid loss {}".format(step, loss(params, x, y)))
    x, y = next(train_ds)
    x = x.reshape(sequence_length, -1, input_size)
    x = jnp.asarray(x)
    y = jnp.asarray(y)
    y = jnp.array(y[:, None] == jnp.arange(10), jnp.float32)
    train_loss, params, opt_state = update(step, params, opt_state, x, y)
    if step % 10 == 0:
        print("Step {}: train loss {}".format(step, train_loss))

Here is the full code.

never-to-never avatar May 05 '23 00:05 never-to-never

test.txt I use LSTM to classify mnist data and find that the loss of the network cannot converge at all. Is the RNN given by the framework correct? I give the script that runs

The error in the code is likely due to the mismatch between PyTorch tensors and JAX arrays.

The train_loader and test_loader provide PyTorch tensors, while the model and loss functions expect JAX arrays. You need to convert the PyTorch tensors to JAX arrays before passing them to the model and loss functions. Use

jnp.array()

to convert PyTorch tensors to JAX arrays.

Ekundayo39283 avatar Apr 12 '24 17:04 Ekundayo39283