transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Running flax run_qa.py got errors after upgrading dependencies software

Open jameszhouyi opened this issue 10 months ago • 4 comments

System Info

Ubuntu 20.04.1 LTS on NV GPU

Who can help?

Hi @sanchit-gandhi Recently I have a running flax example with jax0.4.24, flax0.8.1 and transformer 0.4.38 etc. https://github.com/huggingface/transformers/tree/main/examples/flax/question-answering but always got errors like: batch = {k: np.array(v) for k, v in batch.items()} ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (64, 384) + inhomogeneous part.

Could you please help me have a check for this issue ? Thanks in advanced.

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

  1. conda create -n test python=3.10
  2. pip install jax[cuda12_pip]==0.4.24 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
  3. pip install datasets==2.14.6 flax==0.8.1 transformers==4.38.0 evaluate==0.4.1 optax==0.2.2 chex==0.1.86 numpy==1.26.4
  4. git clone https://github.com/huggingface/transformers.git (commit 9b0a8ea7d1d6226b76cfdc645ce65e21157e2b50)
  5. cd transformers/examples/flax/question-answering
  6. python run_qa.py
    --model_name_or_path google-bert/bert-base-uncased
    --dataset_name squad
    --do_train
    --do_eval
    --max_seq_length 384
    --doc_stride 128
    --learning_rate 3e-5
    --num_train_epochs 2
    --per_device_train_batch_size 12
    --output_dir ./bert-qa-squad
    --eval_steps 1000

Expected behavior

Run successfully

jameszhouyi avatar Apr 02 '24 15:04 jameszhouyi

Hey! we don't maintain the examples for specific version, so it's highly posssible that most recent jax lib or so are not supported. Feel free to update it!

ArthurZucker avatar Apr 05 '24 06:04 ArthurZucker

Hey! we don't maintain the examples for specific version, so it's highly posssible that most recent jax lib or so are not supported. Feel free to update it!

Hi @ArthurZucker , I have checked the error and found it related to numpy deprecated feature. Although it can be run before, it has already reported a warning:

      VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  batch = {k: np.array(v) for k, v in batch.items()}

The model created a numpy.array from ragged nested sequences. It is not a standard numpy operation and deprecated after numpy v1.24.0.

The issue happens in this way:

  1. The ragged nested sequences are generated here: https://github.com/huggingface/transformers/blob/v4.39.0/examples/flax/question-answering/run_qa.py#L733-L736. It creates the sequences mixed with None and normal 2D array elements as below:
[None, None, None, None, None, None, None, None, None, None, None, None, None, [0, 5], [6, 10], [11, 13], ... 
  1. The sequences is used to create numpy.array here: https://github.com/huggingface/transformers/blob/v4.39.0/examples/flax/question-answering/run_qa.py#L436. With this operation, numpy reports a warning before v1.24.0 and a error after v1.24.0.

Zantares avatar Apr 18 '24 06:04 Zantares

The fix PR is submitted: https://github.com/huggingface/transformers/pull/30434

Zantares avatar Apr 23 '24 16:04 Zantares

Hi @ArthurZucker , can you help to review the fix PR?

Zantares avatar May 07 '24 02:05 Zantares

sorry for the delay @Zantares and thanks for your contribution!

ArthurZucker avatar May 23 '24 14:05 ArthurZucker