jax icon indicating copy to clipboard operation
jax copied to clipboard

JAX remat on colab TPU doesn't reduce memory usage

Open igrekun opened this issue 3 years ago • 2 comments

Steps to reproduce

https://colab.research.google.com/drive/1_XbQ_oq8APHmlJT-1oo_tU4CwCHj4BNZ?usp=sharing

Jax remat uses lots of memory when compiled with JIT. Without JIT it works fine.

class SimpleLayer(nn.Module):
    @nn.remat #should decrease memory usage
    @nn.compact
    def __call__(self, x):
        residual = x
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(
            num_heads=8,
            dtype=jnp.float32,
            qkv_features=512,
            deterministic=True,
        )(x)
        return x + residual

class SimpleModel(nn.Module):
    depth: int = 24
    @nn.compact
    def __call__(self, x):
        for i in range(self.depth):
            x = SimpleLayer()(x)
        return x

def get_model(input_shape):
    key = random.PRNGKey(42)
    return SimpleModel().init(key, jnp.ones(input_shape, dtype=jnp.float32))["params"]

def loss_func(params, inputs, targets):
    preds = SimpleModel().apply({"params": params}, inputs)
    return jnp.mean((preds - targets)**2)

@jax.jit
def fake_train_step(params, inputs, targets):
    grad_fn = jax.grad(loss_func)
    return grad_fn(params, inputs, targets)

params = get_model((128, 128, 512))
grad_fn = jax.grad(loss_func)

inputs = random.normal(random.PRNGKey(42), (128, 128, 512), jnp.float32)
targets = random.normal(random.PRNGKey(43), (128, 128, 512), jnp.float32)

grads = grad_fn(params, inputs, targets) # this works fine
grads = fake_train_step(params, inputs, targets) # this results in OOM

Linked discussion in Flax repo https://github.com/google/flax/issues/1285

igrekun avatar Apr 29 '21 13:04 igrekun