TypeError: Accelerator.__init__() got an unexpected keyword argument 'use_seedable_sampler'
System Info
absl-py 1.4.0
accelerate 0.24.1
addict 2.4.0
aim 3.17.5
aim-ui 3.17.5
aimrocks 0.4.0
aiofiles 23.2.1
aiohttp 3.8.5
aiosignal 1.3.1
antlr4-python3-runtime 4.9.3
anyio 3.7.1
appdirs 1.4.4
asttokens 2.0.5
async-timeout 4.0.3
attrs 23.1.0
backcall 0.2.0
backoff 2.2.1
base58 2.0.1
bitsandbytes 0.41.1
brotlipy 0.7.0
cachetools 5.3.1
certifi 2023.7.22
cffi 1.15.1
chardet 5.2.0
charset-normalizer 2.0.4
click 8.1.7
colorama 0.4.6
comm 0.1.2
contexttimer 0.3.3
cpm-kernels 1.0.11
cryptography 41.0.2
DataProperty 1.0.1
datasets 2.14.4
debugpy 1.6.7
decorator 5.1.1
deepspeed 0.10.2
dill 0.3.7
distlib 0.3.7
docker-pycreds 0.4.0
docstring-parser 0.15
einops 0.6.1
evaluate 0.4.0
executing 0.8.3
fastapi 0.103.1
fastjsonschema 2.18.1
filelock 3.9.0
flash-attn 2.5.6
frozenlist 1.4.0
fsspec 2023.9.0
fuzzywuzzy 0.18.0
gitdb 4.0.10
GitPython 3.1.34
gmpy2 2.1.2
google-auth 2.22.0
google-auth-oauthlib 1.0.0
greenlet 3.0.3
grpcio 1.57.0
h11 0.14.0
hjson 3.1.0
huggingface-hub 0.21.4
idna 3.4
ipykernel 6.25.0
ipython 8.12.2
jedi 0.18.1
jieba 0.42.1
Jinja2 3.1.2
joblib 1.3.2
jsonlines 4.0.0
jupyter_client 8.1.0
jupyter_core 5.3.0
lm-eval 0.3.0 /slurmhome/aps/xuzhihui/lm-evaluation-harness
Mako 1.3.0
Markdown 3.4.4
markdown-it-py 3.0.0
MarkupSafe 2.1.1
matplotlib-inline 0.1.6
mbstrdecoder 1.1.3
mdurl 0.1.2
mkl-fft 1.3.6
mkl-random 1.2.2
mkl-service 2.4.0
monotonic 1.6
mpmath 1.3.0
multidict 6.0.4
multiprocess 0.70.15
nest-asyncio 1.5.6
networkx 3.1
ninja 1.11.1
numexpr 2.8.5
numpy 1.25.2
nvidia-ml-py3 7.352.0
oauthlib 3.2.2
omegaconf 2.3.0
onnx 1.15.0
opencompass 0.1.0
opencompass 0.1.0
packaging 23.1
pandas 2.1.0
parso 0.8.3
pathtools 0.1.2
pathvalidate 3.1.0
peewee 3.16.3
peft 0.5.0
pexpect 4.8.0
pickleshare 0.7.5
Pillow 9.4.0
pip 23.2.1
platformdirs 3.10.0
portalocker 2.7.0
prompt-toolkit 3.0.36
protobuf 4.24.2
psutil 5.9.5
ptyprocess 0.7.0
pure-eval 0.2.2
py-cpuinfo 9.0.0
py3nvml 0.2.7
pyarrow 13.0.0
pyasn1 0.5.0
pyasn1-modules 0.3.0
pybind11 2.11.1
pycountry 22.3.5
pycparser 2.21
pydantic 1.10.12
Pygments 2.15.1
pyOpenSSL 23.2.0
PySocks 1.7.1
pytablewriter 1.0.0
python-dateutil 2.8.2
pytz 2023.3
PyYAML 6.0.1
pyzmq 25.1.0
regex 2023.8.8
requests 2.31.0
requests-oauthlib 1.3.1
responses 0.18.0
RestrictedPython 7.0
rich 13.7.1
rouge-score 0.0.4
rsa 4.9
sacrebleu 1.5.0
safetensors 0.4.2
sanic-ext 23.6.0
sanic-routing 23.6.0
scikit-learn 1.3.0
scipy 1.11.2
segment-analytics-python 2.2.3
sentencepiece 0.1.99
sentry-sdk 1.30.0
setproctitle 1.3.2
setuptools 68.0.0
shtab 1.7.1
six 1.16.0
smmap 5.0.0
sniffio 1.3.0
SQLAlchemy 1.4.51
sqlitedict 2.1.0
stack-data 0.2.0
starlette 0.27.0
sympy 1.11.1
tabledata 1.3.1
tcolorpy 0.1.3
tensorboard 2.14.0
tensorboard-data-server 0.7.1
threadpoolctl 3.2.0
tokenizers 0.15.2
torch 2.0.1
torchaudio 2.0.2
torchvision 0.15.2
tornado 6.3.2
tqdm 4.66.2
tqdm-multiprocess 0.0.11
traitlets 5.7.1
transformers 4.38.2
transformers-stream-generator 0.0.4
triton 2.0.0
trl 0.7.12.dev0
typepy 1.3.1
typing_extensions 4.7.1
tyro 0.7.3
tzdata 2023.3
ujson 5.8.0
urllib3 1.26.16
uvicorn 0.23.2
uvloop 0.17.0
wandb 0.15.9
wcwidth 0.2.5
Werkzeug 2.3.7
wheel 0.38.4
wrapt 1.15.0
xmltodict 0.13.0
xxhash 3.3.0
yarl 1.9.2
zipp 3.16.2
zstandard 0.21.0
Expected behavior
run success
all code:
import os import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline, logging, ) from peft import LoraConfig, PeftModel from trl import SFTTrainer from datasets import load_from_disk
offline_dataset_path = "./guanaco-llama2-1k-offline" dataset = load_from_disk(offline_dataset_path) print('wjq111',type(dataset))
model_name = "/slurmhome/aps/aps_sft_models/llama2-7B-origin-sft"
new_model ='./finetuned_model_llama2-7b'
lora_r = 64
lora_alpha = 16
lora_dropout = 0.1
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False
output_dir = "./results"
num_train_epochs = 1
fp16 = False bf16 = False
per_device_train_batch_size = 1
per_device_eval_batch_size = 1
gradient_accumulation_steps = 1
gradient_checkpointing = True
max_grad_norm = 0.3
learning_rate = 2e-4
weight_decay = 0.001
optim = "paged_adamw_32bit"
lr_scheduler_type = "cosine"
max_steps = -1
warmup_ratio = 0.03
group_by_length = True
save_steps = 0
logging_steps = 25
max_seq_length = None
packing = False
device_map = {"": 0}
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig( load_in_4bit=use_4bit, bnb_4bit_quant_type=bnb_4bit_quant_type, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=use_nested_quant, )
if compute_dtype == torch.float16 and use_4bit: major, _ = torch.cuda.get_device_capability() if major >= 8: print("=" * 80) print("Your GPU supports bfloat16: accelerate training with bf16=True") print("=" * 80)
model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map=device_map ) model.config.use_cache = False model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right"
peft_config = LoraConfig( lora_alpha=lora_alpha, lora_dropout=lora_dropout, r=lora_r, bias="none", task_type="CAUSAL_LM", )
training_arguments = TrainingArguments( output_dir=output_dir, num_train_epochs=num_train_epochs, per_device_train_batch_size=per_device_train_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, optim=optim, save_steps=save_steps, logging_steps=logging_steps, learning_rate=learning_rate, weight_decay=weight_decay, fp16=fp16, bf16=bf16, max_grad_norm=max_grad_norm, max_steps=max_steps, warmup_ratio=warmup_ratio, group_by_length=group_by_length, lr_scheduler_type=lr_scheduler_type )
trainer = SFTTrainer( model=model, train_dataset=dataset, peft_config=peft_config, dataset_text_field="text", max_seq_length=max_seq_length, tokenizer=tokenizer, args=training_arguments, packing=packing, )
trainer.train()
trainer.model.save_pretrained(new_model)
run codes error:
Traceback (most recent call last):
File "/slurmhome/aps/wjq_test/llama_sft.py", line 196, in
python vension:
Python 3.11.4
The accelerate version you use is too old, the argument was introduced in v0.26. If you upgrade to the latest version, v0.28, the way to pass the argument is like this:.
from accelerate import DataLoaderConfiguration
dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True)
accelerator = Accelerator(..., dataloader_config=dataloader_config)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.