flax
flax copied to clipboard
Slow training on TPU
Does following this https://flax.readthedocs.io/en/latest/howtos/ensembling.html train the model on a GPU/TPU if you are connected to one on Google Colab? I have these functions as per the docs above but the training seems to be very just slow just like in CPU
jax.devices()
# [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
# ...
# TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, images, labels):
def loss_fn(params):
logits = CNN().apply({'params': params}, images)
one_hot = jax.nn.one_hot(labels, 2)
loss = optax.sigmoid_binary_cross_entropy(logits=logits, labels=one_hot).mean()
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble')
accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
return grads,loss, accuracy
@jax.pmap
def update_model(state, grads):
return state.apply_gradients(grads=grads)
@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def create_train_state(rng, learning_rate, momentum):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, size_image, size_image, 3]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
def train_one_epoch(state, dataloader):
"""Train for 1 epoch on the training set."""
epoch_loss = []
epoch_accuracy = []
for cnt, (images, labels) in enumerate(dataloader):
images = images / 255.0
images = jax_utils.replicate(images)
labels = jax_utils.replicate(labels)
grads, loss, accuracy = apply_model(state, images, labels)
state = update_model(state, grads)
epoch_loss.append(jax_utils.unreplicate(loss))
epoch_accuracy.append(jax_utils.unreplicate(accuracy))
train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
return state, train_loss, train_accuracy
for epoch in range(1, num_epochs + 1):
state, train_loss, train_accuracy = train_one_epoch(state, train_loader)
training_loss.append(train_loss)
training_accuracy.append(train_accuracy)
print(f"Train epoch: {epoch}, loss: {train_loss}, accuracy: {train_accuracy * 100}")
_, test_loss, test_accuracy = jax_utils.unreplicate(apply_model(state, test_images, test_labels))
testing_accuracy.append(test_accuracy)
testing_loss.append(test_loss)
print(f"Test epoch: {epoch}, loss: {test_loss}, accuracy: {test_accuracy* 100}")
What am I missing?
can you confirm that you called
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
?
(training models with lots of input data on Colab TPUs is unfortunately not as fast as it could be because Colab uses the legacy "TPU Node" setup, but it should for sure be a lot faster than running on CPU)
Yes I did
When you say "training seems to be very just slow just like in CPU", can you quantify that? (and contrast it to GPU)
It means that I am not seeing any difference, between that TPU setup and when I was using CPU. Haven't done any benchmarks, happy to close this.