algorithmic-efficiency icon indicating copy to clipboard operation
algorithmic-efficiency copied to clipboard

Speech workloads speed regression on JAX

Open priyakasimbeg opened this issue 7 months ago • 0 comments

Speech workloads appear to be ~5x slower in update then before. Happens with pmap and jit.

Steps to Reproduce

in container run

 python submission_runner.py --framework=jax --workload=librispeech_deepspeech --submission_path=reference_algorithms/qualification_baselines/external_tuning/jax_nadamw_target_setting.py --data_dir=/data/librispeech --num_tuning_trials=1 --experiment_dir=/experiment_runs --experiment_name=tests/regression_tests/adamw --overwrite=True --save_checkpoints=False --max_global_steps=10 --librispeech_tokenizer_vocab_path=/data/librispeech/spm_model.vocab --tuning_ruleset=external --tuning_search_space=reference_algorithms/qualification_baselines/external_tuning/tuning_search_space.json

Source or Possible Fix

Maybe a package update is resulting in a compilation difference?

Suspicious message in logs

2025-07-15 03:06:29.349878: E external/xla/xla/service/slow_operation_alarm.cc:73] Constant folding an instruction is taking > 1s:

  %reduce-window.10 = f32[256,500]{1,0} reduce-window(%broadcast.248, %constant.119), window={size=1x500 pad=0_0x499_0}, to_apply=%region_35.2517.clone, metadata={op_name="jit(_eval_step)/jit(cumsum)/LibriSpeechConformerWorkload.sequence_mask/reduce_window_sum" source_file="/algorithmic-efficiency/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py" source_line=244}

priyakasimbeg avatar Jul 15 '25 03:07 priyakasimbeg