Chinese-CLIP icon indicating copy to clipboard operation
Chinese-CLIP copied to clipboard

load_from_name 加入 flash-attn 支持

Open ZechengLi19 opened this issue 9 months ago • 4 comments

感谢你如此好的代码实现,他对我的帮助很大,但是我在使用load_from_name 函数时,我发现并不支持flash-attn ,因此我自己实现了这一块的代码,但是我不确定实现是否正确,尽管它可以正常运行。

以下是代码片段

###### ------- ps: add use_flash_attention keyword ------- ######
def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
                   download_root: str = None, vision_model_name: str = None, text_model_name: str = None, 
                   input_resolution: int = None, use_flash_attention: bool = False):
    if name in _MODELS:
        model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
        model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution']
    elif os.path.isfile(name):
        assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'"
        model_path = name
        model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution
    else:
        raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

    with open(model_path, 'rb') as opened_file:
        # loading saved checkpoint
        checkpoint = torch.load(opened_file, map_location="cpu")

    model = create_model(model_name, checkpoint, use_flash_attention=use_flash_attention)
    if str(device) == "cpu":
        model.float()
    else:
        model.to(device)
    return model, image_transform(model_input_resolution)
###### ------- ps: convert flash_attention weight ------- ######
def create_model(model_name, checkpoint=None, use_flash_attention=False):
    vision_model, text_model = model_name.split('@')
    # Initialize the model.
    vision_model_config_file = Path(
        __file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
    print('Loading vision model config from', vision_model_config_file)
    assert os.path.exists(vision_model_config_file)

    text_model_config_file = Path(
        __file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
    print('Loading text model config from', text_model_config_file)
    assert os.path.exists(text_model_config_file)

    with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
        model_info = json.load(fv)
        for k, v in json.load(ft).items():
            model_info[k] = v
    if isinstance(model_info['vision_layers'], str):
        model_info['vision_layers'] = eval(model_info['vision_layers'])
    print('Model info', model_info)
    if use_flash_attention:
        model_info['use_flash_attention'] = use_flash_attention
    model = CLIP(**model_info)
    convert_weights(model)
            
    if checkpoint:
        if use_flash_attention:
            sd = checkpoint["state_dict"]
            sd = {k: v for k, v in sd.items() if "bert.pooler" not in k}
            if next(iter(sd.items()))[0].startswith('module'):
                sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
            # Resize the positional embedding by interpolation, if needed
            resize_pos_embed(sd, model, prefix="module.")
            # Adapt flash attention
            sd = convert_state_dict(sd)
            # Load the state dict
        else:
            sd = checkpoint["state_dict"]
            if next(iter(sd.items()))[0].startswith('module'):
                sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
        model.load_state_dict(sd)
    return model

如果作者有空能帮我检查一下,这一实现是否正确就好了~

如果是正确的,作者可以将我的implement加入到仓库中~

不甚感谢

ZechengLi19 avatar May 11 '24 08:05 ZechengLi19