functorch icon indicating copy to clipboard operation
functorch copied to clipboard

vmap does not warmup jitted function

Open msaroufim opened this issue 2 years ago • 4 comments

Looking at a lot of jax code they seem to freely mix together jit and vmap for efficient vectorization.

So I tried to do the same in functorch but found that using vmap on a function with @torch.jit.script is about 20x slower than not using it. I have a simple repro below and that 20x different is removed if I manually call the torch.jit.script function. I guess I would expect vmap to do that for me instead of having me warmup all the jitted functions.

All the experiments below are on CPU

import torch
from functorch import vmap
import time 

@torch.jit.script
def double(x):
    return x * 2

def non_jit_double(x):
    return x * 2

x = torch.randn(100000)

# need to warmup the function that will be jitted otherwise 20x performance degredation
y = double(torch.randn(1000))

# this warmum has no impact
y = vmap(double)(x)


tic  = time.time()
y = vmap(double)(x)
toc = time.time()
duration = toc - tic 
print(duration)

# warmup has no impact
y2 = vmap(non_jit_double)(x)

tic = time.time()
y2 = vmap(non_jit_double)(x)
toc = time.time()
duration = toc - tic 
print(duration)

EDIT: I called vmap more than once in a for loop and this effect went away but it's still less than ideal

msaroufim avatar Apr 16 '22 22:04 msaroufim

The interesting part is jit gets optimized progressively (albeit in an unstable way) and tops at about 10 iterations and not 1 :confounded: - I don't know enough about JIT but my best guess is there's some heuristic to decide when it's worth optimizing a function

from functorch import vmap
import time 

@torch.jit.script
def double(x):
    return x * 2

def non_jit_double(x):
    return x * 2

x = torch.randn(100000)

for i in range(20):
    tic = time.time()
    y = vmap(double)(x)
    toc = time.time()
    duration = toc - tic
    print(f"iteration: {i}, duration: {duration}")
iteration: 0, duration: 0.002549886703491211
iteration: 1, duration: 0.000408172607421875
iteration: 2, duration: 0.0001251697540283203
iteration: 3, duration: 0.00011110305786132812
iteration: 4, duration: 0.0001201629638671875
iteration: 5, duration: 0.00012683868408203125
iteration: 6, duration: 0.00011706352233886719
iteration: 7, duration: 0.00013756752014160156
iteration: 8, duration: 0.00010704994201660156
iteration: 9, duration: 0.00010800361633300781
iteration: 10, duration: 5.984306335449219e-05
iteration: 11, duration: 5.817413330078125e-05
iteration: 12, duration: 8.368492126464844e-05
iteration: 13, duration: 9.584426879882812e-05
iteration: 14, duration: 8.273124694824219e-05
iteration: 15, duration: 7.843971252441406e-05
iteration: 16, duration: 7.653236389160156e-05
iteration: 17, duration: 7.557868957519531e-05
iteration: 18, duration: 7.677078247070312e-05
iteration: 19, duration: 7.62939453125e-05

Whereas JAX optimizes a function after the first iteration in a stable way

import torch
# from functorch import vmap
import time 
import jax
from jax import vmap, jit
key = jax.random.PRNGKey(0)

@jit
def double(x):
    return x * 2

def non_jit_double(x):
    return x * 2

x = jax.random.normal(key=key, shape=[10000])

for i in range(20):
    tic = time.time()
    y = vmap(double)(x)
    toc = time.time()
    duration = toc - tic
    print(f"iteration: {i}, duration: {duration}")
iteration: 0, duration: 0.01475214958190918
iteration: 1, duration: 0.00033783912658691406
iteration: 2, duration: 0.0002703666687011719
iteration: 3, duration: 0.00025200843811035156
iteration: 4, duration: 0.0002467632293701172
iteration: 5, duration: 0.00024366378784179688
iteration: 6, duration: 0.0002410411834716797
iteration: 7, duration: 0.00023984909057617188
iteration: 8, duration: 0.00023937225341796875
iteration: 9, duration: 0.00023746490478515625
iteration: 10, duration: 0.00023555755615234375
iteration: 11, duration: 0.0002353191375732422
iteration: 12, duration: 0.00023317337036132812
iteration: 13, duration: 0.00023245811462402344
iteration: 14, duration: 0.00022983551025390625
iteration: 15, duration: 0.0002295970916748047
iteration: 16, duration: 0.0002300739288330078
iteration: 17, duration: 0.00022983551025390625
iteration: 18, duration: 0.00023436546325683594
iteration: 19, duration: 0.0002300739288330078

cc: @Chillee @eellison

msaroufim avatar Apr 16 '22 22:04 msaroufim

Mmmm, so... this is probably the wrong way to compile through vmap :P It might technically work today, but it's very much not intended usage.

In general we cannot (currently) freely mix vmap and jit (or AOTAutograd). It's a WIP but currently doesn't quite work. aot_function(vmap(f)) does (mostly work), but not the other way around.

Chillee avatar Apr 18 '22 21:04 Chillee

aot_function(vmap(f)) does (mostly work), but not the other way around.

I think this could work, I believe jax recommendation is to jit the vmap and not vmap the jit if I'm not mistaken

Regardless maybe aot_function needs its own decorator @aot or compile or what have you

That is of course if the intent is to freely mix vmap and aot or if vmap will be optimized by default

msaroufim avatar Apr 18 '22 21:04 msaroufim

aot_function(vmap(f)) does (mostly work), but not the other way around.

I think this could work, I believe jax recommendation is to jit the vmap and not vmap the jit if I'm not mistaken

aot_function(vmap(f)) is jitting the vmap, right?

EDIT: Yes, that's what you were saying Mark, I didn't read carefully enough

zou3519 avatar Apr 19 '22 14:04 zou3519