Jack Gallagher

Results 32 comments of Jack Gallagher

here's some example code doing this for `FrozenDict`s, extracted + tweaked starting from the flax implementation ```py def _split_gdas( target: Dict[str, Any] ) -> Tuple[Dict[str, Any], List[Tuple[Union[jax.Array, GlobalDeviceArray], str]]]: #...

yeah the only thing I don't like about the `tree_flatten` approach is that all the path information is lost, so if afterwards I want to eg just load the params...

> The key point here is that "pytree structure" is really a notion that only makes sense within the context of a particular Python runtime, and it's probably a mistake...

@harisraharjo jax does in fact import numpy

@David-OConnor where in the code would this feature go? I have enough issues with poetry et al that I'd be up to give it a try, but need some pointers...

can confirm that this error also appears under `jax.lax.scan` example here: ```py q = jax.random.normal(keys[0], (l, b, lq, h, d)) k = jax.random.normal(keys[1], (l, b, lkv, h, d)) v =...

```py from flash_attention_jax import flash_attention ```

@dlwh looks like you also ran an autoformatter so there's a ton of other changes here - can you say a bit more about how you fixed it?

won't that make `flash_attention` always do causal masking? I'm using this in a context where that's not appropriate

so the relevant fix would be to replace https://github.com/lucidrains/flash-attention-jax/blob/e5efb902c89858c2fa15fce2781492a6b36ddeb4/flash_attention_jax/flash_attention.py#L97 with ```py return (out, (row_sum, row_max)), (q, k, v, key_mask, out, row_sum, row_max) ``` ?