DeepSpeed
DeepSpeed copied to clipboard
[BUG] Tested the 2662 PR, It fails for GPTJ 6B and few others
Describe the bug We have tested PR again a few models https://github.com/microsoft/DeepSpeed/pull/2662
- OPT 1.3B, 2 tp degree, fp16
- OPT 13B, 4 tp degree, [fp16, int8]
- OPT 30B, 8 tp degree [fp16, int8]
- GPT NeoX 20B [fp16, int8]
- GPTJ 6B
Our test involves 2 steps.
- Load the model to the device and then generate the partitions and save them in a local directory.
- While generating the partitions, the model can be loaded with or without meta tensor. (This affects the partition generation)
- Meta tensors, have the same dimensions of the real tensor, but it contains no data.
- Load back the generated partition files to run inference.
We used your test suite to test your models.
System info (please complete the following information):
- 1 machine with 8 GPUS. NVIDIA A10G, 24GB memory per GPU
- Ubuntu
Check the tables below to know the results we got upon testing.
Load in CPU fully with HF, save to DS sharded and load back
Model | Partitions | Dtype | Result : Generate DS presharded checkpoints | Result: Loaded back DS presharded and run inference. |
---|---|---|---|---|
OPT 1.3B | 2 | float 16 | Successfully generates presharded checkpoint files | Successfully load back the presharded checkpoints and run inference and generate outputs. |
GPTJ 6B | 4 | float 16 | Successfully generates presharded checkpoint files | But loading back them returns the error. NotImplementedError: Cannot copy out of meta tensor; no data!Traceback (most recent call last): File "inference-test.py", line 57, in <module> pipe.model = deepspeed.init_inference(pipe.model, File "/usr/local/lib/python3.8/dist-packages/deepspeed/__init__.py", line 311, in init_inference engine = InferenceEngine(model, config=ds_inference_config) File "/usr/local/lib/python3.8/dist-packages/deepspeed/inference/engine.py", line 129, in __init__ self.module.to(device) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 927, in to return self._apply(convert) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 579, in _apply module._apply(fn) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 602, in _apply param_applied = fn(param) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 925, in convert return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)NotImplementedError: Cannot copy out of meta tensor; no data! |
GPT NeoX 20B | 8 | float 16 | Successfully generates presharded checkpoint files | Successfully load back the presharded checkpoints and run inference and generate outputs. |
OPT 13B | 4 | int 8 | Successfully generates presharded checkpoint files | Loaded the pre-sharded checkpoints, throws error while generating outputs.Traceback (most recent call last):File "inference-test.py", line 88, in <module>outputs = pipe(inputs,File "/tmp/ws/models/utils.py", line 69, in __call__outputs = self.generate_outputs(input_list, num_tokens=num_tokens, do_sample=do_sample)File "/tmp/ws/models/utils.py", line 113, in generate_outputsoutputs = self.model.generate(**input_tokens, **generate_kwargs)File "/usr/local/lib/python3.8/dist-packages/deepspeed/inference/engine.py", line 537, in _generatereturn self.module.generate(*inputs, **kwargs)File "/usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py", line 27, in decorate_contextreturn func(*args, **kwargs)File "/usr/local/lib/python3.8/dist-packages/transformers/generation_utils.py", line 1422, in generatereturn self.sample(File "/usr/local/lib/python3.8/dist-packages/transformers/generation_utils.py", line 2049, in samplenext_token_scores = logits_warper(input_ids, next_token_scores)File "/usr/local/lib/python3.8/dist-packages/transformers/generation_logits_process.py", line 92, in __call__scores = processor(input_ids, scores)File "/usr/local/lib/python3.8/dist-packages/transformers/generation_logits_process.py", line 233, in __call__indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]RuntimeError: CUDA error: an illegal memory access was encounteredCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.For debugging consider passing CUDA_LAUNCH_BLOCKING=1. |
OPT 30B | 8 | int 8 | Successfully generates presharded checkpoint files | Traceback (most recent call last):File "inference-test.py", line 88, in <module>outputs = pipe(inputs,File "/tmp/ws/models/utils.py", line 69, in __call__outputs = self.generate_outputs(input_list, num_tokens=num_tokens, do_sample=do_sample)File "/tmp/ws/models/utils.py", line 113, in generate_outputsoutputs = self.model.generate(**input_tokens, **generate_kwargs)File "/usr/local/lib/python3.8/dist-packages/deepspeed/inference/engine.py", line 537, in _generatereturn self.module.generate(*inputs, **kwargs)File "/usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py", line 27, in decorate_contextreturn func(*args, **kwargs)File "/usr/local/lib/python3.8/dist-packages/transformers/generation_utils.py", line 1422, in generatereturn self.sample(File "/usr/local/lib/python3.8/dist-packages/transformers/generation_utils.py", line 2049, in samplenext_token_scores = logits_warper(input_ids, next_token_scores)File "/usr/local/lib/python3.8/dist-packages/transformers/generation_logits_process.py", line 92, in __call__scores = processor(input_ids, scores)File "/usr/local/lib/python3.8/dist-packages/transformers/generation_logits_process.py", line 233, in __call__indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]RuntimeError: CUDA error: an illegal memory access was encounteredCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.For debugging consider passing CUDA_LAUNCH_BLOCKING=1. |
GPT NeoX 20B | 8 | int8 | Successfully generates presharded checkpoint files | Loading them back throws the error. Traceback (most recent call last):File "inference-test.py", line 57, in <module>pipe.model = deepspeed.init_inference(pipe.model,File "/usr/local/lib/python3.8/dist-packages/deepspeed/__init__.py", line 311, in init_inferenceengine = InferenceEngine(model, config=ds_inference_config)File "/usr/local/lib/python3.8/dist-packages/deepspeed/inference/engine.py", line 126, in __init__self._apply_injection_policy(config)File "/usr/local/lib/python3.8/dist-packages/deepspeed/inference/engine.py", line 339, in _apply_injection_policyreplace_transformer_layer(client_module,File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/replace_module.py", line 850, in replace_transformer_layerload_model_with_checkpoint(replaced_module,File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 252, in load_model_with_checkpointload_module_recursive(r_module)File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 246, in load_module_recursiveload_module_recursive(File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 246, in load_module_recursiveload_module_recursive(File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 244, in load_module_recursivelayer_policies[child.__class__](child, prefix + name + '.')File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 176, in load_transformer_layerload_parameters(child, prefix + n + '.')File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 85, in load_parametersassert tmp_data.dtype != torch.int8, \AssertionError: Merging of the checkpoints are not supported when using INT8 checkpoint! Please use a as many GPUs as TP-size for the checkpoint |
Load with Meta Tensor, save to DS sharded and load back
Model | Partitions | Dtype | Result : Generate DS presharded checkpoints | Result: Loaded back DS presharded and run inference. |
---|---|---|---|---|
OPT 1.3B | 2 | float 16 | Could not even generate presharded checpoints for the model. It generates the following error during init_inference API call. Traceback (most recent call last):File "inference-test.py", line 57, in <module>pipe.model = deepspeed.init_inference(pipe.model,File "/usr/local/lib/python3.8/dist-packages/deepspeed/__init__.py", line 311, in init_inferenceengine = InferenceEngine(model, config=ds_inference_config)File "/usr/local/lib/python3.8/dist-packages/deepspeed/inference/engine.py", line 126, in __init__self._apply_injection_policy(config)File "/usr/local/lib/python3.8/dist-packages/deepspeed/inference/engine.py", line 339, in _apply_injection_policyreplace_transformer_layer(client_module,File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/replace_module.py", line 820, in replace_transformer_layerload_model_with_checkpoint(replaced_module,File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 252, in load_model_with_checkpointload_module_recursive(r_module)File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 246, in load_module_recursiveload_module_recursive(File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 246, in load_module_recursiveload_module_recursive(File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 244, in load_module_recursivelayer_policies[child.__class__](child, prefix + name + '.')File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 30, in loadmodule.weight = mp_replace.copy(module.weight.data, sd[0][prefix + 'weight'])KeyError: 'decoder.embed_tokens.weight' |
- |
OPT 13B | 4 | float 16 | Successfully generates presharded checkpoint files | Successfully load back the presharded checkpoints and run inference and generate outputs. |
OPT 30B | 8 | float 16 | Successfully generates presharded checkpoint files | Successfully load back the presharded checkpoints and run inference and generate outputs. |
GPT J 6B | 4 | float 16 | Successfully generates presharded checkpoint files | Traceback (most recent call last):File "inference-test.py", line 57, in |
GPT NeoX 20B | 8 | float 16 | Could not even generate presharded checpoints for the model. It generates the following error during init_inference API call. Traceback (most recent call last):File "inference-test.py", line 57, in <module>pipe.model = deepspeed.init_inference(pipe.model,File "/usr/local/lib/python3.8/dist-packages/deepspeed/__init__.py", line 311, in init_inferenceengine = InferenceEngine(model, config=ds_inference_config)File "/usr/local/lib/python3.8/dist-packages/deepspeed/inference/engine.py", line 126, in __init__self._apply_injection_policy(config)File "/usr/local/lib/python3.8/dist-packages/deepspeed/inference/engine.py", line 339, in _apply_injection_policyreplace_transformer_layer(client_module,File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/replace_module.py", line 820, in replace_transformer_layerload_model_with_checkpoint(replaced_module,File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 252, in load_model_with_checkpointload_module_recursive(r_module)File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 246, in load_module_recursiveload_module_recursive(File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 246, in load_module_recursiveload_module_recursive(File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 244, in load_module_recursivelayer_policies[child.__class__](child, prefix + name + '.')File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 178, in load_transformer_layerreplace_policy.load_params(module,File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/replace_policy.py", line 864, in load_paramsmaybe_copy(module.attention,File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/replace_policy.py", line 250, in maybe_copydst = mp_replace.copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/replace_module.py", line 116, in copydst = dst.reshape(-1).data.copy_(weight_split.reshape(-1)).reshape( |
- |
OPT 13B | 4 | int 8 | Successfully generates presharded checkpoint files | ```Traceback (most recent call last):File "inference-test.py", line 88, in |
OPT 30B | 8 | int 8 | Successfully generates presharded checkpoint files | ```Traceback (most recent call last):File "inference-test.py", line 88, in |
GPT NeoX 20B | 8 | int 8 | Successfully generates checkpoint files. | Error occurs when we load back the presharded checkpoint files. Traceback (most recent call last):File "inference-test.py", line 57, in <module>pipe.model = deepspeed.init_inference(pipe.model,File "/usr/local/lib/python3.8/dist-packages/deepspeed/__init__.py", line 311, in init_inferenceengine = InferenceEngine(model, config=ds_inference_config)File "/usr/local/lib/python3.8/dist-packages/deepspeed/inference/engine.py", line 126, in __init__self._apply_injection_policy(config)File "/usr/local/lib/python3.8/dist-packages/deepspeed/inference/engine.py", line 339, in _apply_injection_policyreplace_transformer_layer(client_module,File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/replace_module.py", line 850, in replace_transformer_layerload_model_with_checkpoint(replaced_module,File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 252, in load_model_with_checkpointload_module_recursive(r_module)File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 246, in load_module_recursiveload_module_recursive(File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 246, in load_module_recursiveload_module_recursive(File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 244, in load_module_recursivelayer_policies[child.__class__](child, prefix + name + '.')File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 176, in load_transformer_layerload_parameters(child, prefix + n + '.')File "/usr/local/lib/python3.8/dist-packages/deepspeed/module_inject/load_checkpoint.py", line 85, in load_parametersassert tmp_data.dtype != torch.int8, \AssertionError: Merging of the checkpoints are not supported when using INT8 checkpoint! Please use a as many GPUs as TP-size for the checkpoint |
A quick summary in words.
- For OPT 1.3B and GPT NeoX 20B with dtype float 16, We COULD NOT EVEN GENERATE the partition files when we load the model with meta tensor.
- For GPTJ 6B with dtype float16, we are able to generate the pre-sharded checkpoint files with/without meta tensor. But loading them back generates the error.
- For int8 dtype,
- For all OPT 13B, OPT 30B models, We could generate the pre-sharded checkpoint files when the model is loaded without meta tensor, i.e the traditional way, but generating the output after loading the model throws error.
- For GPT NeoX 20B model, we could generate the pre-sharded checkpoint files in both with and without meta tensor. But loading them back throws the error.
The major problem here are the following:
- GPTJ presharded checpoint files generated does not work.
- Presharded checkpoints generation with int8 quantization does not work as well.
@RezaYazdaniAminabadi @lekurile we did some experiment based on your PR. It worked for a few cases and doesn't work on some corner cases. I would suggest we merge the PR given some use cases works and let's fix the remaining one with follow up prs :)
BLOOM series was not covered since it is more like "known to work"
Hello @sindhuvahinis @lanking520, thank you for reporting this! With the merge of https://github.com/microsoft/DeepSpeed/pull/2725, the major part of this issue should have been resolved. I tested the models you listed with the master branch of DeepSpeed with meta tensor and int8 checkpoint loading. These models run smoothly. The GPTJ 6B model gives a different result. I believe it is another issue and I am actively investigating it. Could you do a quick check on your side to see if you still have this issue with the current master branch of DeepSpeed? Thank you!
@HeyangQin we did some tests on 2725 as well and still observing the major issues with INT8. Will share more details and setup
@HeyangQin we did some tests on 2725 as well and still observing the major issues with INT8. Will share more details and setup
@lanking520 Thank you for the update! If possible, could you share the command line you use to reproduce this issue?
Thanks for the update @HeyangQin As Qing said, we also tested #2725. You can check the comments in the PR the error we faced. we tested in multiple GPU size with tp_size more than 1.
@HeyangQin We used the same test suite as yours. For example, for GPT-NeoX. The way we tested is we generated the checkpoints using save_mp_checkpoint_path first and then loaded it back using meta tensor and checkpoint file.
deepspeed --num_nodes 1 \
--num_gpus 8 \
inference-test.py \
--use_kernel \
--ds_inference \
--use_meta_tensor \
--name EleutherAI/gpt-neox-20b \
--checkpoint_path /tmp/ws/gpt-neox-20b/ \
--save_mp_checkpoint_path /tmp/ws/sharded-gpt-neox-20b/ \
--dtype int8
deepspeed --num_nodes 1 \
--num_gpus 8 \
inference-test.py \
--use_kernel \
--ds_inference \
--use_meta_tensor \
--name EleutherAI/gpt-neox-20b \
--checkpoint_path /tmp/ws/sharded-gpt-neox-20b/ \
--dtype int8
Similar test could be conducted quickly on OPT/GPTJ/GPT-Neox/BLOOM 7B INT8, these models are all producing garbage outputs.
- OPT model is NCCL communication issue
- GPT-NeoX 20B is producing garbage
- BLOOM-7B:
shape '[1, 4, 32, 384]' is invalid for input of size 16384
Just tried these models on DeepSpeed 0.8.1
Maybe we could close this issue since Meta tensor and checkpoint loading for other precision type is mostly working (FP16/32). And open one for INT8 specifically. @HeyangQin what do you think?
Hi @lanking520 @sindhuvahinis, Thank you for the information. Previously I only tested checkpoint loading with int8. Now when I test checkpoint saving with int8, I see the same error as you reported. After some initial investigation, I think there are multiple underlying issues that caused these errors:
- DeepSpeedExample tries to load checkpoint even if they don't exist. I fixed it by https://github.com/microsoft/DeepSpeedExamples/commit/efacebb3ddbea86bb20c3af30fd060be0fa41ac8
-
load_params()
should reside in policy. I fixed it by https://github.com/microsoft/DeepSpeed/pull/2875. I will merge this once it is reviewed. - Kernel issues. I am working on this.
Once the int8 checkpoint saving works, I plan to add unit tests to prevent such errors in the future. I agree with @lanking520 that opening a new issue would make it more organized as this issue is about a PR that has been merged.
Sounds good, @sindhuvahinis let's close this issue and make a different issue titled :
[0.8.1] INT8 model loading/inference issue
And summarize the finding