OLMo
OLMo copied to clipboard
Output Hidden States seems to return None on forward pass of OLMO model
🐛 Describe the bug
OLMo/hf_olmo /modeling_olmo.py - there seems to be no usage of output_hidden_states=True, and therefore the outputs do not contain the hidden states of the model
Versions
Python 3.11.5
ai2-olmo==0.2.4 aiobotocore==2.11.2 aiohttp==3.9.3 aioitertools==0.11.0 aiosignal==1.3.1 alembic==1.13.1 antlr4-python3-runtime==4.9.3 anyascii==0.3.2 anyio==4.2.0 archspec @ file:///croot/archspec_1697725767277/work argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work async-generator==1.10 async-lru==2.0.4 attrs==23.2.0 Babel==2.14.0 beautifulsoup4==4.12.3 bleach==6.1.0 blingfire==0.1.8 boltons @ file:///work/ci_py311/boltons_1677685195580/work boto3==1.34.34 botocore==1.34.34 Brotli @ file:///work/ci_py311/brotli-split_1676830125088/work cached-path==1.5.1 cachetools==5.3.2 certifi @ file:///croot/certifi_1700501669400/work/certifi certipy==0.1.3 cffi @ file:///croot/cffi_1700254295673/work charset-normalizer==3.3.2 click==8.1.7 comm==0.2.1 conda @ file:///croot/conda_1701719518285/work conda-content-trust @ file:///croot/conda-content-trust_1693490622020/work conda-libmamba-solver @ file:///croot/conda-libmamba-solver_1702997573971/work/src conda-package-handling @ file:///croot/conda-package-handling_1690999929514/work conda_package_streaming @ file:///croot/conda-package-streaming_1690987966409/work contourpy==1.2.0 cryptography @ file:///croot/cryptography_1702070282333/work cycler==0.12.1 datasets==2.16.1 debugpy @ file:///croot/debugpy_1690905042057/work decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work defusedxml==0.7.1 dill==0.3.7 distro @ file:///croot/distro_1701455004953/work dolma==1.0.1 executing @ file:///opt/conda/conda-bld/executing_1646925071911/work fastjsonschema==2.19.1 fasttext-wheel==0.9.2 filelock==3.12.4 fonttools==4.47.2 fqdn==1.5.1 frozenlist==1.4.1 fsspec==2024.2.0 google-api-core==2.16.2 google-auth==2.27.0 google-cloud-core==2.4.1 google-cloud-storage==2.14.0 google-crc32c==1.5.0 google-resumable-media==2.7.0 googleapis-common-protos==1.62.0 greenlet==3.0.3 huggingface-hub==0.20.3 idna @ file:///work/ci_py311/idna_1676822698822/work imageio==2.33.1 inquirerpy==0.3.4 ipykernel @ file:///croot/ipykernel_1691121631942/work ipython @ file:///croot/ipython_1704833016303/work ipywidgets==8.1.1 isoduration==20.11.0 jedi @ file:///work/ci_py311_2/jedi_1679336495545/work Jinja2==3.1.3 jmespath==1.0.1 joblib==1.3.2 json5==0.9.14 jsonpatch @ file:///tmp/build/80754af9/jsonpatch_1615747632069/work jsonpointer==2.1 jsonschema==4.21.0 jsonschema-specifications==2023.12.1 jupyter==1.0.0 jupyter-console==6.6.3 jupyter-events==0.9.0 jupyter-lsp==2.2.2 jupyter-telemetry==0.1.0 jupyter_client @ file:///croot/jupyter_client_1699455897726/work jupyter_core @ file:///croot/jupyter_core_1698937308754/work jupyter_server==2.12.5 jupyter_server_terminals==0.5.1 jupyterhub==4.0.2 jupyterlab==4.0.11 jupyterlab-widgets==3.0.9 jupyterlab_pygments==0.3.0 jupyterlab_server==2.25.2 kiwisolver==1.4.5 langdetect==1.0.9 lazy_loader==0.3 libmambapy @ file:///croot/mamba-split_1698782620632/work/libmambapy LTpycld2==0.42 Mako==1.3.0 markdown-it-py==3.0.0 MarkupSafe==2.1.3 matplotlib==3.8.2 matplotlib-inline @ file:///work/ci_py311/matplotlib-inline_1676823841154/work mdurl==0.1.2 menuinst @ file:///croot/menuinst_1702390294373/work mistune==3.0.2 mpmath==1.3.0 msgspec==0.18.6 multidict==6.0.5 multiprocess==0.70.15 nbclient==0.9.0 nbconvert==7.14.2 nbformat==5.9.2 necessary==0.4.3 nest-asyncio @ file:///work/ci_py311/nest-asyncio_1676823382924/work networkx==3.2.1 nltk==3.8.1 notebook==7.0.7 notebook_shim==0.2.3 numpy==1.26.3 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.18.1 nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 oauthlib==3.2.2 omegaconf==2.3.0 overrides==7.4.0 packaging @ file:///croot/packaging_1693575174725/work pamela==1.1.0 pandas==2.1.4 pandocfilters==1.5.1 parso @ file:///opt/conda/conda-bld/parso_1641458642106/work pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work pfzy==0.3.4 pillow==10.2.0 platformdirs @ file:///croot/platformdirs_1692205439124/work pluggy @ file:///work/ci_py311/pluggy_1676822818071/work prometheus-client==0.19.0 prompt-toolkit @ file:///croot/prompt-toolkit_1704404351921/work protobuf==4.25.2 psutil @ file:///work/ci_py311_2/psutil_1679337388738/work ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work pyarrow==15.0.0 pyarrow-hotfix==0.6 pyasn1==0.5.1 pyasn1-modules==0.3.0 pybind11==2.11.1 pycosat @ file:///croot/pycosat_1696536503704/work pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work Pygments @ file:///croot/pygments_1684279966437/work pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work pyparsing==3.1.1 PySocks @ file:///work/ci_py311/pysocks_1676822712504/work python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work python-json-logger==2.0.7 pytz==2023.3.post1 PyYAML==6.0.1 pyzmq @ file:///croot/pyzmq_1705605076900/work qtconsole==5.5.1 QtPy==2.4.1 referencing==0.32.1 regex==2023.12.25 requests @ file:///croot/requests_1690400202158/work requirements-parser==0.5.0 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rich==13.7.0 rpds-py==0.17.1 rsa==4.9 ruamel.yaml @ file:///work/ci_py311/ruamel.yaml_1676838772170/work s3fs==2024.2.0 s3transfer==0.10.0 safetensors==0.4.2 scikit-image==0.22.0 scikit-learn==1.4.0 scipy==1.11.4 Send2Trash==1.8.2 six @ file:///tmp/build/80754af9/six_1644875935023/work smart-open==6.4.0 sniffio==1.3.0 soupsieve==2.5 SQLAlchemy==2.0.25 stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work sympy==1.12 terminado==0.18.0 threadpoolctl==3.2.0 tifffile==2023.12.9 tinycss2==1.2.1 tokenizers==0.15.1 torch==2.1.2 torchaudio==2.1.2 torchvision==0.16.2 tornado @ file:///croot/tornado_1696936946304/work tqdm @ file:///croot/tqdm_1679561862951/work traitlets @ file:///work/ci_py311/traitlets_1676823305040/work transformers==4.37.2 triton==2.1.0 truststore @ file:///croot/truststore_1695244293384/work types-python-dateutil==2.8.19.20240106 types-setuptools==69.0.0.20240125 typing_extensions==4.9.0 tzdata==2023.4 uniseg==0.7.2 uri-template==1.3.0 urllib3 @ file:///croot/urllib3_1698257533958/work wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work webcolors==1.13 webencodings==0.5.1 websocket-client==1.7.0 widgetsnbextension==4.0.9 wrapt==1.16.0 xxhash==3.4.1 yarl==1.9.4 zstandard @ file:///work/ci_py311_2/zstandard_1679339489613/work
Hey @idobenshaul10, this is not currently implemented. I just took a look and its a bit of a heavy lift due to the block arrangement in OLMo (which is much more complicated than something like Llama below), but if you're interested in trying it, would love a contribution.
Do you want us to raise a warning that it's not implemented? I added a PR to do that.
Hi @idobenshaul10, this should be fixed by #451. Hasn't been merged in yet but should be shortly.
When passing output_hidden_states=True
argument, the output dictionary does not seem to contain the hidden states.
The model initialization:
model = AutoModelForCausalLM.from_pretrained('allenai/OLMo-1B', trust_remote_code=True)
Putting together the tokenized input:
inputs = {'input_ids':inputs['input_ids'].cuda(), 'attention_mask':inputs['attention_mask'].cuda(), 'output_hidden_states':True }
The forward pass:
out = model.forward(**inputs)
The returned output dictionary:
print(out.keys())
==> odict_keys(['logits', 'past_key_values'])
Any thoughts on this @sarahwie
Hmm that's bizarre because I tested the code & am using it. Did you upgrade your pip/conda install to the latest package version which has the fix (ai2-olmo==0.2.5
)?
Yep, I was on version 0.2.4. Thanks. This issue can probably be closed lol.