Alex McKinney

Results 71 comments of Alex McKinney

> True it'd make sense to allow this! Could you maybe open a seperate issue for it? I'll try to have a look into it soon :-) See #1141 for...

A broad statement would be helpful, no need to go into great detail.

Makes sense, actually this is just a small reproducing example. I found the error initially trying to use `jax.lax.associative_scan` within a kernel. Do you think `pl.load` could be used to...

Another big +1 here, would really love to use Pallas but I cannot find a correct set of commands to install it correctly.

Thanks @sharadmv! Appreciate the quick workaround. Here is a copy pasteable version (note, you need to be using python 3.9 or 3.10): ```shell pip install --no-deps -IU --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==2.1.0.post20231216005823...

When running the Hello World example in [the Quickstart](https://jax.readthedocs.io/en/latest/pallas/quickstart.html). This is in a fresh environment, followed by the block above, with CUDA12.3.1-2 locally ``` 2024-01-04 21:03:19.992205: W external/xla/xla/service/gpu/command_buffer_scheduling.cc:470] Removed command...

Sorry, I missed a warning during the install: ``` jax-triton 0.1.4 requires absl-py>=1.4.0, which is not installed. ``` Installing the specified package, then reinstalling the packages worked :)

I think it depends on your environment and whether you have that package preinstalled or not. I think adding `absl-py` to the `jax-triton` install line should be sufficient, but I...

Hi @gianlucadetommaso, I haven't had the time to work on this since this draft PR went live, but I am blocking time out this weekend to continue.

Thanks @sanchit-gandhi I found a little time to continue today. One issue I am noticing is that the tolerance when comparing the ground truth PyTorch implementation (in `modeling_llama.py`) and my...