functorch
functorch copied to clipboard
vmap does not warmup jitted function
Looking at a lot of jax code they seem to freely mix together jit
and vmap
for efficient vectorization.
So I tried to do the same in functorch but found that using vmap on a function with @torch.jit.script
is about 20x slower than not using it. I have a simple repro below and that 20x different is removed if I manually call the torch.jit.script function. I guess I would expect vmap
to do that for me instead of having me warmup all the jitted functions.
All the experiments below are on CPU
import torch
from functorch import vmap
import time
@torch.jit.script
def double(x):
return x * 2
def non_jit_double(x):
return x * 2
x = torch.randn(100000)
# need to warmup the function that will be jitted otherwise 20x performance degredation
y = double(torch.randn(1000))
# this warmum has no impact
y = vmap(double)(x)
tic = time.time()
y = vmap(double)(x)
toc = time.time()
duration = toc - tic
print(duration)
# warmup has no impact
y2 = vmap(non_jit_double)(x)
tic = time.time()
y2 = vmap(non_jit_double)(x)
toc = time.time()
duration = toc - tic
print(duration)
EDIT: I called vmap more than once in a for loop and this effect went away but it's still less than ideal
The interesting part is jit gets optimized progressively (albeit in an unstable way) and tops at about 10 iterations and not 1 :confounded: - I don't know enough about JIT but my best guess is there's some heuristic to decide when it's worth optimizing a function
from functorch import vmap
import time
@torch.jit.script
def double(x):
return x * 2
def non_jit_double(x):
return x * 2
x = torch.randn(100000)
for i in range(20):
tic = time.time()
y = vmap(double)(x)
toc = time.time()
duration = toc - tic
print(f"iteration: {i}, duration: {duration}")
iteration: 0, duration: 0.002549886703491211
iteration: 1, duration: 0.000408172607421875
iteration: 2, duration: 0.0001251697540283203
iteration: 3, duration: 0.00011110305786132812
iteration: 4, duration: 0.0001201629638671875
iteration: 5, duration: 0.00012683868408203125
iteration: 6, duration: 0.00011706352233886719
iteration: 7, duration: 0.00013756752014160156
iteration: 8, duration: 0.00010704994201660156
iteration: 9, duration: 0.00010800361633300781
iteration: 10, duration: 5.984306335449219e-05
iteration: 11, duration: 5.817413330078125e-05
iteration: 12, duration: 8.368492126464844e-05
iteration: 13, duration: 9.584426879882812e-05
iteration: 14, duration: 8.273124694824219e-05
iteration: 15, duration: 7.843971252441406e-05
iteration: 16, duration: 7.653236389160156e-05
iteration: 17, duration: 7.557868957519531e-05
iteration: 18, duration: 7.677078247070312e-05
iteration: 19, duration: 7.62939453125e-05
Whereas JAX optimizes a function after the first iteration in a stable way
import torch
# from functorch import vmap
import time
import jax
from jax import vmap, jit
key = jax.random.PRNGKey(0)
@jit
def double(x):
return x * 2
def non_jit_double(x):
return x * 2
x = jax.random.normal(key=key, shape=[10000])
for i in range(20):
tic = time.time()
y = vmap(double)(x)
toc = time.time()
duration = toc - tic
print(f"iteration: {i}, duration: {duration}")
iteration: 0, duration: 0.01475214958190918
iteration: 1, duration: 0.00033783912658691406
iteration: 2, duration: 0.0002703666687011719
iteration: 3, duration: 0.00025200843811035156
iteration: 4, duration: 0.0002467632293701172
iteration: 5, duration: 0.00024366378784179688
iteration: 6, duration: 0.0002410411834716797
iteration: 7, duration: 0.00023984909057617188
iteration: 8, duration: 0.00023937225341796875
iteration: 9, duration: 0.00023746490478515625
iteration: 10, duration: 0.00023555755615234375
iteration: 11, duration: 0.0002353191375732422
iteration: 12, duration: 0.00023317337036132812
iteration: 13, duration: 0.00023245811462402344
iteration: 14, duration: 0.00022983551025390625
iteration: 15, duration: 0.0002295970916748047
iteration: 16, duration: 0.0002300739288330078
iteration: 17, duration: 0.00022983551025390625
iteration: 18, duration: 0.00023436546325683594
iteration: 19, duration: 0.0002300739288330078
cc: @Chillee @eellison
Mmmm, so... this is probably the wrong way to compile through vmap :P It might technically work today, but it's very much not intended usage.
In general we cannot (currently) freely mix vmap and jit
(or AOTAutograd). It's a WIP but currently doesn't quite work. aot_function(vmap(f))
does (mostly work), but not the other way around.
aot_function(vmap(f)) does (mostly work), but not the other way around.
I think this could work, I believe jax recommendation is to jit the vmap and not vmap the jit if I'm not mistaken
Regardless maybe aot_function needs its own decorator @aot
or compile
or what have you
That is of course if the intent is to freely mix vmap and aot or if vmap will be optimized by default
aot_function(vmap(f)) does (mostly work), but not the other way around.
I think this could work, I believe jax recommendation is to jit the vmap and not vmap the jit if I'm not mistaken
aot_function(vmap(f)) is jitting the vmap, right?
EDIT: Yes, that's what you were saying Mark, I didn't read carefully enough