accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

TypeError: Accelerator.__init__() got an unexpected keyword argument 'use_seedable_sampler'

Open wangjianqiao111 opened this issue 1 year ago • 4 comments

System Info

absl-py                       1.4.0
accelerate                    0.24.1
addict                        2.4.0
aim                           3.17.5
aim-ui                        3.17.5
aimrocks                      0.4.0
aiofiles                      23.2.1
aiohttp                       3.8.5
aiosignal                     1.3.1
antlr4-python3-runtime        4.9.3
anyio                         3.7.1
appdirs                       1.4.4
asttokens                     2.0.5
async-timeout                 4.0.3
attrs                         23.1.0
backcall                      0.2.0
backoff                       2.2.1
base58                        2.0.1
bitsandbytes                  0.41.1
brotlipy                      0.7.0
cachetools                    5.3.1
certifi                       2023.7.22
cffi                          1.15.1
chardet                       5.2.0
charset-normalizer            2.0.4
click                         8.1.7
colorama                      0.4.6
comm                          0.1.2
contexttimer                  0.3.3
cpm-kernels                   1.0.11
cryptography                  41.0.2
DataProperty                  1.0.1
datasets                      2.14.4
debugpy                       1.6.7
decorator                     5.1.1
deepspeed                     0.10.2
dill                          0.3.7
distlib                       0.3.7
docker-pycreds                0.4.0
docstring-parser              0.15
einops                        0.6.1
evaluate                      0.4.0
executing                     0.8.3
fastapi                       0.103.1
fastjsonschema                2.18.1
filelock                      3.9.0
flash-attn                    2.5.6
frozenlist                    1.4.0
fsspec                        2023.9.0
fuzzywuzzy                    0.18.0
gitdb                         4.0.10
GitPython                     3.1.34
gmpy2                         2.1.2
google-auth                   2.22.0
google-auth-oauthlib          1.0.0
greenlet                      3.0.3
grpcio                        1.57.0
h11                           0.14.0
hjson                         3.1.0
huggingface-hub               0.21.4
idna                          3.4
ipykernel                     6.25.0
ipython                       8.12.2
jedi                          0.18.1
jieba                         0.42.1
Jinja2                        3.1.2
joblib                        1.3.2
jsonlines                     4.0.0
jupyter_client                8.1.0
jupyter_core                  5.3.0
lm-eval                       0.3.0       /slurmhome/aps/xuzhihui/lm-evaluation-harness
Mako                          1.3.0
Markdown                      3.4.4
markdown-it-py                3.0.0
MarkupSafe                    2.1.1
matplotlib-inline             0.1.6
mbstrdecoder                  1.1.3
mdurl                         0.1.2
mkl-fft                       1.3.6
mkl-random                    1.2.2
mkl-service                   2.4.0
monotonic                     1.6
mpmath                        1.3.0
multidict                     6.0.4
multiprocess                  0.70.15
nest-asyncio                  1.5.6
networkx                      3.1
ninja                         1.11.1
numexpr                       2.8.5
numpy                         1.25.2
nvidia-ml-py3                 7.352.0
oauthlib                      3.2.2
omegaconf                     2.3.0
onnx                          1.15.0
opencompass                   0.1.0
opencompass                   0.1.0
packaging                     23.1
pandas                        2.1.0
parso                         0.8.3
pathtools                     0.1.2
pathvalidate                  3.1.0
peewee                        3.16.3
peft                          0.5.0
pexpect                       4.8.0
pickleshare                   0.7.5
Pillow                        9.4.0
pip                           23.2.1
platformdirs                  3.10.0
portalocker                   2.7.0
prompt-toolkit                3.0.36
protobuf                      4.24.2
psutil                        5.9.5
ptyprocess                    0.7.0
pure-eval                     0.2.2
py-cpuinfo                    9.0.0
py3nvml                       0.2.7
pyarrow                       13.0.0
pyasn1                        0.5.0
pyasn1-modules                0.3.0
pybind11                      2.11.1
pycountry                     22.3.5
pycparser                     2.21
pydantic                      1.10.12
Pygments                      2.15.1
pyOpenSSL                     23.2.0
PySocks                       1.7.1
pytablewriter                 1.0.0
python-dateutil               2.8.2
pytz                          2023.3
PyYAML                        6.0.1
pyzmq                         25.1.0
regex                         2023.8.8
requests                      2.31.0
requests-oauthlib             1.3.1
responses                     0.18.0
RestrictedPython              7.0
rich                          13.7.1
rouge-score                   0.0.4
rsa                           4.9
sacrebleu                     1.5.0
safetensors                   0.4.2
sanic-ext                     23.6.0
sanic-routing                 23.6.0
scikit-learn                  1.3.0
scipy                         1.11.2
segment-analytics-python      2.2.3
sentencepiece                 0.1.99
sentry-sdk                    1.30.0
setproctitle                  1.3.2
setuptools                    68.0.0
shtab                         1.7.1
six                           1.16.0
smmap                         5.0.0
sniffio                       1.3.0
SQLAlchemy                    1.4.51
sqlitedict                    2.1.0
stack-data                    0.2.0
starlette                     0.27.0
sympy                         1.11.1
tabledata                     1.3.1
tcolorpy                      0.1.3
tensorboard                   2.14.0
tensorboard-data-server       0.7.1
threadpoolctl                 3.2.0
tokenizers                    0.15.2
torch                         2.0.1
torchaudio                    2.0.2
torchvision                   0.15.2
tornado                       6.3.2
tqdm                          4.66.2
tqdm-multiprocess             0.0.11
traitlets                     5.7.1
transformers                  4.38.2
transformers-stream-generator 0.0.4
triton                        2.0.0
trl                           0.7.12.dev0
typepy                        1.3.1
typing_extensions             4.7.1
tyro                          0.7.3
tzdata                        2023.3
ujson                         5.8.0
urllib3                       1.26.16
uvicorn                       0.23.2
uvloop                        0.17.0
wandb                         0.15.9
wcwidth                       0.2.5
Werkzeug                      2.3.7
wheel                         0.38.4
wrapt                         1.15.0
xmltodict                     0.13.0
xxhash                        3.3.0
yarl                          1.9.2
zipp                          3.16.2
zstandard                     0.21.0

Expected behavior

run success

wangjianqiao111 avatar Mar 13 '24 06:03 wangjianqiao111

all code:

import os import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline, logging, ) from peft import LoraConfig, PeftModel from trl import SFTTrainer from datasets import load_from_disk

offline_dataset_path = "./guanaco-llama2-1k-offline" dataset = load_from_disk(offline_dataset_path) print('wjq111',type(dataset))

model_name = "/slurmhome/aps/aps_sft_models/llama2-7B-origin-sft"

new_model ='./finetuned_model_llama2-7b'

lora_r = 64

lora_alpha = 16

lora_dropout = 0.1

use_4bit = True

bnb_4bit_compute_dtype = "float16"

bnb_4bit_quant_type = "nf4"

use_nested_quant = False

output_dir = "./results"

num_train_epochs = 1

fp16 = False bf16 = False

per_device_train_batch_size = 1

per_device_eval_batch_size = 1

gradient_accumulation_steps = 1

gradient_checkpointing = True

max_grad_norm = 0.3

learning_rate = 2e-4

weight_decay = 0.001

optim = "paged_adamw_32bit"

lr_scheduler_type = "cosine"

max_steps = -1

warmup_ratio = 0.03

group_by_length = True

save_steps = 0

logging_steps = 25

max_seq_length = None

packing = False

device_map = {"": 0}

compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig( load_in_4bit=use_4bit, bnb_4bit_quant_type=bnb_4bit_quant_type, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=use_nested_quant, )

if compute_dtype == torch.float16 and use_4bit: major, _ = torch.cuda.get_device_capability() if major >= 8: print("=" * 80) print("Your GPU supports bfloat16: accelerate training with bf16=True") print("=" * 80)

model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map=device_map ) model.config.use_cache = False model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right"

peft_config = LoraConfig( lora_alpha=lora_alpha, lora_dropout=lora_dropout, r=lora_r, bias="none", task_type="CAUSAL_LM", )

training_arguments = TrainingArguments( output_dir=output_dir, num_train_epochs=num_train_epochs, per_device_train_batch_size=per_device_train_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, optim=optim, save_steps=save_steps, logging_steps=logging_steps, learning_rate=learning_rate, weight_decay=weight_decay, fp16=fp16, bf16=bf16, max_grad_norm=max_grad_norm, max_steps=max_steps, warmup_ratio=warmup_ratio, group_by_length=group_by_length, lr_scheduler_type=lr_scheduler_type )

trainer = SFTTrainer( model=model, train_dataset=dataset, peft_config=peft_config, dataset_text_field="text", max_seq_length=max_seq_length, tokenizer=tokenizer, args=training_arguments, packing=packing, )

trainer.train()

trainer.model.save_pretrained(new_model)

wangjianqiao111 avatar Mar 13 '24 06:03 wangjianqiao111

run codes error:

Traceback (most recent call last): File "/slurmhome/aps/wjq_test/llama_sft.py", line 196, in trainer = SFTTrainer( ^^^^^^^^^^^ File "/slurmhome/wangjq/.local/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 305, in init super().init( File "/slurmhome/wangjq/.local/lib/python3.11/site-packages/transformers/trainer.py", line 367, in init self.create_accelerator_and_postprocess() File "/slurmhome/wangjq/.local/lib/python3.11/site-packages/transformers/trainer.py", line 4127, in create_accelerator_and_postprocess self.accelerator = Accelerator( ^^^^^^^^^^^^ TypeError: Accelerator.init() got an unexpected keyword argument 'use_seedable_sampler'

wangjianqiao111 avatar Mar 13 '24 06:03 wangjianqiao111

python vension:

Python 3.11.4

wangjianqiao111 avatar Mar 13 '24 06:03 wangjianqiao111

The accelerate version you use is too old, the argument was introduced in v0.26. If you upgrade to the latest version, v0.28, the way to pass the argument is like this:.

from accelerate import DataLoaderConfiguration
dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True)
accelerator = Accelerator(..., dataloader_config=dataloader_config)

BenjaminBossan avatar Mar 13 '24 10:03 BenjaminBossan

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 12 '24 15:04 github-actions[bot]