Wilson Yan
Wilson Yan
If I remove the model parallelism component in the `pjit` portion of the code, i.e. only `['dp']` ```python mesh = Mesh(np.asarray(jax.devices(), dtype=object).reshape(jax.local_device_count(),), ['dp']) jax.experimental.maps.thread_resources.env = ( jax.experimental.maps.ResourceEnv(physical_mesh=mesh, loops=()) ) p_step...
I tried running similar code on some larger models, and get similar effects. This is also done **with only data parallelism (no model axis in mesh)** Code is run on...
Thanks for looking into it! Is there a good way to prevent this issue from happening code-wise? i.e. differently coding the architecture or enforcing certain constraints to help the partitioner...
Sorry about that, I'll spend some time this coming weekend to write some more descriptions. I can also include the dataset generation script. In general, it's just downloading [pg19](https://huggingface.co/datasets/pg19) and...
Hi, thanks for your interest! A pytorch version is on the roadmap, but may take a some time since both of us are rather occupied with other things at the...
If using vLLM for inference (PyTorch model, FP16), I believe we used: - 1 80GB A100 for 32K - 2 80GB A100s for 128K - 4 80GB A100s for 256K...
The `mesh_dim` argument depends on the number of devices you're using for inference. If you want to do tensor parallelism over 8 gpus, then `mesh_dim` should be `1,1,8,1`. The default...
Thanks for your interest. We don't have plans to train a smaller model at the moment
In general, we didn't run into too much memory bottlenecks for our needs, so we primarily just stuck with `fp32` to be safe (proper mixed precision training with `bf16` requires...
I don't think your GPU has enough memory, as by itself a 7B model with `fp32` would be 28GB.