jax
jax copied to clipboard
JAX remat on colab TPU doesn't reduce memory usage
Steps to reproduce
https://colab.research.google.com/drive/1_XbQ_oq8APHmlJT-1oo_tU4CwCHj4BNZ?usp=sharing
Jax remat uses lots of memory when compiled with JIT. Without JIT it works fine.
class SimpleLayer(nn.Module):
@nn.remat #should decrease memory usage
@nn.compact
def __call__(self, x):
residual = x
x = nn.LayerNorm()(x)
x = nn.SelfAttention(
num_heads=8,
dtype=jnp.float32,
qkv_features=512,
deterministic=True,
)(x)
return x + residual
class SimpleModel(nn.Module):
depth: int = 24
@nn.compact
def __call__(self, x):
for i in range(self.depth):
x = SimpleLayer()(x)
return x
def get_model(input_shape):
key = random.PRNGKey(42)
return SimpleModel().init(key, jnp.ones(input_shape, dtype=jnp.float32))["params"]
def loss_func(params, inputs, targets):
preds = SimpleModel().apply({"params": params}, inputs)
return jnp.mean((preds - targets)**2)
@jax.jit
def fake_train_step(params, inputs, targets):
grad_fn = jax.grad(loss_func)
return grad_fn(params, inputs, targets)
params = get_model((128, 128, 512))
grad_fn = jax.grad(loss_func)
inputs = random.normal(random.PRNGKey(42), (128, 128, 512), jnp.float32)
targets = random.normal(random.PRNGKey(43), (128, 128, 512), jnp.float32)
grads = grad_fn(params, inputs, targets) # this works fine
grads = fake_train_step(params, inputs, targets) # this results in OOM
Linked discussion in Flax repo https://github.com/google/flax/issues/1285