equinox icon indicating copy to clipboard operation
equinox copied to clipboard

eqx.nn.Conv2d slower than torch.nn.Conv2d

Open lowlorenz opened this issue 4 months ago • 1 comments

Hey there!

@patrick-kidger Thank you for providing this amazing library! Highly appreciated.

When comparing the speed of eqx.nn.Conv2d and torch.nn.Conv2d I was a surprised to find that the jitted version of eqx is faster than the unjitted version of torch. (Also holds for a jitted version of torch's conv filter, when jitted with torch.jit.script(t_conv) and a unjitted version of equinox).

I provided a minimal example below.

Any thoughts on this? Is there an error in my measurement?

Thank you for your feedback!

import equinox as eqx
import torch
from torch import nn
import time
import jax

key = jax.random.PRNGKey(0)


# record time for the jax model 
elapsed_time = 0

j_conv = eqx.nn.Conv2d(in_channels=1, out_channels=3, kernel_size=4, stride=1, padding=0, key=key)
j_conv = eqx.filter_jit(j_conv)

inp = jax.random.uniform(key, (1,28,28)) 
j_conv(inp).block_until_ready()

for _ in range(10_000):    
    inp = jax.random.uniform(key, (1,28,28))  
    start = time.time()
    
    j_conv(inp).block_until_ready() 
     
    elapsed_time += time.time() - start
        

print(f'{elapsed_time:.5f} seconds for 10_000 eqx.nn.conv2d forward passes')
        
# record time for the torch model
elapsed_time = 0

t_conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=4, stride=1, padding=0).cuda()

for _ in range(10_000):
    inp = torch.randn(1,1,28,28).cuda()    
    start = time.time()
    
    torch.cuda.synchronize() 
    t_conv(inp)
    torch.cuda.synchronize() 
    
    elapsed_time += time.time() - start

print(f'{elapsed_time:.5f} seconds for 10_000 torch.nn.conv2d forward passes')

Results in

2.27325 seconds for 10_000 eqx.nn.conv2d forward passes
0.67587 seconds for 10_000 torch.nn.conv2d forward passes

Versions:

equinox: 0.11.3
torch: 2.2.1+cu118
jax: 0.4.24

GPU: NVIDIA TITAN RTX

lowlorenz avatar Feb 24 '24 10:02 lowlorenz