Saurabh Shah
Saurabh Shah
Trying torch scripting and applying the rotations in the complex plane instead of R²
Logging for activations for all modules. Updates (@epwalsh): For each module, we log the activation L2 norm, average, absolute min, and absolute max. The are reduced over all ranks. Note...
### ❓ The question I tried implementing [grouped query attention](https://arxiv.org/pdf/2305.13245.pdf) in [this pull request](https://github.com/allenai/LLM/pull/241), but seems that pytorch's `scaled_dot_product_attention` doesn't support the kind of broadcasting we'd need for this. Revisit...
I tried implementing [grouped query attention](https://arxiv.org/pdf/2305.13245.pdf) in this PR, but seems that Pytorch's `scaled_dot_product_attention` doesn't support the kind of broadcasting we'd need for this. Revisit if/when this gets fixed on...
refactor to allow for LM-judge based "verifer" - general `VeriferConfig` so that we can configure verifers per training run (builder pattern). Not much in there now but easily extendable to...
> [!NOTE] > Ensure `checkpoint_state_dir` is rewritten under `/filestore` when `gs_bucket_path` is set and the path isn’t already in filestore. > > - **Checkpointing**: > - When `gs_bucket_path` is provided...
when the ref policy is corrupted or fails to materialize, the resuming a job crashes. This fixes that
If you don't have this, you'll get a idx error during this line (2006 of grpo_fast.py): `collated_vllm_logprobs.append( collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0) )` this only can occur when you...