Jay Mody
Jay Mody
I would maybe increase the number of shards from 18 to maybe 32. Then run create_pretraining_data.py individually on each shard one by one (make a quick script to automate this).
Odd, worked for me using python `3.7.5` with yapf `0.29.0` and `0.30.0`. Tested this with the following code and command: ### script.py ```python SEARCH_FIELD="hello" FIELD_PATTERN: str = fr'^({SEARCH_FIELD}|{SEARCH_FIELD}[ \t]+[a-z0-9])$' ```...
Are there any plans on adding support for named arguments for `vmap` when using `in_axes` and/or `out_axes`? Right now, in one way or another, I have to use a unideal...
Ah, so I'm realizing it's because `jnp.arange` by default returns an array of type int. If I change it to `print(foo(jnp.arange(10)*1.0, jnp.arange(10)*1.0))` I no longer get an error. Wondering if...
Yeah, that's the workaround I'm using as well to check the shapes and types if an error comes up. Maybe it's worth documenting this in `API.md`? I missed #6 in...
Curious, I would've expected `jax` to be faster given that it executes asynchronously (which should effectively make this line `out_heads = [attention(q, k, v, causal_mask) for q, k, v in...
Thanks for the PR! I'm going to leave this unmerged, want to keep the repo as minimal as possible (while the dockerfile is minimal, I'd need to document it in...
Thanks for the implementation! What kind of speedups did you get with this and did you get an identical output to the non-kv cache version? Just FYI, I'm going to...