zett
zett copied to clipboard
Error when training a hypernetwork
I tried to train a hypernetwork with English and Chinese dataset, and transfer a bilingual tokenizer for TinyLlama.
My devices are 2 * A100 80G, with CUDA driver version 12.2
My config is:
{
"output_dir": "output-debug",
"train_directory": "data/train",
"valid_directory": "data/valid",
"langs": "data/langs.txt",
"model_name_or_path": "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
"revision": "refs/pr/8",
"loss": "clm",
"n_embd": 2048,
"n_token_subsample": null,
"random_warmup_steps": 0,
"identity_n_subsample": 16384,
"identity_steps": 0,
"warmup_steps": [
10000
],
"steps": 200000,
"dtype": "bfloat16",
"use_unigram_bias": true,
"learning_rate": [
6e-5
],
"max_grad_norm": 0.1,
"extra_valid_tokenizer_names": [
"models/TinyLlama-1.1B-intermediate-step-1431k-3T-Ext"
],
"extra_valid_files": [
"data/valid/en.parquet",
"data/valid/zh.parquet"
],
"extra_lang_codes": [
"en",
"zh"
],
"n_valid_subsample": 4000,
"do_tokenizer_sampling": true,
"hn_rescale_embeddings": true,
"hn_surface_maxlen": 15,
"tokenizer_sample_mean": 32768,
"tokenizer_sample_max": 32768,
"tokenizer_sample_std": 0,
"tokenizer_batch_size": 32,
"weight_decay": 0.01,
"adam_beta2": 0.95,
"hn_model_name_or_path": "roberta-base",
"tokenizer_noise_mean": 1e-5,
"tokenizer_noise_std": 4,
"hn_embed_lang_id": true,
"hn_add_inter_token_attention": false,
"hn_embed_target_priors": false,
"hn_inter_token_attention_bias_by_priors": true,
"hn_embed_using_source_embeddings": true,
"train_batch_size": 2,
"eval_batch_size": 2,
"hn_hidden_size": 2048,
"hn_intermediate_size": 4096,
"gradient_accumulation_steps": 1,
"learnable_bias": false,
"add_target_priors_to_bias": false,
"lexical_loss_weight": 0.5,
"debug": false,
"dataloader_num_workers": 64,
"mix_languages": false,
"logging_steps": 10
}
data/langs.txt is
en,1
zh,3
Everything works well in the main training loop, but I meet errors when it goes into logging_steps:
Traceback (most recent call last):
File "/home/jnguan/code/zett/train.py", line 1605, in <module>
main()
File "/home/jnguan/code/zett/train.py", line 1516, in main
lambda x: x.flatten(), stack_forest(train_metrics)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/flax/training/common_utils.py", line 69, in stack_forest
return jax.tree_util.tree_map(stack_args, *forest)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/tree_util.py", line 244, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/flax/training/common_utils.py", line 68, in <lambda>
stack_args = lambda *args: np.stack(args)
^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/numpy/core/shape_base.py", line 443, in stack
arrays = [asanyarray(arr) for arr in arrays]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/numpy/core/shape_base.py", line 443, in <listcomp>
arrays = [asanyarray(arr) for arr in arrays]
^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 390, in __array__
return np.asarray(self._value, dtype=dtype)
^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 588, in _value
if self.is_fully_replicated:
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jnguan/.miniconda/envs/zett/lib/python3.11/site-packages/jax/_src/array.py", line 354, in is_fully_replicated
return self.sharding.is_fully_replicated
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'UnspecifiedValue' object has no attribute 'is_fully_replicated'
Full log: zett-142044.log
My environment:
Package Version
------------------------ -----------
absl-py 2.1.0
accelerate 0.30.1
aiohttp 3.9.5
aiosignal 1.3.1
appdirs 1.4.4
attrs 23.2.0
certifi 2024.2.2
charset-normalizer 3.3.2
chex 0.1.86
click 8.1.7
cmake 3.29.3
contourpy 1.2.1
cycler 0.12.1
datasets 2.19.1
dill 0.3.8
docker-pycreds 0.4.0
etils 1.8.0
filelock 3.14.0
flax 0.8.0
fonttools 4.52.4
frozenlist 1.4.1
fsspec 2024.5.0
gitdb 4.0.11
GitPython 3.1.43
h5py 3.8.0
huggingface-hub 0.23.2
idna 3.7
importlib_resources 6.4.0
jax 0.4.23
jax-cuda12-pjrt 0.4.23
jax-cuda12-plugin 0.4.23
jaxlib 0.4.23
Jinja2 3.1.4
joblib 1.4.2
kiwisolver 1.4.5
lit 18.1.6
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.7.2
maturin 1.3.0
mdurl 0.1.2
ml-dtypes 0.4.0
mpmath 1.3.0
msgpack 1.0.8
multidict 6.0.5
multiprocess 0.70.16
nest-asyncio 1.6.0
networkx 3.3
numpy 1.26.4
nvidia-cublas-cu11 11.10.3.66
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu11 11.7.101
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvcc-cu12 12.5.40
nvidia-cuda-nvrtc-cu11 11.7.99
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu11 8.5.0.96
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu11 10.9.0.58
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu11 10.2.10.91
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu11 11.4.0.1
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu11 11.7.4.91
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu11 2.14.3
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.5.40
nvidia-nvtx-cu11 11.7.91
nvidia-nvtx-cu12 12.1.105
opt-einsum 3.3.0
optax 0.1.5
orbax-checkpoint 0.5.14
packaging 24.0
pandas 2.0.3
pathtools 0.1.2
pillow 10.3.0
pip 24.0
protobuf 4.25.3
psutil 5.9.8
pyahocorasick 2.0.0
pyarrow 16.1.0
pyarrow-hotfix 0.6
Pygments 2.18.0
pyparsing 3.0.9
python-dateutil 2.9.0.post0
pytz 2024.1
PyYAML 6.0.1
regex 2024.5.15
requests 2.32.3
rich 13.7.1
rust_utils 0.14.1.dev0
safetensors 0.4.3
scikit-learn 1.4.2
scipy 1.10.1
sentry-sdk 2.3.1
setproctitle 1.3.3
setuptools 69.5.1
six 1.16.0
smmap 5.0.1
sympy 1.12.1
tensorstore 0.1.60
threadpoolctl 3.5.0
tokenizers 0.19.1
toolz 0.12.1
torch 2.3.0
tqdm 4.66.4
transformers 4.41.1
triton 2.3.0
typing_extensions 4.12.0
tzdata 2024.1
urllib3 2.2.1
wandb 0.15.4
wheel 0.43.0
xxhash 3.4.1
yarl 1.9.4
zipp 3.19.0
Hi @jubgjf, can you try branch mentioned in this - https://github.com/bminixhofer/zett/issues/8