flax icon indicating copy to clipboard operation
flax copied to clipboard

Slow training on TPU

Open mwitiderrick opened this issue 3 years ago • 4 comments

Does following this https://flax.readthedocs.io/en/latest/howtos/ensembling.html train the model on a GPU/TPU if you are connected to one on Google Colab? I have these functions as per the docs above but the training seems to be very just slow just like in CPU

jax.devices()

# [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
# ...
# TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, images, labels):
  def loss_fn(params):
    logits = CNN().apply({'params': params}, images)
    one_hot = jax.nn.one_hot(labels, 2)
    loss = optax.sigmoid_binary_cross_entropy(logits=logits, labels=one_hot).mean()
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble')
  accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
  return grads,loss, accuracy

@jax.pmap
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, size_image, size_image, 3]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

def train_one_epoch(state, dataloader):
    """Train for 1 epoch on the training set."""
    epoch_loss = []
    epoch_accuracy = []
    for cnt, (images, labels) in enumerate(dataloader):
        images = images / 255.0
        images = jax_utils.replicate(images)
        labels = jax_utils.replicate(labels)
        grads, loss, accuracy = apply_model(state, images, labels)
        state = update_model(state, grads)
    epoch_loss.append(jax_utils.unreplicate(loss))
    epoch_accuracy.append(jax_utils.unreplicate(accuracy))
    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)
    return state, train_loss, train_accuracy



for epoch in range(1, num_epochs + 1):
    state, train_loss, train_accuracy = train_one_epoch(state, train_loader)
    training_loss.append(train_loss)
    training_accuracy.append(train_accuracy)
    print(f"Train epoch: {epoch}, loss: {train_loss}, accuracy: {train_accuracy * 100}")

    _, test_loss, test_accuracy = jax_utils.unreplicate(apply_model(state, test_images, test_labels))
    testing_accuracy.append(test_accuracy)
    testing_loss.append(test_loss)
    print(f"Test epoch: {epoch}, loss: {test_loss}, accuracy: {test_accuracy* 100}")


What am I missing?

mwitiderrick avatar Jul 02 '22 18:07 mwitiderrick

can you confirm that you called

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

?

(training models with lots of input data on Colab TPUs is unfortunately not as fast as it could be because Colab uses the legacy "TPU Node" setup, but it should for sure be a lot faster than running on CPU)

andsteing avatar Jul 04 '22 06:07 andsteing

Yes I did

mwitiderrick avatar Jul 04 '22 07:07 mwitiderrick

When you say "training seems to be very just slow just like in CPU", can you quantify that? (and contrast it to GPU)

andsteing avatar Jul 04 '22 07:07 andsteing

It means that I am not seeing any difference, between that TPU setup and when I was using CPU. Haven't done any benchmarks, happy to close this.

mwitiderrick avatar Jul 04 '22 07:07 mwitiderrick