equinox
equinox copied to clipboard
eqx.nn.Conv2d slower than torch.nn.Conv2d
Hey there!
@patrick-kidger Thank you for providing this amazing library! Highly appreciated.
When comparing the speed of eqx.nn.Conv2d
and torch.nn.Conv2d
I was a surprised to find that the jitted version of eqx is faster than the unjitted version of torch. (Also holds for a jitted version of torch's conv filter, when jitted with torch.jit.script(t_conv)
and a unjitted version of equinox).
I provided a minimal example below.
Any thoughts on this? Is there an error in my measurement?
Thank you for your feedback!
import equinox as eqx
import torch
from torch import nn
import time
import jax
key = jax.random.PRNGKey(0)
# record time for the jax model
elapsed_time = 0
j_conv = eqx.nn.Conv2d(in_channels=1, out_channels=3, kernel_size=4, stride=1, padding=0, key=key)
j_conv = eqx.filter_jit(j_conv)
inp = jax.random.uniform(key, (1,28,28))
j_conv(inp).block_until_ready()
for _ in range(10_000):
inp = jax.random.uniform(key, (1,28,28))
start = time.time()
j_conv(inp).block_until_ready()
elapsed_time += time.time() - start
print(f'{elapsed_time:.5f} seconds for 10_000 eqx.nn.conv2d forward passes')
# record time for the torch model
elapsed_time = 0
t_conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=4, stride=1, padding=0).cuda()
for _ in range(10_000):
inp = torch.randn(1,1,28,28).cuda()
start = time.time()
torch.cuda.synchronize()
t_conv(inp)
torch.cuda.synchronize()
elapsed_time += time.time() - start
print(f'{elapsed_time:.5f} seconds for 10_000 torch.nn.conv2d forward passes')
Results in
2.27325 seconds for 10_000 eqx.nn.conv2d forward passes
0.67587 seconds for 10_000 torch.nn.conv2d forward passes
Versions:
equinox: 0.11.3
torch: 2.2.1+cu118
jax: 0.4.24
GPU: NVIDIA TITAN RTX