open_flamingo icon indicating copy to clipboard operation
open_flamingo copied to clipboard

How to quantize open_flamingo?

Open YerongLi opened this issue 1 year ago • 0 comments

How to quantize open-Flamingo?

Expected Behavior

https://github.com/open-mmlab/Multimodal-GPT https://github.com/open-mmlab/Multimodal-GPT/blob/main/mmgpt/models/open_flamingo/builder.py This open_flamingov1 builder did successfully fine-tune with lora.

I have a similar builder.py trying to do qlora:

### 
"""Modified from https://github.com/mlfoundations/open_flamingo"""
import open_clip
import torch
import torch.nn as nn
from bigmodelvis import Visualization
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import LlamaForCausalLM, LlamaTokenizer,BitsAndBytesConfig
# from transformers import AutoModelForCausalLM, AutoTokenizer,BitsAndBytesConfig

from .flamingo import Flamingo
from .flamingo_lm import FlamingoLMMixin
from .utils import extend_instance

DEBUG = 1
def create_model_and_transforms(
    clip_vision_encoder_path: str,
    clip_vision_encoder_pretrained: str,
    lang_encoder_path: str,
    tokenizer_path: str,
    decoder_layers_attr_name: str = None,
    pretrained_model_path: str = None,
    tuning_config=None,
    **flamingo_kwargs,
):
    """
    Initialize a Flamingo model from a pretrained vision encoder and language encoder.
    Appends special tokens to the tokenizer and freezes backbones.

    Args:
        clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
        clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
        lang_encoder_path (str): path to pretrained language encoder
        tokenizer_path (str): path to pretrained tokenizer
        decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
    Returns:
        Flamingo: Flamingo model from pretrained vision and language encoders
        Image processor: Pipeline to preprocess input images
        Tokenizer: A tokenizer for the language model
    """
    print("init clip vision encoder")
    vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
        clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained
    )
    # set the vision encoder to output the visual features
    vision_encoder.visual.output_tokens = True
    print("init tokenizer")
    text_tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
    # add Flamingo special tokens to the tokenizer
    text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
    if text_tokenizer.pad_token is None:
        # Issue: GPT models don't have a pad token, which we use to
        # modify labels for the loss.
        text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
    text_tokenizer.bos_token_id = 1
    text_tokenizer.eos_token_id = 2

    print("init llama")
    quantization_config = BitsAndBytesConfig(
    load_in_4bit = True,
    load_in_8bit = False ,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
    )
    # lang_encoder = LlamaForCausalLM.from_pretrained(lang_encoder_path)
    if DEBUG == 1:
        lang_encoder = LlamaForCausalLM.from_pretrained(lang_encoder_path, quantization_config=quantization_config)
        # lang_encoder = prepare_model_for_kbit_training(lang_encoder)
    else: 
        lang_encoder = LlamaForCausalLM.from_pretrained(lang_encoder_path)

    extend_instance(lang_encoder, FlamingoLMMixin)

    if decoder_layers_attr_name is None:
        decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
    lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
    lang_encoder.resize_token_embeddings(len(text_tokenizer))

    model = Flamingo(
        vision_encoder,
        lang_encoder,
        text_tokenizer.encode("<|endofchunk|>")[-1],
        text_tokenizer.encode("<image>")[-1],
        vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"],
        cross_attn_every_n_layers=4,
        **flamingo_kwargs,
    )

    if pretrained_model_path is not None:
        print(f"loading pretrained model from {pretrained_model_path}")
        model.load_state_dict(torch.load(pretrained_model_path), strict=False)

    # Freeze all parameters
    model.requires_grad_(False)
    assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0

    if tuning_config is not None:
        if DEBUG == 1: 
            model = prepare_model_for_kbit_training(model)
            model.restart(
            vision_encoder,
            lang_encoder,
            text_tokenizer.encode("<|endofchunk|>")[-1],
            text_tokenizer.encode("<image>")[-1],
            vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"],
            cross_attn_every_n_layers=4,
            **flamingo_kwargs,
            )

        model = prepare_model_for_tuning(model, tuning_config)
    else:
        raise ValueError("tuning_config must be provided")

    print(
        f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
    )

    # if DEBUG == 1: model = prepare_model_for_kbit_training(model)

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Example usage
    num_params = count_parameters(model)
    print(f" ======= Number of trainable parameters: {num_params}")


    return model, image_processor, text_tokenizer


def _infer_decoder_layers_attr_name(model):
    for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
        if k.lower() in model.__class__.__name__.lower():
            return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]

    raise ValueError(
        f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
    )


__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
    "opt": "model.decoder.layers",
    "gptneo": "transformer.h",
    "gptj": "transformer.h",
    "gpt-j": "transformer.h",
    "pythia": "gpt_neox.layers",
    "llama": "model.layers",
}


def prepare_model_for_tuning(model: nn.Module, config):
    if config.lora:
        lora_config = LoraConfig(
            r=config.lora_r,
            lora_alpha=config.lora_alpha,
            target_modules=config.lora_target_modules,
            lora_dropout=config.lora_dropout,
            bias="none",  # won't use bias currently
            modules_to_save=[],  # TODO: might be helpful if save partial model
            task_type="VL",
        )
        model.lang_encoder = get_peft_model(model.lang_encoder, peft_config=lora_config)

    # manually unfreeze modules, we use a `substring` fashion mathcing
    for name, param in model.named_parameters():
        if any(substr in name for substr in config.unfrozen):
            param.requires_grad = True

    if config.vis and is_rank0():
        Visualization(model).structure_graph()
    return model


# temporary workaround, should use a common utils in the future
def is_rank0():
    if not torch.distributed.is_initialized():
        return True
    return torch.distributed.get_rank() == 0

Current Behavior

After prepare_model_for_kbit_training, the model does not have self.media_locations

Steps to Reproduce

  • https://github.com/open-mmlab/Multimodal-GPT.git
  • Simplify the data configuration (pick only lavla) Replace configs/dataset_config.py with following:
visual_datasets = [
   dict(
       type="llava",
       vis_root="data/coco/train2017",
       ann_paths=[
           "data/llava/detail_23k.json",
           "data/llava/complex_reasoning_77k.json",
       ],
   ),
   ]

language_datasets = [
   dict(
       type="alpaca_gpt4",
       ann_path="data/alpaca_gpt4/alpaca_gpt4_data.json",
   ),

]

  • Download data LLaVA and Alpaca GPT4 according to the official readme
  1. LlaVA

    Download from liuhaotian/LLaVA-Instruct-150K and place in data/llava/.

  2. Alpaca GPT4

    Download it from this link and place it in data/alpaca_gpt4/alpaca_gpt4_data.json.

  • Replace the following file with provided code builder.py: https://github.com/open-mmlab/Multimodal-GPT/blob/main/mmgpt/models/open_flamingo/builder.py
  • Run
torchrun --nproc_per_node=1 mmgpt/train/instruction_finetune.py \
  --lm_path /scratch/yerong/.cache/pyllama/hf/7B \
  --tokenizer_path /scratch/yerong/.cache/pyllama/hf/7B \
  --pretrained_path OpenFlamingo-9B/checkpoint.pt \
  --run_name train-my-gpt4 \
  --learning_rate 1e-5 \
  --lr_scheduler cosine \
  --batch_size 1\
  --tuning_config configs/lora_config.py \
  --dataset_config configs/dataset_config.py

https://github.com/open-mmlab/Multimodal-GPT/blob/main/mmgpt/models/open_flamingo/builder.py

Expected behavior

Total training steps: 99883                                                                                                                                   
Found no checkpoints for run train-my-gpt4.                                                                                                   
  0%|                                                                                                               | 0/99883 [00:00<?, ?it/s]/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None                                                                                                                  
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")                                                         
  0%|                                                                                                               | 0/99883 [00:01<?, ?it/s]                                                                                                                                               
Traceback (most recent call last):                                                                                                                            
  File "/scratch/yerong/Multimodal-GPT/mmgpt/train/instruction_finetune.py", line 470, in <module>                                            
    main()                                                                                                                                    
  File "/scratch/yerong/Multimodal-GPT/mmgpt/train/instruction_finetune.py", line 302, in main                                                
    train_one_epoch(                                                                                                                          
  File "/scratch/yerong/Multimodal-GPT/mmgpt/train/instruction_finetune.py", line 390, in train_one_epoch                                     
    loss_batch = model(                                                                                                                       
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                   
    return forward_call(*args, **kwargs)                                                                                                      
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward                
    output = self._run_ddp_forward(*inputs, **kwargs)                                                                                         
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward                                                                                                                                                      
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]                                                                      
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                   
    return forward_call(*args, **kwargs)                                                                                                                      
  File "/scratch/yerong/Multimodal-GPT/mmgpt/models/open_flamingo/flamingo.py", line 104, in forward                                          
    output = self.lang_encoder(                                                                                                               
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                   
    return forward_call(*args, **kwargs)                                                                                         
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/peft/peft_model.py", line 416, in forward                               
    return self.get_base_model()(*args, **kwargs)                                                                                                             
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl      
    return forward_call(*args, **kwargs)                                                                                         
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward  
    output = old_forward(*args, **kwargs)                                                                                        
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 688, in forward                                                                                                                                                     
    outputs = self.model(                                                                                                        
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl      
    return forward_call(*args, **kwargs)                                                                                                                      
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward                          
    output = old_forward(*args, **kwargs)                                                                                                                     
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 570, in forward                                                                                                                                                     
    layer_outputs = torch.utils.checkpoint.checkpoint(                                                                                                        
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint                                     
    return CheckpointFunction.apply(function, preserve, *args)
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/autograd/function.py", line 506, in apply                                                             
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 107, in forward                                                            
    outputs = run_function(*args)           
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 566, in custom_forward                                                            
    return module(*inputs, output_attentions, None)
  File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                                
    return forward_call(*args, **kwargs)           
TypeError: forward() takes from 2 to 3 positional arguments but 7 were given                            

Environment

python 3.10

Detailed Description

(OPTIONAL) Possible Implementation

YerongLi avatar Jul 18 '23 17:07 YerongLi