DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] Inference predictions dont match Huggingface for GPT-J

Open rahul003 opened this issue 2 years ago • 24 comments

Describe the bug

hf_output [{'generated_text': 'Try without sampling the data.\n\nA:\n\nYou can use the following code to get the data from the database.\n$sql = "SELECT * FROM `table`";\n$result = mysqli_query($conn,'}]
ds output [{'generated_text': 'Try without sampling the ( � hub ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ('}]

To Reproduce Steps to reproduce the behavior:

import torch
from transformers import pipeline
import deepspeed

query_text = "Try without sampling"
from transformers import GPTJForCausalLM
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B",
                revision="float16",
                torch_dtype=torch.float16,
                low_cpu_mem_usage=True)

pipe = pipeline("text-generation", model=model, tokenizer="EleutherAI/gpt-j-6B", device=0, framework="pt")

pipe.model.half()

hf_output = pipe(query_text, do_sample=False)

pipe.model = deepspeed.init_inference(
    pipe.model,
    mp_size=1,
    dtype=torch.half,
    replace_method="auto",
    replace_with_kernel_inject=True,
)

ds_output = pipe(query_text, do_sample=False)

print('HUGGINGFACE:', hf_output[0])
print('DEEPSPEED:', ds_output[0])

Expected behavior Output predictions match HF predictions

ds_report output

oot@2f0b3a15b3d0:/fsx/huilgolr/inference/rubik# ds_report
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
sparse_attn ............ [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/opt/conda/lib/python3.8/site-packages/torch']
torch version .................... 1.11.0+cu113
torch cuda version ............... 11.3
torch hip version ................ None
nvcc version ..................... 11.3
deepspeed install path ........... ['/deepspeed']
deepspeed info ................... 0.7.1+8b2a6371, 8b2a6371, master
deepspeed wheel compiled w. ...... torch 1.11, cuda 11.3

Screenshots If applicable, add screenshots to help explain your problem. Screenshots NA

System info (please complete the following information):

OS: [e.g. Ubuntu 18.04] Ubuntu
GPU count and types A100 GPU
Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB] N/A
Python version 3.8.3
Any other relevant info about your setup

Launcher context inference, single process

rahul003 avatar Aug 17 '22 20:08 rahul003

Hi @rahul003,

I am able to repro this on my side using your script. However, when using mine, which is as follows, there is no issue with it:

import os
import torch
import deepspeed
import transformers

from deepspeed import module_inject
from transformers import pipeline
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoBlock as gpt2_transformer

# Get local gpu rank from torch.distributed/deepspeed launcher
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

print(
    "***************** Creating model in RANK ({0}) with WORLD_SIZE = {1} *****************"
    .format(local_rank,
            world_size))
generator = pipeline('text-generation',
                     model='EleutherAI/gpt-j-6B',
                     device=local_rank)
hf_output = generator("Try without sampling", do_sample=False)
print('HUGGINGFACE:', hf_output)

generator.model = deepspeed.init_inference(generator.model,
                                           mp_size=world_size,
                                           dtype=torch.half,
                                           replace_method='auto',
                                           replace_with_kernel_inject=True)
string = generator("Try without sampling", do_sample=False)
print(string)

Here are the generated texts:

HUGGINGFACE: [{'generated_text': "Try without sampling, and if you don't like it, you can always sample it.\n\n------\njoshu\nI'm not sure I understand the point of this.\n\n~~~\njoshu\nI think I get"}]
DEEPSPEED: [{'generated_text': "Try without sampling, and if you don't like it, you can always sample it.\n\n------\njoshu\nI'm not sure I understand the point of this.\n\n~~~\njoshu\nI think I get"}]

I will try your script again and see what's going on.

Best, Reza

RezaYazdaniAminabadi avatar Aug 18 '22 17:08 RezaYazdaniAminabadi

Hi @rahul003 , I did check again with your script and it seems the issue is regarding setting this flag low_cpu_mem_usage to true when creating the mode. Can you please verify if the issue is resolved by setting this flag to False? btw, I have never used this flag before. Can you please describe what it does? Thanks, Reza

RezaYazdaniAminabadi avatar Aug 24 '22 22:08 RezaYazdaniAminabadi

I have a same problem, i am using fine-tuned GPTJForCasualLM.from_pretrained() without low_cpu_mem_usage flag, and generation with deepspeed and without is different.

I have last deepspeed version, transformers==4.21.2

AlexWortega avatar Aug 27 '22 01:08 AlexWortega

Can you try with the above script that I pasted?

RezaYazdaniAminabadi avatar Aug 27 '22 01:08 RezaYazdaniAminabadi

Also, can you please show the outputs? Thanks, Reza

RezaYazdaniAminabadi avatar Aug 27 '22 01:08 RezaYazdaniAminabadi

Got same result with

replace_method=None,
replace_with_kernel_inject=False

and different with

replace_method='auto',
replace_with_kernel_inject=True

AlexWortega avatar Aug 27 '22 01:08 AlexWortega

the results might be different, but is it meaningful when using kernels? Can you please paste the output?

RezaYazdaniAminabadi avatar Aug 27 '22 01:08 RezaYazdaniAminabadi

HUGGINGFACE: [{'generated_text': "Try without sampling.\n\nI'm not sure if I'm doing it right.\n\nI'm not sure if I'm doing it right.\n\nI'm not sure if I'm doing it right.\n\nI'm not sure if"}]

with kernels [{'generated_text': "Try without sampling.\n\nI'm not sure if I'm doing it right.\n\n.\n\n.\n\n.\n\n.\n\n.\n\n as"}] without [{'generated_text': "Try without sampling.\n\nI'm not sure if I'm doing it right.\n\nI'm not sure if I'm doing it right.\n\nI'm not sure if I'm doing it right.\n\nI'm not sure if"}]

AlexWortega avatar Aug 27 '22 01:08 AlexWortega

Interesting, I am not sure what might be different between our system environment that we see different results! I am using Torch1.12+CUDA11.6 and I see similar results between HF and DeepSpeed with kernels.

RezaYazdaniAminabadi avatar Aug 27 '22 01:08 RezaYazdaniAminabadi

I am using to '1.12.0+cu116'

AlexWortega avatar Aug 27 '22 01:08 AlexWortega

Can you please paste the whole log? I want to see the transformer configuration

RezaYazdaniAminabadi avatar Aug 27 '22 01:08 RezaYazdaniAminabadi

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
 [WARNING]  please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
async_io ............... [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
 [WARNING]  On Ampere and higher architectures please use CUDA 11+
transformer_inference .. [NO] ....... [NO]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/alex/anaconda3/lib/python3.9/site-packages/torch']
torch version .................... 1.12.0+cu116
torch cuda version ............... 11.6
torch hip version ................ None
nvcc version ..................... 10.1
deepspeed install path ........... ['/home/alex/anaconda3/lib/python3.9/site-packages/deepspeed']
deepspeed info ................... 0.7.2, unknown, unknown
deepspeed wheel compiled w. ...... torch 1.12, cuda 11.6

AlexWortega avatar Aug 27 '22 02:08 AlexWortega

Sorry, I meant the output log when you are running the test.

RezaYazdaniAminabadi avatar Aug 27 '22 02:08 RezaYazdaniAminabadi

https://colab.research.google.com/drive/1nv-UI30gPx7Hj6laeV2DKwvey3CjCzCM?usp=sharing i reproduce error on colab, but fix doesnt lol

UPD

Fix bug reproducing too, just reload notebook, i provide proof in end of notebook

AlexWortega avatar Aug 27 '22 02:08 AlexWortega

I ve tryed

  • Downgrading till 0.5.9 - doesnt help
  • Redifine injection_policy injection_policy={GPTNeoBlock: ('SelfAttention.o', 'EncDecAttention.o', 'DenseReluDense.wo')} - doesnt help What else can i try? Idn what to do (

AlexWortega avatar Aug 27 '22 02:08 AlexWortega

Hi @AlexWortega,

I did retry running this with the same test above. I also did modify it to be similar to yours. However, I am still seeing similar results between HF and DeepSpeed:

***************** Creating model in RANK (0) with WORLD_SIZE = 1 *****************
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
/home/reyazda/.local/lib/python3.8/site-packages/transformers/generation_utils.py:1202: UserWarning: Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to 50 (`self.config.max_length`). Controlling `max_length` via the config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(
HUGGINGFACE: [{'generated_text': "Try without sampling, and if you don't like it, you can always sample it.\n\n------\njoshu\nI'm not sure I understand the point of this.\n\n~~~\njoshu\nI think I get"}]
[2022-08-27 12:26:56,262] [INFO] [logging.py:68:log_dist] [Rank -1] DeepSpeed info: version=0.7.1+dce3acaa, git-hash=dce3acaa, git-branch=master
[2022-08-27 12:26:56,262] [INFO] [logging.py:68:log_dist] [Rank -1] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
Using /home/reyazda/.cache/torch_extensions/py38_cu116 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/reyazda/.cache/torch_extensions/py38_cu116/transformer_inference/build.ninja...
Building extension module transformer_inference...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module transformer_inference...
Time to load transformer_inference op: 0.49053120613098145 seconds
[2022-08-27 12:26:58,417] [INFO] [logging.py:68:log_dist] [Rank -1] DeepSpeed-Inference config: {'layer_id': 0, 'hidden_size': 4096, 'intermediate_size': 16384, 'heads': 16, 'num_hidden_layers': -1, 'fp16': True, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 1, 'q_int8': False, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': 64, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': True, 'mlp_after_attn': False, 'mlp_act_func_type': <ActivationFuncType.GELU: 1>, 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': False}
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
/home/reyazda/.local/lib/python3.8/site-packages/transformers/generation_utils.py:1202: UserWarning: Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to 50 (`self.config.max_length`). Controlling `max_length` via the config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(
DEEPSPEED: [{'generated_text': "Try without sampling, and if you don't like it, you can always sample it.\n\n------\njoshu\nI'm not sure I understand the point of this.\n\n~~~\njoshu\nI think I get"}]

I am running on A100-40G. Here is the test snippet:

import deepspeed
import transformers

from deepspeed import module_inject
from transformers import pipeline
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoBlock as gpt2_transformer

# Get local gpu rank from torch.distributed/deepspeed launcher
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

print(
    "***************** Creating model in RANK ({0}) with WORLD_SIZE = {1} *****************"
    .format(local_rank,
            world_size))
generator = pipeline('text-generation',
                     model='EleutherAI/gpt-j-6B',
                     revision="float16",
                     torch_dtype=torch.float16,
                     device=local_rank)
hf_output = generator("Try without sampling", do_sample=False)
print('HUGGINGFACE:', hf_output)

generator.model = deepspeed.init_inference(generator.model,
                                           mp_size=world_size,
                                           dtype=torch.half,
                                           replace_method='auto',
                                           replace_with_kernel_inject=True)
string = generator("Try without sampling", do_sample=False)
print("DEEPSPEED:", string)

Unfortunately, unless I can repro the issue you see on your side, I cannot be much helpful. Also, I am seeing that you are getting similar results in your testing, or am I missing something here? image Thanks, Reza

RezaYazdaniAminabadi avatar Aug 27 '22 07:08 RezaYazdaniAminabadi

Hi @RezaYazdaniAminabadi , Can you show you ds_report + pip freeze? Thanks, Alex

AlexWortega avatar Aug 27 '22 10:08 AlexWortega

Hi @AlexWortega ,

Here is the output of the two commands:

 ds_report
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
 [WARNING]  please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
async_io ............... [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/opt/conda/lib/python3.8/site-packages/torch']
torch version .................... 1.12.1+cu116
torch cuda version ............... 11.6
torch hip version ................ None
nvcc version ..................... 11.6
deepspeed install path ........... ['/home/reyazda/DeepSpeed/deepspeed']
deepspeed info ................... 0.7.3+9eea4ee4, 9eea4ee4, ds-inference/fix-mp2
deepspeed wheel compiled w. ...... torch 1.12, cuda 11.6
reyazda@webxt7c7400004Q:~$ pip freeze
absl-py==0.12.0
accelerate==0.12.0
adal==1.2.7
-e [email protected]:v3/agicode/TuringModelShare/XLM-E@2ee5665cf8d50051768797120b2e56055f31ffe2#egg=adbpe&subdirectory=src-adbpe
alabaster==0.7.12
antlr4-python3-runtime==4.8
apex==0.1
appdirs==1.4.4
argon2-cffi==20.1.0
ascii-graph==1.5.1
astunparse==1.6.3
async-generator==1.10
attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work
audioread==2.1.9
azure-common==1.1.27
azure-core==1.15.0b1
azure-graphrbac==0.61.1
azure-identity==1.4.1
azure-mgmt-authorization==0.61.0
azure-mgmt-containerregistry==8.0.0b1
azure-mgmt-core==1.3.0b2
azure-mgmt-keyvault==9.0.0
azure-mgmt-resource==13.0.0
azure-mgmt-storage==11.2.0
azureml-core==0.1.0.38250401
azureml-dataprep==2.17.0.dev0+0edc67d
azureml-dataprep-native==34.0.0
azureml-dataprep-rslex==1.15.0.dev0+0edc67d
Babel==2.9.1
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
backports.tempfile==1.0
backports.weakref==1.0.post1
beautifulsoup4 @ file:///home/linux1/recipes/ci/beautifulsoup4_1610988766420/work
bleach==3.3.0
blis @ file:///tmp/build/80754af9/cython-blis_1613319335612/work
boto3==1.11.11
botocore==1.14.17
brotlipy==0.7.0
cachetools==4.2.2
catalogue==1.0.0
certifi==2020.12.5
cffi==1.14.5
cfgv==3.3.0
chardet==4.0.0
clang-format==9.0.0
click==7.1.2
cloudpickle==1.6.0
codecov==2.1.11
colorama==0.4.5
conda==4.10.1
conda-build==3.21.4
conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1603018141399/work
contextlib2==0.6.0.post1
coverage==5.5
cryptography==3.4.7
cxxfilt==0.2.2
cycler==0.10.0
cymem @ file:///tmp/build/80754af9/cymem_1613319259039/work
Cython==0.28.4
DataProperty==0.50.1
decorator @ file:///tmp/build/80754af9/decorator_1617916966915/work
-e [email protected]:microsoft/DeepSpeed.git@9eea4ee4abf27e0242a562381f92babbb3718841#egg=deepspeed
defusedxml==0.7.1
distlib==0.3.5
distro==1.5.0
DLLogger @ git+git://github.com/NVIDIA/dllogger.git@26a0f8f1958de2c0c460925ff6102a4d2486d6cc
docker==4.4.4
docutils==0.15.2
dotnetcore2==2.1.20
einops==0.4.1
entrypoints==0.3
-e [email protected]:v3/agicode/TuringModelShare/XLM-E@2ee5665cf8d50051768797120b2e56055f31ffe2#egg=fairseq&subdirectory=fairseq
filelock==3.7.1
flake8==3.7.9
flash-attn==0.1
Flask==1.1.2
flatbuffers==1.12
future==0.18.2
gast==0.4.0
glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work
google-auth==1.30.1
google-auth-oauthlib==0.4.4
google-pasta==0.2.0
graphsurgeon @ file:///workspace/TensorRT-7.2.3.4/graphsurgeon/graphsurgeon-0.4.5-py2.py3-none-any.whl
grpcio==1.34.1
h5py==3.1.0
hjson==3.1.0
html2text==2020.1.16
huggingface-hub==0.9.0
hydra-core==1.0.7
hypothesis==4.50.8
identify==2.2.6
idna==2.10
imageio==2.9.0
imagesize==1.2.0
importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617874469820/work
importlib-resources==5.9.0
-e [email protected]:v3/agicode/TuringModelShare/XLM-E@2ee5665cf8d50051768797120b2e56055f31ffe2#egg=infinibatch&subdirectory=infinibatch
inflect==5.3.0
iniconfig==1.1.1
ipdb==0.13.7
ipykernel==5.5.3
ipython @ file:///tmp/build/80754af9/ipython_1617120885885/work
ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
isodate==0.6.0
itsdangerous==1.1.0
jedi==0.17.0
jeepney==0.6.0
Jinja2 @ file:///tmp/build/80754af9/jinja2_1612213139570/work
jmespath==0.10.0
joblib==1.0.1
json5==0.9.5
jsonpickle==2.0.0
jsonschema @ file:///tmp/build/80754af9/jsonschema_1594303806266/work
jupyter-client==6.1.12
jupyter-core==4.7.1
jupyter-tensorboard @ git+https://github.com/cliffwoolley/jupyter_tensorboard.git@ffa7e26138b82549453306e06b535a9ac36db17a
jupyterlab==2.3.1
jupyterlab-pygments==0.1.2
jupyterlab-server==1.2.0
jupytext==1.11.1
keras-nightly==2.5.0.dev2021032900
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work
librosa==0.8.0
llvmlite==0.35.0
lmdb==1.2.1
lxml==4.9.1
Mako==1.1.4
Markdown==3.3.4
markdown-it-py==0.6.2
MarkupSafe==1.1.1
maskrcnn-benchmark @ file:///opt/pytorch/examples/maskrcnn/pytorch
matplotlib==3.4.1
mbstrdecoder==1.0.1
mccabe==0.6.1
mdit-py-plugins==0.2.6
mistune==0.8.4
mlperf-compliance==0.0.10
mock @ file:///tmp/build/80754af9/mock_1607622725907/work
msal==1.12.0
msal-extensions==0.2.2
msgfy==0.1.0
msrest==0.6.21
msrestazure==0.6.4
murmurhash @ file:///tmp/build/80754af9/murmurhash_1607456108764/work
nbclient==0.5.3
nbconvert==6.0.7
nbformat==5.1.3
ndg-httpsclient==0.5.1
nest-asyncio==1.5.1
networkx==2.0
ninja==1.10.0.post2
nltk==3.6.2
nodeenv==1.6.0
notebook==6.2.0
numba @ file:///tmp/build/80754af9/numba_1614888130619/work
numpy==1.23.2
nvidia-dali-cuda110==1.0.0
nvidia-dlprof-pytorch-nvtx @ file:///nvidia/opt/dlprof/bin/nvidia_dlprof_pytorch_nvtx-1.1.0-py3-none-any.whl
nvidia-pyprof @ git+https://github.com/NVIDIA/PyProf@5aae754018a7b607ed047de0f3d77dd71640e919
nvidia-tensorboard @ file:///nvidia/opt/tensorboard_install/nvidia_tensorboard-1.15.0%2Bnv21.04-py3-none-any.whl
nvidia-tensorboard-plugin-dlprof @ file:///nvidia/opt/tensorboard_install/nvidia_tensorboard_plugin_dlprof-1.2.0-py3-none-any.whl
oauthlib==3.1.0
omegaconf==2.0.6
onnx @ file:///opt/pytorch/pytorch/third_party/onnx
onnxruntime==1.7.0
opt-einsum==3.3.0
packaging==20.9
pandas==1.1.4
pandocfilters==1.4.3
parso @ file:///tmp/build/80754af9/parso_1617223946239/work
pathspec==0.8.1
pathvalidate==2.4.1
pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
Pillow==8.2.0
Pillow-SIMD @ file:///tmp/pillow-simd
pkginfo==1.7.0
plac @ file:///tmp/build/80754af9/plac_1594259967336/work
platformdirs==2.5.2
pluggy==0.13.1
polygraphy==0.29.1
pooch==1.3.0
portalocker==1.7.1
pre-commit==2.13.0
preshed==3.0.2
prettytable==2.1.0
progressbar==2.5
prometheus-client==0.10.1
prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work
protobuf==3.15.8
pssh==2.3.1
psutil @ file:///tmp/build/80754af9/psutil_1612298023621/work
ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
py==1.10.0
py-cpuinfo==8.0.0
py-spy==0.3.12
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybind11==2.6.2
pycocotools @ git+https://github.com/nvidia/cocoapi.git@9a47a76980d02f70a371e12d4fad61f644a209f1#subdirectory=PythonAPI
pycodestyle==2.5.0
pycosat==0.6.3
pycparser==2.20
pycuda==2020.1
pydantic==1.9.2
pydot==1.4.2
pyflakes==2.1.1
Pygments @ file:///tmp/build/80754af9/pygments_1615143339740/work
PyJWT==2.1.0
pynvml==8.0.4
pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1605545627475/work
pyparsing==2.4.7
pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work
PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
pytablewriter==0.47.0
pytest==6.2.3
pytest-cov==2.11.1
pytest-forked==1.3.0
pytest-pythonpath==0.7.3
python-dateutil==2.8.1
python-hostlist==1.21
python-nvd3==0.15.0
python-slugify==4.0.1
pytools==2021.2.6
pytorch-quantization==2.1.0
pytorch-transformers==1.1.0
pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work
PyWavelets==1.1.1
PyYAML==5.4.1
pyzmq==22.0.3
regex==2020.1.8
requests==2.25.1
requests-oauthlib==1.3.0
resampy==0.2.2
revtok @ git+git://github.com/jekbradbury/revtok.git@f1998b72a941d1e5f9578a66dc1c20b01913caab
rsa==4.7.2
ruamel-yaml==0.15.87
ruamel.yaml.clib==0.2.2
s3transfer==0.3.7
sacrebleu==2.2.0
sacremoses==0.0.35
scikit-image==0.15.0
scikit-learn==0.24.2
scipy @ file:///tmp/build/80754af9/scipy_1618855957096/work
SecretStorage==3.3.1
Send2Trash==1.5.0
sentencepiece==0.1.91
six==1.15.0
snowballstemmer==2.1.0
SoundFile==0.10.3.post1
soupsieve @ file:///tmp/build/80754af9/soupsieve_1616183228191/work
sox==1.4.1
spacy @ file:///tmp/build/80754af9/spacy_1608321098157/work
Sphinx==3.5.4
sphinx-glpi-theme==0.3
sphinx-rtd-theme==0.5.2
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==1.0.3
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.4
srsly @ file:///tmp/build/80754af9/srsly_1607548537638/work
subword-nmt @ git+git://github.com/rsennrich/subword-nmt.git@48ba99e657591c329e0003f0c6e32e493fa959ef
tabledata==1.1.3
tabulate==0.8.9
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorboardX==2.1
tensorflow==2.5.0
tensorflow-estimator==2.5.0
tensorrt @ file:///workspace/TensorRT-7.2.3.4/python/tensorrt-7.2.3.4-cp38-none-linux_x86_64.whl
termcolor==1.1.0
terminado==0.9.4
testpath==0.4.4
text-unidecode==1.3
thinc @ file:///tmp/build/80754af9/thinc_1607710152385/work
threadpoolctl==2.1.0
tokenizers==0.12.1
toml==0.10.2
torch==1.12.1+cu116
torchaudio==0.12.1+cu116
torchtext @ file:///opt/pytorch/text
torchvision @ file:///opt/pytorch/vision
tornado==6.1
tqdm==4.53.0
traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work
transformers==4.21.2
typepy==1.1.5
typing-extensions==3.7.4.3
uff @ file:///workspace/TensorRT-7.2.3.4/uff/uff-0.6.9-py2.py3-none-any.whl
Unidecode==1.2.0
urllib3==1.25.11
virtualenv==20.16.3
wasabi @ file:///tmp/build/80754af9/wasabi_1612219178408/work
wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
webencodings==0.5.1
websocket-client==1.0.1
Werkzeug==1.0.1
wrapt==1.12.1
yacs==0.1.8
zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work

Thanks, Reza

RezaYazdaniAminabadi avatar Aug 28 '22 01:08 RezaYazdaniAminabadi

Thank you Reze, i gonna try to redroduce enviroment in docker, and write you later

AlexWortega avatar Aug 28 '22 13:08 AlexWortega

This low_cpu_mem_usage is a feature in HF to create the model with meta device (no cpu usage for model) and load state_dict so that peak memory usage is not 2x model (model + state_dict). Weird how HF output is fine but DS output gets messed up when using this feature for 1 process setting.

rahul003 avatar Aug 30 '22 18:08 rahul003

I am using Cu11.3 stack though compared to yours

rahul003 avatar Aug 30 '22 18:08 rahul003

I just tried to reprod issue in deespeed docker image with the snippet @RezaYazdaniAminabadi posted. Single process passed, but multi process didn't.

This is the script I used https://github.com/microsoft/DeepSpeed/issues/2230#issuecomment-1219738438

DS output: [{'generated_text': 'Try without sampling,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,'}]

Steps to reproduce:

  • nvidia-docker run -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --shm-size=2048m deepspeed/deepspeed:latest_torch111 bash
  • pip install deepspeed==0.7.2 transformers==4.21.2
  • deepspeed --num_gpus 4 gptj.py

rahul003 avatar Aug 30 '22 19:08 rahul003

I also downgrade all to same as @RezaYazdaniAminabadi , and make replace_with_kernel_inject=False.

AlexWortega avatar Aug 30 '22 19:08 AlexWortega

Created a separate issue for the low_cpu_mem_usage flag issue. https://github.com/microsoft/DeepSpeed/issues/2275

Let's use this for tracking the multi gpu correctness for GPT-J seen above

rahul003 avatar Aug 30 '22 19:08 rahul003

@rahul003, @AlexWortega, the issue is root caused and you can look at the fix here https://github.com/microsoft/DeepSpeed/pull/2489. Could you please try it on your side and confirm if the issue is fixed?

lokoppakmsft avatar Nov 08 '22 19:11 lokoppakmsft