dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

hk.Conv2DTranspose takes FOREVER to initialize and compile

Open sokrypton opened this issue 10 months ago • 1 comments

Not sure if this is a jax thing or dm-haiku... but recently I've been trying to use Conv2DTranspose in my model, and even for very simple case... it takes forever to compile.

here is an example:

import haiku as hk
import jax
from jax import random
import time

def toy_model(x):
  x = hk.Conv2DTranspose(32, 32, stride=16, padding="VALID")(x)
  return x

# Transform the model to be JAX-compatible
toy_model_init = hk.transform(toy_model).init
toy_model_apply = hk.transform(toy_model).apply

# Generate random input and params
key = random.PRNGKey(42)
x = random.normal(key, (1, 8, 8, 128))

# Time the model initialization
start_time = time.time()
params = toy_model_init(key, x)
end_time = time.time()
print(f"initialization Time: {end_time - start_time:.6f} seconds")

# Time the model compilation
start_time = time.time()
compiled_apply = jax.jit(toy_model_apply)
# Warm-up call (this compiles the function)
_ = compiled_apply(params, None, x)
end_time = time.time()
print(f"Compilation Time: {end_time - start_time:.6f} seconds")

# Time the model run
start_time = time.time()
o = compiled_apply(params, None, x)
print("input_shape",x.shape)
print("output_shape",o.shape)
end_time = time.time()
print(f"Run Time: {end_time - start_time:.6f} seconds")

output

initialization Time: 251.865844 seconds
Compilation Time: 255.010969 seconds
input_shape (1, 8, 8, 128)
output_shape (1, 144, 144, 32)
Run Time: 0.000671 seconds

for comparison, in pytorch:

Initialization Time: 0.033582 seconds
input_shape torch.Size([1, 128, 8, 8])
output_shape torch.Size([1, 32, 144, 144])
Run Time: 0.047478 seconds

Google colab notebook replicating the test: https://colab.research.google.com/drive/15YkOuK0EjqZdBNaXpF2wpYexGqtjZjLr

sokrypton avatar Sep 06 '23 14:09 sokrypton