自定义模型多卡推理问题,想一机多卡,每卡一个模型,并行推理一个测评
我自定义了一个模型,想一机多卡,每卡一个模型,并行推理一个测评,类似这个issue https://github.com/open-compass/VLMEvalKit/issues/479 请问该如何修改,实现这个功能?谢谢
现在能并行,但是每个模型都load到了8张卡上
import torch
from PIL import Image
from ...smp import *
from ...dataset import DATASET_TYPE
from ..base import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
from .LLaVAMoE import *
from abc import abstractproperty
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
IMAGE_PLACEHOLDER = "<image-placeholder>"
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs):
kwargs = {"device_map": device_map, **kwargs}
if device != "cuda":
kwargs['device_map'] = {"": device}
if load_8bit:
kwargs['load_in_8bit'] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
else:
kwargs['torch_dtype'] = torch.float16
# kwargs['attn_implementation'] = 'flash_attention_2'
if 'llava' in model_name.lower():
# Load LLaVA model
if 'lora' in model_name.lower() and model_base is None:
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
if 'lora' in model_name.lower() and model_base is not None:
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
print('Loading LLaVA from base model...')
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
if model.lm_head.weight.shape[0] != token_num:
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
print('Loading additional LLaVA weights...')
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
else:
# this is probably from HF Hub
from huggingface_hub import hf_hub_download
def load_from_hf(repo_id, filename, subfolder=None):
cache_file = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder)
return torch.load(cache_file, map_location='cpu')
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
if any(k.startswith('model.model.') for k in non_lora_trainables):
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
model.load_state_dict(non_lora_trainables, strict=False)
if 'MOE' not in model_name:
AssertionError('LLaVA model name should contain `MoE`.')
if 'MOE' in model_name:
from .CoIN.peft import PeftModel, TaskType, get_peft_model, CoINMOELoraConfig, WEIGHTS_NAME, set_peft_model_state_dict
print('Loading MoE LoRA weights...')
else:
from peft import PeftModel
print('Loading LoRA weights...')
# model.load_adapter()
# load_lora_model()
model = PeftModel.from_pretrained(model, model_path)
print('Merging LoRA weights...')
model = model.merge_and_unload()
print('Model is loaded...')
elif model_base is not None:
# this may be mm projector only
print('Loading LLaVA from base model...')
if 'mpt' in model_name.lower():
if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
model.load_state_dict(mm_projector_weights, strict=False)
else:
if 'mpt' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
else:
# Load language model
if model_base is not None:
# PEFT model
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
print(f"Loading LoRA weights from {model_path}")
model = PeftModel.from_pretrained(model, model_path)
print(f"Merging weights")
model = model.merge_and_unload()
print('Convert to FP16...')
model.to(torch.float16)
else:
use_fast = False
if 'mpt' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
image_processor = None
if 'llava' in model_name.lower():
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len
class LLaVA_MoE(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False
def __init__(self, model_path, **kwargs):
assert model_path is not None, 'model_path is required'
try:
from llava.mm_utils import get_model_name_from_path
except Exception as err:
logging.critical(
"Please install llava from https://github.com/haotian-liu/LLaVA"
)
raise err
assert osp.exists(model_path) or splitlen(model_path) == 2
self.system_prompt = (
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions. "
)
self.stop_str = "</s>"
self.model_path = os.path.expanduser(model_path)
model_name = get_model_name_from_path(model_path)
# TODO:need to modify this
model_base = '/data3/rundongwang/CoIN/checkpoints/LLaVA/Vicuna/llava-7b-v1.5'
self.tokenizer, self.model, self.image_processor, self.context_len = (
load_pretrained_model(
model_path=model_path,
model_base=model_base,
model_name=model_name,
device="cuda",
device_map="auto",
)
)
# self.model = self.model.cuda()
self.conv_mode = "llava_v1"
kwargs_default = dict(
do_sample=False,
temperature=0,
max_new_tokens=512,
top_p=None,
num_beams=1,
use_cache=True,
) # noqa E501
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
warnings.warn(
f"Following kwargs received: {self.kwargs}, will use as generation config. "
)
def use_custom_prompt(self, dataset):
assert dataset is not None
if DATASET_TYPE(dataset) == "MCQ":
return True
return False
def build_prompt(self, line, dataset=None):
assert self.use_custom_prompt(dataset)
assert dataset is None or isinstance(dataset, str)
tgt_path = self.dump_image(line, dataset)
question = line["question"]
hint = line["hint"] if ("hint" in line and not pd.isna(line["hint"])) else None
if hint is not None:
question = hint + "\n" + question
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
for key, item in options.items():
question += f"\n{key}. {item}"
prompt = question
if len(options):
prompt += (
"\n请直接回答选项字母。"
if cn_string(prompt)
else "\nAnswer with the option's letter from the given choices directly."
)
else:
prompt += (
"\n请直接回答问题。"
if cn_string(prompt)
else "\nAnswer the question directly."
)
message = [dict(type="image", value=s) for s in tgt_path]
message.append(dict(type="text", value=prompt))
return message
def concat_tilist(self, message):
text, images = "", []
for item in message:
if item["type"] == "text":
text += item["value"]
elif item["type"] == "image":
text += " <image> "
images.append(item["value"])
return text, images
def chat_inner(self, message, dataset=None):
from llava.mm_utils import (
process_images,
tokenizer_image_token,
KeywordsStoppingCriteria,
)
from llava.constants import IMAGE_TOKEN_INDEX
prompt = self.system_prompt
images = []
for utter in message:
prompt += "USER: " if utter["role"] == "user" else "ASSISTANT: "
content, images_sub = self.concat_tilist(utter["content"])
prompt += content
images.extend(images_sub)
prompt += " " if utter["role"] == "user" else self.stop_str
assert message[-1]["role"] == "user", message
prompt += "ASSISTANT: "
images = [Image.open(s).convert("RGB") for s in images]
args = abstractproperty()
args.image_aspect_ratio = "pad"
image_tensor = process_images(images, self.image_processor, args).to(
"cuda", dtype=torch.float16
)
input_ids = (
tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.cuda()
)
keywords = [self.stop_str]
stopping_criteria = KeywordsStoppingCriteria(
keywords, self.tokenizer, input_ids
)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor,
stopping_criteria=[stopping_criteria],
**self.kwargs,
)
output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[
0
].strip()
return output
def generate_inner(self, message, dataset=None):
from llava.mm_utils import (
process_images,
tokenizer_image_token,
KeywordsStoppingCriteria,
)
from llava.constants import IMAGE_TOKEN_INDEX
# Support interleave text and image
content, images = self.concat_tilist(message)
images = [Image.open(s).convert("RGB") for s in images]
args = abstractproperty()
args.image_aspect_ratio = "pad"
if images:
image_tensor = process_images(images, self.image_processor, args).to(
"cuda", dtype=torch.float16
)
else:
image_tensor = None
prompt = self.system_prompt + "USER: " + content + " ASSISTANT: "
input_ids = (
tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.cuda()
)
keywords = [self.stop_str]
stopping_criteria = KeywordsStoppingCriteria(
keywords, self.tokenizer, input_ids
)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor,
stopping_criteria=[stopping_criteria],
**self.kwargs,
)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
output = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[
0
].strip()
if output.endswith(self.stop_str):
output = output[:-len(self.stop_str)]
output = output.strip()
return output
使用48GA6000显卡
您好,可以在运行torchrun命令之前限制cuda_device,比如:
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 run.py --model xxx --data xxx
您好,感谢回复,但是应该不是这个解决方案,我有8张卡,我想在load模型的时候每张卡load一个模型,8卡并行,现在如果我不指定cuda_device,默认就相当于CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,现在也是8卡并行,但是在load模型的时候,每个模型都load到了8张卡上,而不是一张卡上,我想让每个模型load一张卡上
请问您用的是什么模型呢?能提供一下指令吗?如果模型没有调用内置的split_model函数,不会出现模型切分的情况。
如果您的每个模型都被切分到八张卡上,应该是模型在init的时候,也就是load_pretrained_model函数中进行了相关操作。您可以查看一下这个函数。
用的是llava模型,但是改成了MoE lora的结构,load_pretrained_model已经贴在上面的代码中了,请问哪里有问题吗?启动的指令就是torchrun这个指令
这可能和您模型中自定义的init有关,我建议您先使用--nproc_per_node=1进行debug,观察您的模型在init的过程中哪一步将模型进行了split。这可能有助于您快速定位问题。