Chinese-CLIP
Chinese-CLIP copied to clipboard
load_from_name 加入 flash-attn 支持
感谢你如此好的代码实现,他对我的帮助很大,但是我在使用load_from_name 函数时,我发现并不支持flash-attn ,因此我自己实现了这一块的代码,但是我不确定实现是否正确,尽管它可以正常运行。
以下是代码片段
###### ------- ps: add use_flash_attention keyword ------- ######
def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
download_root: str = None, vision_model_name: str = None, text_model_name: str = None,
input_resolution: int = None, use_flash_attention: bool = False):
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution']
elif os.path.isfile(name):
assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'"
model_path = name
model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
# loading saved checkpoint
checkpoint = torch.load(opened_file, map_location="cpu")
model = create_model(model_name, checkpoint, use_flash_attention=use_flash_attention)
if str(device) == "cpu":
model.float()
else:
model.to(device)
return model, image_transform(model_input_resolution)
###### ------- ps: convert flash_attention weight ------- ######
def create_model(model_name, checkpoint=None, use_flash_attention=False):
vision_model, text_model = model_name.split('@')
# Initialize the model.
vision_model_config_file = Path(
__file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
print('Loading vision model config from', vision_model_config_file)
assert os.path.exists(vision_model_config_file)
text_model_config_file = Path(
__file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
print('Loading text model config from', text_model_config_file)
assert os.path.exists(text_model_config_file)
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
model_info = json.load(fv)
for k, v in json.load(ft).items():
model_info[k] = v
if isinstance(model_info['vision_layers'], str):
model_info['vision_layers'] = eval(model_info['vision_layers'])
print('Model info', model_info)
if use_flash_attention:
model_info['use_flash_attention'] = use_flash_attention
model = CLIP(**model_info)
convert_weights(model)
if checkpoint:
if use_flash_attention:
sd = checkpoint["state_dict"]
sd = {k: v for k, v in sd.items() if "bert.pooler" not in k}
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
# Resize the positional embedding by interpolation, if needed
resize_pos_embed(sd, model, prefix="module.")
# Adapt flash attention
sd = convert_state_dict(sd)
# Load the state dict
else:
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
model.load_state_dict(sd)
return model
如果作者有空能帮我检查一下,这一实现是否正确就好了~
如果是正确的,作者可以将我的implement加入到仓库中~
不甚感谢