functorch icon indicating copy to clipboard operation
functorch copied to clipboard

[BUG]: AOT example does not work

Open hypnopump opened this issue 3 years ago • 4 comments

Here's a code snippet using the example from here:

import torch as th
import functorch as fth 

def fn(input, bias, residual, p: float):
    a = th.add(input, bias)
    b = th.nn.functional.dropout(a, p, training=True)
    c = b + residual
    return c

aot_fn = fth.compile.aot_function(fn, print_compile_fn, static_argnums=(3,))


aot_fn(th.randn(4, ), th.randn(4,), th.randn(4,), p=0.3)

which produces the following error on functorch==0.1.1 :

RuntimeError                              Traceback (most recent call last)
Input In [59], in <cell line: 10>()
      5     return c
      7 aot_fn = fth.compile.aot_function(fn, print_compile_fn, static_argnums=(3,))
---> 10 aot_fn(th.randn(4, ), th.randn(4,), th.randn(4,), p=0.3)

File ~/miniconda3/envs/charm/lib/python3.8/site-packages/functorch/_src/aot_autograd.py:389, in aot_function.<locals>.returned_function(*args, **kwargs)
    387 num_tensor_args = len(flat_tensor_args)
    388 flat_args_for_cache = flat_tensor_args + static_args_hashed
--> 389 cached_res = compile_cache.at(
    390     fn_id,
    391     fw_compiler_id,
    392     bw_compiler_id,
    393     num_tensor_args,
    394     hasher_type,
    395     *flat_args_for_cache,
    396 )
    398 # Compile the function and save it in the cache
    399 if cached_res is None:
    400     # Save the args_spec for flat_tensor_args to unflatten while tracing

RuntimeError: Found an argument of type float at index 3. Non-tensor arguments must be marked static. Please set the static_argnums correctly to mark the argument at index 3 static

Any help on how to proceed? Any quick fix that could work?

Thanks in advance for this awesome work!

hypnopump avatar Apr 25 '22 12:04 hypnopump

cc @Chillee

vfdev-5 avatar Apr 25 '22 12:04 vfdev-5

Unfortunately, we don't support kwargs properly right now with AOTAutograd (particularly with static_argnums).

So, it works if you remove the p=0.3 from the function.

Chillee avatar Apr 27 '22 08:04 Chillee

I see... whats a reasonable timeframe in which this might get implemented?

hypnopump avatar May 03 '22 15:05 hypnopump

@lucidrains so i tried it for the memory-efficient attention but couldn't get it to work faster than the original (the difference was so small it wasn't justified...) so i guess we might want to try triton? tbh, i think it would be worth looking at how the jax compiler imoplements it under the hood, because it seems to take no hit in runtime when compared to torch's einsum.

hypnopump avatar May 04 '22 10:05 hypnopump