axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

Use Accuracy from cross_entropy in causal_lm.py::CrossEntropyLossMetrics

Open apivovarov opened this issue 10 months ago • 2 comments

PR Description:

This PR updates the accuracy calculation in causal_lm.py::CrossEntropyLossMetrics to use the value stored in loss_dict["accuracy"] instead of recomputing it. The key changes include:

  • Removing redundant accuracy computation in causal_lm.py.
  • Ensuring loss_dict["accuracy"] is used directly in CrossEntropyLossMetrics.
  • Updating the documentation in loss.py to clearly describe the accuracy computation method.
  • These changes improve code clarity and maintain consistency across the loss computation functions.

The following models already use loss_dict["accuracy"] from loss.py::cross_entropy() function

  • axlearn/audio/decoder_asr.py::LASDecoderModel::forward
  • axlearn/common/encoder_decoder.py::EncoderDecoderModel::forward (via _metrics)

Testing

Test command uses config fuji-test-v3, jax_backend is gpu

mkdir -p /tmp/gpt_c4_test
python3 -m axlearn.common.launch_trainer_main \
  --module=text.gpt.c4_trainer \
  --config=fuji-test-v3 \
  --trainer_dir=/tmp/gpt_c4_test\
  --data_dir=gs://axlearn-public/tensorflow_datasets \
  --jax_backend=gpu

tensorboard --logdir=train_train

  • Accuracy before: https://ibb.co/kVj6ss9b
  • Accuracy after: https://ibb.co/GQn2wfH4

pre-commit

$ pre-commit run --files $(git diff --name-only HEAD~1)
Check Yaml...........................................(no files to check)Skipped
Fix End of Files.........................................................Passed
Trim Trailing Whitespace.................................................Passed
black....................................................................Passed
isort....................................................................Passed
pylint...................................................................Passed

pytype

$ pytype -j auto $(git diff --name-only HEAD~1) 
Success: no errors found

apivovarov avatar Feb 12 '25 01:02 apivovarov

Hey Mark, thank you for reviewing and approving this PR! Just wanted to check if there's anything else needed to merge it. @markblee

apivovarov avatar Feb 14 '25 19:02 apivovarov

I'll need to run some internal tests before I can merge it -- hopefully soon!

markblee avatar Feb 14 '25 22:02 markblee

Should we merge this PR? Thanks.

ruomingp avatar May 21 '25 13:05 ruomingp

@markblee @ruomingp Thank you for you help with this PR!

apivovarov avatar May 22 '25 21:05 apivovarov

@apoorvtintin FYI

apivovarov avatar May 22 '25 21:05 apivovarov