maxtext
maxtext copied to clipboard
[Question] are there some some train replication results?
Hi,
Thanks for the library! I'm new to the JAX+LLM ecosystem and trying to understand which library I should be using.
I see a lot of (very impressive) computational efficiency benchmarks of maxtext but can't find any benchmark in terms of performance. Do you have some perplexity/evals on a model trained with maxtext in standard settings? E.g. nanoGPT on wikitext (evaluated with perplexities or MMLU) or llama finetuning on Vicuna or Alapca data? I think it would be very useful to decide which JAX library to use for training LLMs!
Thank you for your help!
@YannDubs Sorry for the late response here, some travel.
MaxText has historically been focused on the largest customers who were training custom models of their own design. So we've only been focused on making sure there was correctness and prioritizing perf/scalability. But we've assumed pretraining customers would have their own secret sauce regarding convergence, etc. To demonstrate correctly, we verify we can directly reproduce Chinchilla results: https://github.com/google/maxtext/blob/main/end_to_end/test_convergence_1b_params.sh
We've gotten a lot of interest in off-the-shelf models that appeal to different folks so we've been adding support for more models. (Now Gemma, Llama and Mistral.)
We also have high performance inference coming soon.
But I think you're asking for something more. Happy to talk live as well, (rwitten at google.com)
Thanks @rwitten, something like the Chinchilla results was what I was asking about but I was hoping to see the actual training curves and final evaluation results to be able to compare to (1) be able to compare to the original results, and (2) have a reference curve to compare with when modifying the configs/model.
Thanks!
@gobbleturk can you provide?
I've uploaded loss curve data using test_convergence_1b_params.sh here
Here is a screenshot of some learning metrics from that run that we display via tensorboard:
Great, thanks!