open_flamingo icon indicating copy to clipboard operation
open_flamingo copied to clipboard

Notes on FP16 inference

Open 152334H opened this issue 2 years ago • 2 comments

For no reason at all, I decided to try running the model on my RTX 3090. This turned out to be surprisingly difficult, so I am documenting my process here for people to search up.

Although the README provides some instructions on how to load the model,

  • it loads the model onto CPU by default,
  • it loads the model as fp32 by default,
  • requires_grad is enabled for the new layers by default,
  • the precision=... parameter from open_clip is partially broken due to a forced fp32 conversion in LayerNormFP32
  • attempting to load LLaMA in 8bit breaks the model, because accelerate discards the injected FlamingoLMMixin when device_map is used,

I have created a simple fork of this repo here that is slightly easier to use for people with consumer GPUs. Unfortunately, it still consumes a large amount of vram (because of the lack of 8bit), and it takes a long time to load the llama weights (because accelerate is broken)

152334H avatar Mar 29 '23 11:03 152334H

I followed the fork https://github.com/mlfoundations/open_flamingo/commit/25b17319723b41f900cd52a389466b97c053695d but am unable to load the Llama weights on the 3090 - crashes when I try to load the llama weights in fp16

NVIDIA-SMI 515.86.01    Driver Version: 515.86.01    CUDA Version: 11.7 

ericjang avatar Mar 29 '23 21:03 ericjang

Never mind, I was able to get things working. Had to decrease batch size to 2 to have sufficient memory.

Note also that one has to cast the inputs to half-precision, e.g. batch_images = batch_images.half()

Performance on the OK-VQA benchmark with the following settings:

python open_flamingo/eval/evaluate.py \
    --lm_path $LM_PATH \
    --lm_tokenizer_path $LM_TOKENIZER_PATH \
    --checkpoint_path $CKPT_PATH \
    --device $DEVICE \
    --cross_attn_every_n_layers 4 \
    --eval_ok_vqa \
    --ok_vqa_image_dir_path $VQAV2_IMG_PATH \
    --ok_vqa_annotations_json_path $VQAV2_ANNO_PATH \
    --ok_vqa_questions_json_path $VQAV2_QUESTION_PATH \
    --results_file $RESULTS_FILE \
    --num_samples 5000 --shots 2 --num_trials 1 \
    --batch_size 2
Shots 2 Trial 0 OK-VQA score: 35.15
Shots 2 Mean OK-VQA score: 35.15

ericjang avatar Mar 29 '23 22:03 ericjang