transformers
transformers copied to clipboard
Accelerate support for GLM
Feature request
Accelerate support for GLM.
Motivation
GLM is a SOTA chinese LLM. However, running the following code...
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained("THUDM/glm-10b", trust_remote_code=True, device_map="auto", load_in_8bit=True)
gives error...
ValueError: GLMForConditionalGeneration does not support `device_map='auto'` yet.
Your contribution
I would be happy to contribute. However, I can't find a guide on adding other models to accelerate.
You just need to add the proper attribute to GLMPreTrainedModel so that it knows which layers should not be split across GPUs and then test it works properly. Since this model uses the code on the Hub feature, the code of the model needs to be changed there to add something like in T5 here (since the model seems to look like T5). You can open a PR on their repo with this maybe?
Thanks @sgugger for the advice! I've added the _no_split_modules attributes in this PR.
However, when I tried using device_map with the following code...
from transformers import AutoModelForSeq2SeqLM
model_name_or_path = "THUDM/glm-10b-chinese"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path,
trust_remote_code=True,
revision="6adb492",
device_map="auto",
load_in_8bit=True,
)
model.eval()
I faced the error...
Overriding torch_dtype=None with `torch_dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in mixed int8. Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning.
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[7], line 2
1 # ours
----> 2 model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path,
3 trust_remote_code=True,
4 cache_dir=SHARED_MODEL_DIR,
5 revision="6adb492",
6 device_map="auto",
7 load_in_8bit=True,
8 )
9 model.eval()
File ~/ln/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py:466, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
462 model_class = get_class_from_dynamic_module(
463 pretrained_model_name_or_path, module_file + ".py", class_name, **hub_kwargs, **kwargs
464 )
465 model_class.register_for_auto_class(cls.__name__)
--> 466 return model_class.from_pretrained(
467 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
468 )
469 elif type(config) in cls._model_mapping.keys():
470 model_class = _get_model_class(config, cls._model_mapping)
File ~/ln/lib/python3.8/site-packages/transformers/modeling_utils.py:2648, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
2638 if dtype_orig is not None:
2639 torch.set_default_dtype(dtype_orig)
2641 (
2642 model,
2643 missing_keys,
2644 unexpected_keys,
2645 mismatched_keys,
2646 offload_index,
2647 error_msgs,
-> 2648 ) = cls._load_pretrained_model(
2649 model,
2650 state_dict,
2651 loaded_state_dict_keys, # XXX: rename?
2652 resolved_archive_file,
2653 pretrained_model_name_or_path,
2654 ignore_mismatched_sizes=ignore_mismatched_sizes,
2655 sharded_metadata=sharded_metadata,
2656 _fast_init=_fast_init,
2657 low_cpu_mem_usage=low_cpu_mem_usage,
2658 device_map=device_map,
2659 offload_folder=offload_folder,
2660 offload_state_dict=offload_state_dict,
2661 dtype=torch_dtype,
2662 load_in_8bit=load_in_8bit,
2663 keep_in_fp32_modules=keep_in_fp32_modules,
2664 )
2666 model.is_loaded_in_8bit = load_in_8bit
2668 # make sure token embedding weights are still tied if needed
File ~/ln/lib/python3.8/site-packages/transformers/modeling_utils.py:2971, in PreTrainedModel._load_pretrained_model(cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, load_in_8bit, keep_in_fp32_modules)
2961 mismatched_keys += _find_mismatched_keys(
2962 state_dict,
2963 model_state_dict,
(...)
2967 ignore_mismatched_sizes,
2968 )
2970 if low_cpu_mem_usage:
-> 2971 new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
2972 model_to_load,
2973 state_dict,
2974 loaded_keys,
2975 start_prefix,
2976 expected_keys,
2977 device_map=device_map,
2978 offload_folder=offload_folder,
2979 offload_index=offload_index,
2980 state_dict_folder=state_dict_folder,
2981 state_dict_index=state_dict_index,
2982 dtype=dtype,
2983 load_in_8bit=load_in_8bit,
2984 is_safetensors=is_safetensors,
2985 keep_in_fp32_modules=keep_in_fp32_modules,
2986 )
2987 error_msgs += new_error_msgs
2988 else:
File ~/ln/lib/python3.8/site-packages/transformers/modeling_utils.py:665, in _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix, expected_keys, device_map, offload_folder, offload_index, state_dict_folder, state_dict_index, dtype, load_in_8bit, is_safetensors, keep_in_fp32_modules)
662 module_name = ".".join(module_name.split(".")[:-1])
663 if module_name == "" and "" not in device_map:
664 # TODO: group all errors and raise at the end.
--> 665 raise ValueError(f"{param_name} doesn't have any device set.")
666 param_device = device_map[module_name]
667 if param_device == "disk":
ValueError: word_embeddings.weight doesn't have any device set.
I managed to fix this by specifying a custom device_map (code below). However, device_map='auto' should work without the user passing a specific device_map right? Is my PR missing something?
device_map={'glm.word_embeddings': 0,
'glm.transformer.embedding_dropout': 0,
'glm.transformer.position_embeddings': 0,
'glm.transformer.block_position_embeddings': 0,
'glm.transformer.layers.0': 0,
'glm.transformer.layers.1': 0,
'glm.transformer.layers.2': 0,
'glm.transformer.layers.3': 0,
'glm.transformer.layers.4': 0,
'glm.transformer.layers.5': 0,
'glm.transformer.layers.6': 0,
'glm.transformer.layers.7': 0,
'glm.transformer.layers.8': 0,
'glm.transformer.layers.9': 0,
'glm.transformer.layers.10': 0,
'glm.transformer.layers.11': 0,
'glm.transformer.layers.12': 0,
'glm.transformer.layers.13': 0,
'glm.transformer.layers.14': 0,
'glm.transformer.layers.15': 0,
'glm.transformer.layers.16': 0,
'glm.transformer.layers.17': 0,
'glm.transformer.layers.18': 0,
'glm.transformer.layers.19': 0,
'glm.transformer.layers.20': 0,
'glm.transformer.layers.21': 0,
'glm.transformer.layers.22': 0,
'glm.transformer.layers.23': 0,
'glm.transformer.layers.24': 0,
'glm.transformer.layers.25': 0,
'glm.transformer.layers.26': 0,
'glm.transformer.layers.27': 0,
'glm.transformer.layers.28': 0,
'glm.transformer.layers.29': 0,
'glm.transformer.layers.30': 0,
'glm.transformer.layers.31': 0,
'glm.transformer.layers.32': 0,
'glm.transformer.layers.33': 0,
'glm.transformer.layers.34': 0,
'glm.transformer.layers.35': 0,
'glm.transformer.layers.36': 0,
'glm.transformer.layers.37': 0,
'glm.transformer.layers.38': 0,
'glm.transformer.layers.39': 0,
'glm.transformer.layers.40': 0,
'glm.transformer.layers.41': 0,
'glm.transformer.layers.42': 0,
'glm.transformer.layers.43': 0,
'glm.transformer.layers.44': 0,
'glm.transformer.layers.45': 0,
'glm.transformer.layers.46': 0,
'glm.transformer.layers.47': 0,
'glm.transformer.final_layernorm': 0}
# ours
model_name_or_path = "THUDM/glm-10b-chinese"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path,
trust_remote_code=True,
revision="6adb492",
device_map=device_map,
load_in_8bit=True,
)
model.eval()
Does it work without the load_in_8bit part? Also what is your version of Accelerate?
Nope, same error. Here's my dependencies:
accelerate==0.18.0
aiohttp==3.8.4
aiosignal==1.3.1
anyio==3.6.2
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
async-timeout==4.0.2
attrs==22.2.0
backcall==0.2.0
beautifulsoup4==4.12.0
bitsandbytes==0.37.2
bleach==6.0.0
certifi==2022.12.7
cffi==1.15.1
charset-normalizer==3.1.0
cmake==3.26.1
comm==0.1.3
datasets==2.11.0
debugpy==1.6.6
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.6
evaluate==0.4.0
executing==1.2.0
fastjsonschema==2.16.3
filelock==3.10.7
fqdn==1.5.1
frozenlist==1.3.3
fsspec==2023.3.0
huggingface-hub==0.13.3
idna==3.4
importlib-metadata==6.1.0
importlib-resources==5.12.0
ipykernel==6.22.0
ipython==8.12.0
ipython-genutils==0.2.0
isoduration==20.11.0
jedi==0.18.2
Jinja2==3.1.2
jsonpointer==2.3
jsonschema==4.17.3
jupyter-events==0.6.3
jupyter_client==8.1.0
jupyter_core==5.3.0
jupyter_server==2.5.0
jupyter_server_terminals==0.4.4
jupyterlab-pygments==0.2.2
-e git+https://github.com/larrylawl/prompt-infill-prompt.git@aefd41e421cf30485b2e14b13877cdf1232335c7#egg=lexnorm
lit==16.0.0
MarkupSafe==2.1.2
matplotlib-inline==0.1.6
mistune==2.0.5
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.14
nbclassic==0.5.4
nbclient==0.7.3
nbconvert==7.3.0
nbformat==5.8.0
nest-asyncio==1.5.6
networkx==3.1
notebook==6.5.3
notebook_shim==0.2.2
numpy==1.24.2
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
packaging==23.0
pandas==2.0.0
pandocfilters==1.5.0
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
pkgutil_resolve_name==1.3.10
platformdirs==3.2.0
prometheus-client==0.16.0
prompt-toolkit==3.0.38
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==11.0.0
pycparser==2.21
Pygments==2.14.0
pyrsistent==0.19.3
python-dateutil==2.8.2
python-json-logger==2.0.7
pytz==2023.3
PyYAML==6.0
pyzmq==25.0.2
regex==2023.3.23
requests==2.28.2
responses==0.18.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
Send2Trash==1.8.0
sentencepiece==0.1.97
six==1.16.0
sniffio==1.3.0
soupsieve==2.4
stack-data==0.6.2
sympy==1.11.1
terminado==0.17.1
tinycss2==1.2.1
tokenizers==0.13.2
torch==2.0.0
tornado==6.2
tqdm==4.65.0
traitlets==5.9.0
transformers==4.27.4
triton==2.0.0
typing_extensions==4.5.0
tzdata==2023.3
uri-template==1.2.0
urllib3==1.26.15
wcwidth==0.2.6
webcolors==1.13
webencodings==0.5.1
websocket-client==1.5.1
xxhash==3.2.0
yarl==1.8.2
zhon==1.1.5
zipp==3.15.0
I just tried
from transformers import AutoModelForSeq2SeqLM
model_name_or_path = "THUDM/glm-10b-chinese"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True, revision="6adb492", device_map="auto")
and it worked without any issue.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.