[BUG]: AOT example does not work
Here's a code snippet using the example from here:
import torch as th
import functorch as fth
def fn(input, bias, residual, p: float):
a = th.add(input, bias)
b = th.nn.functional.dropout(a, p, training=True)
c = b + residual
return c
aot_fn = fth.compile.aot_function(fn, print_compile_fn, static_argnums=(3,))
aot_fn(th.randn(4, ), th.randn(4,), th.randn(4,), p=0.3)
which produces the following error on functorch==0.1.1 :
RuntimeError Traceback (most recent call last)
Input In [59], in <cell line: 10>()
5 return c
7 aot_fn = fth.compile.aot_function(fn, print_compile_fn, static_argnums=(3,))
---> 10 aot_fn(th.randn(4, ), th.randn(4,), th.randn(4,), p=0.3)
File ~/miniconda3/envs/charm/lib/python3.8/site-packages/functorch/_src/aot_autograd.py:389, in aot_function.<locals>.returned_function(*args, **kwargs)
387 num_tensor_args = len(flat_tensor_args)
388 flat_args_for_cache = flat_tensor_args + static_args_hashed
--> 389 cached_res = compile_cache.at(
390 fn_id,
391 fw_compiler_id,
392 bw_compiler_id,
393 num_tensor_args,
394 hasher_type,
395 *flat_args_for_cache,
396 )
398 # Compile the function and save it in the cache
399 if cached_res is None:
400 # Save the args_spec for flat_tensor_args to unflatten while tracing
RuntimeError: Found an argument of type float at index 3. Non-tensor arguments must be marked static. Please set the static_argnums correctly to mark the argument at index 3 static
Any help on how to proceed? Any quick fix that could work?
Thanks in advance for this awesome work!
cc @Chillee
Unfortunately, we don't support kwargs properly right now with AOTAutograd (particularly with static_argnums).
So, it works if you remove the p=0.3 from the function.
I see... whats a reasonable timeframe in which this might get implemented?
@lucidrains so i tried it for the memory-efficient attention but couldn't get it to work faster than the original (the difference was so small it wasn't justified...) so i guess we might want to try triton? tbh, i think it would be worth looking at how the jax compiler imoplements it under the hood, because it seems to take no hit in runtime when compared to torch's einsum.