transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Accelerate support for GLM

Open larrylawl opened this issue 2 years ago • 6 comments

Feature request

Accelerate support for GLM.

Motivation

GLM is a SOTA chinese LLM. However, running the following code...

from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained("THUDM/glm-10b", trust_remote_code=True, device_map="auto", load_in_8bit=True)

gives error...

ValueError: GLMForConditionalGeneration does not support `device_map='auto'` yet.

Your contribution

I would be happy to contribute. However, I can't find a guide on adding other models to accelerate.

larrylawl avatar Mar 31 '23 09:03 larrylawl

You just need to add the proper attribute to GLMPreTrainedModel so that it knows which layers should not be split across GPUs and then test it works properly. Since this model uses the code on the Hub feature, the code of the model needs to be changed there to add something like in T5 here (since the model seems to look like T5). You can open a PR on their repo with this maybe?

sgugger avatar Mar 31 '23 13:03 sgugger

Thanks @sgugger for the advice! I've added the _no_split_modules attributes in this PR.

However, when I tried using device_map with the following code...

from transformers import AutoModelForSeq2SeqLM
model_name_or_path = "THUDM/glm-10b-chinese"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path,
                                    trust_remote_code=True, 
                                    revision="6adb492",
                                    device_map="auto",
                                    load_in_8bit=True,
                                    )
model.eval()

I faced the error...

Overriding torch_dtype=None with `torch_dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in mixed int8. Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning.
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[7], line 2
      1 # ours
----> 2 model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path,
      3                                     trust_remote_code=True, 
      4                                     cache_dir=SHARED_MODEL_DIR,
      5                                     revision="6adb492",
      6                                     device_map="auto",
      7                                     load_in_8bit=True,
      8                                     )
      9 model.eval()

File ~/ln/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py:466, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    462     model_class = get_class_from_dynamic_module(
    463         pretrained_model_name_or_path, module_file + ".py", class_name, **hub_kwargs, **kwargs
    464     )
    465     model_class.register_for_auto_class(cls.__name__)
--> 466     return model_class.from_pretrained(
    467         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    468     )
    469 elif type(config) in cls._model_mapping.keys():
    470     model_class = _get_model_class(config, cls._model_mapping)

File ~/ln/lib/python3.8/site-packages/transformers/modeling_utils.py:2648, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
   2638     if dtype_orig is not None:
   2639         torch.set_default_dtype(dtype_orig)
   2641     (
   2642         model,
   2643         missing_keys,
   2644         unexpected_keys,
   2645         mismatched_keys,
   2646         offload_index,
   2647         error_msgs,
-> 2648     ) = cls._load_pretrained_model(
   2649         model,
   2650         state_dict,
   2651         loaded_state_dict_keys,  # XXX: rename?
   2652         resolved_archive_file,
   2653         pretrained_model_name_or_path,
   2654         ignore_mismatched_sizes=ignore_mismatched_sizes,
   2655         sharded_metadata=sharded_metadata,
   2656         _fast_init=_fast_init,
   2657         low_cpu_mem_usage=low_cpu_mem_usage,
   2658         device_map=device_map,
   2659         offload_folder=offload_folder,
   2660         offload_state_dict=offload_state_dict,
   2661         dtype=torch_dtype,
   2662         load_in_8bit=load_in_8bit,
   2663         keep_in_fp32_modules=keep_in_fp32_modules,
   2664     )
   2666 model.is_loaded_in_8bit = load_in_8bit
   2668 # make sure token embedding weights are still tied if needed

File ~/ln/lib/python3.8/site-packages/transformers/modeling_utils.py:2971, in PreTrainedModel._load_pretrained_model(cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, load_in_8bit, keep_in_fp32_modules)
   2961 mismatched_keys += _find_mismatched_keys(
   2962     state_dict,
   2963     model_state_dict,
   (...)
   2967     ignore_mismatched_sizes,
   2968 )
   2970 if low_cpu_mem_usage:
-> 2971     new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
   2972         model_to_load,
   2973         state_dict,
   2974         loaded_keys,
   2975         start_prefix,
   2976         expected_keys,
   2977         device_map=device_map,
   2978         offload_folder=offload_folder,
   2979         offload_index=offload_index,
   2980         state_dict_folder=state_dict_folder,
   2981         state_dict_index=state_dict_index,
   2982         dtype=dtype,
   2983         load_in_8bit=load_in_8bit,
   2984         is_safetensors=is_safetensors,
   2985         keep_in_fp32_modules=keep_in_fp32_modules,
   2986     )
   2987     error_msgs += new_error_msgs
   2988 else:

File ~/ln/lib/python3.8/site-packages/transformers/modeling_utils.py:665, in _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix, expected_keys, device_map, offload_folder, offload_index, state_dict_folder, state_dict_index, dtype, load_in_8bit, is_safetensors, keep_in_fp32_modules)
    662         module_name = ".".join(module_name.split(".")[:-1])
    663     if module_name == "" and "" not in device_map:
    664         # TODO: group all errors and raise at the end.
--> 665         raise ValueError(f"{param_name} doesn't have any device set.")
    666     param_device = device_map[module_name]
    667 if param_device == "disk":

ValueError: word_embeddings.weight doesn't have any device set.

I managed to fix this by specifying a custom device_map (code below). However, device_map='auto' should work without the user passing a specific device_map right? Is my PR missing something?

device_map={'glm.word_embeddings': 0,
 'glm.transformer.embedding_dropout': 0,
 'glm.transformer.position_embeddings': 0,
 'glm.transformer.block_position_embeddings': 0,
 'glm.transformer.layers.0': 0,
 'glm.transformer.layers.1': 0,
 'glm.transformer.layers.2': 0,
 'glm.transformer.layers.3': 0,
 'glm.transformer.layers.4': 0,
 'glm.transformer.layers.5': 0,
 'glm.transformer.layers.6': 0,
 'glm.transformer.layers.7': 0,
 'glm.transformer.layers.8': 0,
 'glm.transformer.layers.9': 0,
 'glm.transformer.layers.10': 0,
 'glm.transformer.layers.11': 0,
 'glm.transformer.layers.12': 0,
 'glm.transformer.layers.13': 0,
 'glm.transformer.layers.14': 0,
 'glm.transformer.layers.15': 0,
 'glm.transformer.layers.16': 0,
 'glm.transformer.layers.17': 0,
 'glm.transformer.layers.18': 0,
 'glm.transformer.layers.19': 0,
 'glm.transformer.layers.20': 0,
 'glm.transformer.layers.21': 0,
 'glm.transformer.layers.22': 0,
 'glm.transformer.layers.23': 0,
 'glm.transformer.layers.24': 0,
 'glm.transformer.layers.25': 0,
 'glm.transformer.layers.26': 0,
 'glm.transformer.layers.27': 0,
 'glm.transformer.layers.28': 0,
 'glm.transformer.layers.29': 0,
 'glm.transformer.layers.30': 0,
 'glm.transformer.layers.31': 0,
 'glm.transformer.layers.32': 0,
 'glm.transformer.layers.33': 0,
 'glm.transformer.layers.34': 0,
 'glm.transformer.layers.35': 0,
 'glm.transformer.layers.36': 0,
 'glm.transformer.layers.37': 0,
 'glm.transformer.layers.38': 0,
 'glm.transformer.layers.39': 0,
 'glm.transformer.layers.40': 0,
 'glm.transformer.layers.41': 0,
 'glm.transformer.layers.42': 0,
 'glm.transformer.layers.43': 0,
 'glm.transformer.layers.44': 0,
 'glm.transformer.layers.45': 0,
 'glm.transformer.layers.46': 0,
 'glm.transformer.layers.47': 0,
 'glm.transformer.final_layernorm': 0}

# ours
model_name_or_path = "THUDM/glm-10b-chinese"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path,
                                    trust_remote_code=True, 
                                    revision="6adb492",
                                    device_map=device_map,
                                    load_in_8bit=True,
                                    )
model.eval()

larrylawl avatar Apr 05 '23 07:04 larrylawl

Does it work without the load_in_8bit part? Also what is your version of Accelerate?

sgugger avatar Apr 05 '23 14:04 sgugger

Nope, same error. Here's my dependencies:

accelerate==0.18.0
aiohttp==3.8.4
aiosignal==1.3.1
anyio==3.6.2
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
async-timeout==4.0.2
attrs==22.2.0
backcall==0.2.0
beautifulsoup4==4.12.0
bitsandbytes==0.37.2
bleach==6.0.0
certifi==2022.12.7
cffi==1.15.1
charset-normalizer==3.1.0
cmake==3.26.1
comm==0.1.3
datasets==2.11.0
debugpy==1.6.6
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.6
evaluate==0.4.0
executing==1.2.0
fastjsonschema==2.16.3
filelock==3.10.7
fqdn==1.5.1
frozenlist==1.3.3
fsspec==2023.3.0
huggingface-hub==0.13.3
idna==3.4
importlib-metadata==6.1.0
importlib-resources==5.12.0
ipykernel==6.22.0
ipython==8.12.0
ipython-genutils==0.2.0
isoduration==20.11.0
jedi==0.18.2
Jinja2==3.1.2
jsonpointer==2.3
jsonschema==4.17.3
jupyter-events==0.6.3
jupyter_client==8.1.0
jupyter_core==5.3.0
jupyter_server==2.5.0
jupyter_server_terminals==0.4.4
jupyterlab-pygments==0.2.2
-e git+https://github.com/larrylawl/prompt-infill-prompt.git@aefd41e421cf30485b2e14b13877cdf1232335c7#egg=lexnorm
lit==16.0.0
MarkupSafe==2.1.2
matplotlib-inline==0.1.6
mistune==2.0.5
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.14
nbclassic==0.5.4
nbclient==0.7.3
nbconvert==7.3.0
nbformat==5.8.0
nest-asyncio==1.5.6
networkx==3.1
notebook==6.5.3
notebook_shim==0.2.2
numpy==1.24.2
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
packaging==23.0
pandas==2.0.0
pandocfilters==1.5.0
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
pkgutil_resolve_name==1.3.10
platformdirs==3.2.0
prometheus-client==0.16.0
prompt-toolkit==3.0.38
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==11.0.0
pycparser==2.21
Pygments==2.14.0
pyrsistent==0.19.3
python-dateutil==2.8.2
python-json-logger==2.0.7
pytz==2023.3
PyYAML==6.0
pyzmq==25.0.2
regex==2023.3.23
requests==2.28.2
responses==0.18.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
Send2Trash==1.8.0
sentencepiece==0.1.97
six==1.16.0
sniffio==1.3.0
soupsieve==2.4
stack-data==0.6.2
sympy==1.11.1
terminado==0.17.1
tinycss2==1.2.1
tokenizers==0.13.2
torch==2.0.0
tornado==6.2
tqdm==4.65.0
traitlets==5.9.0
transformers==4.27.4
triton==2.0.0
typing_extensions==4.5.0
tzdata==2023.3
uri-template==1.2.0
urllib3==1.26.15
wcwidth==0.2.6
webcolors==1.13
webencodings==0.5.1
websocket-client==1.5.1
xxhash==3.2.0
yarl==1.8.2
zhon==1.1.5
zipp==3.15.0

larrylawl avatar Apr 06 '23 01:04 larrylawl

I just tried

from transformers import AutoModelForSeq2SeqLM
model_name_or_path = "THUDM/glm-10b-chinese"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True, revision="6adb492", device_map="auto")

and it worked without any issue.

sgugger avatar Apr 06 '23 13:04 sgugger

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 30 '23 15:04 github-actions[bot]