transformers
transformers copied to clipboard
LLaMA Implementation
What does this PR do?
Implementation of LLaMA models (https://arxiv.org/abs/2302.13971). Model weights can be requested here. Weight conversion script is included.
Weights conversion can be run via:
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights \
--model_size 7B \
--output_dir /output/path
Models can then be loaded via:
tokenizer = transformers.LLaMATokenizer.from_pretrained("/output/path/tokenizer/")
model = transformers.LLaMAForCausalLM.from_pretrained("/output/path/llama-7b/")
Example:
batch = tokenizer(
"The primary use of LLaMA is research on large language models, including",
return_tensors="pt",
add_special_tokens=False
)
batch = {k: v.cuda() for k, v in batch.items()}
generated = model.generate(batch["input_ids"], max_length=100)
print(tokenizer.decode(generated[0]))
Fixes https://github.com/huggingface/transformers/issues/21796
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@ArthurZucker @younesbelkada
does this work with int8?
does this work with int8?
No idea! I haven't messed with int8 too much myself. It ought to be compatible with whatever is already supported in the HF models.
nice work! thanks for the upload and I hope it gets pulled
The documentation is not available anymore as the PR was closed or merged.
It looks like the tests which are currently failing are unrelated to the LLaMA code, so this should be good to review/use.
If folks can try it out (particularly with the larger, sharded models) and see if there are any issues, that will be helpful!
It looks like the tests which are currently failing are unrelated to the LLaMA code, so this should be good to review/use.
If folks can try it out (particularly with the larger, sharded models) and see if there are any issues, that will be helpful!
At lest the convert script seems to work fine. I was able to convert 7B to 30B. I do not have enough ram to convert 65B.
Great work. thanks for putting this together
After replacing transformers from Kobold with this PR I am able to load the shards as expected. Just I cant generate anything because Kobold still needs some changes.

does this work with int8?
No idea! I haven't messed with int8 too much myself. It ought to be compatible with whatever is already supported in the HF models.
Int8 seems not working but float16 is fine, in my hasty put-together test at https://github.com/zsc/llama_infer . Please throw a comment in case you find something!
@zphang I'm not able to get something like tokenizer = AutoTokenizer.from_pretrained("/data/llama/hf/7b/tokenizer/") to work. Is this intentional or just leaving AutoTokenizer for future work?
@zphang I'm not able to get something like
tokenizer = AutoTokenizer.from_pretrained("/data/llama/hf/7b/tokenizer/")to work. Is this intentional or just leaving AutoTokenizer for future work?
What issue are you having / what is the error?
I have tested the code and these are my findings:
- The conversion script works.
- Loading the model works.
- Loading the tokenizer with
transformers.LLaMATokenizer.from_pretrainedworks. - Loading the tokenizer with
AutoTokenizer.from_pretraineddoes not work and generates this error:
OSError: /tmp/converted/tokenizer/ does not appear to have a file named config.json. Checkout
'https://huggingface.co//tmp/converted/tokenizer//None' for available files.
- The generated text seems to be incoherent. If I try these default values for the generation parameters:
model.generate(input_ids, eos_token_id=2, do_sample=True, temperature=1, top_p=1, typical_p=1, repetition_penalty=1, top_k=50, min_length=0, no_repeat_ngram_size=0, num_beams=1, penalty_alpha=0, length_penalty=1, early_stopping=False, max_new_tokens=200).cuda()
with this prompt:
Common sense questions and answers
Question: What color is the sky?
Factual answer:
I get
Common sense questions and answers
Question: What color is the sky?
Factual answer: Tags: python, django, django-models
Question: Using Django with multiple databases
I am attempting to use django with multiple databases, and I have the following code:
\begin{code}
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': ':memory:',
},
'db_one': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': 'db_one',
},
'db_two': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': 'db_two',
},
}
It seems to me that prompts are being completely ignored.
- Loading in 8-bit mode with
load_in_8bit=Trueworks.
This is OK: tokenizer = transformers.LLaMATokenizer.from_pretrained("/data/llama/hf/7b/tokenizer/")
If using tokenizer = AutoTokenizer.from_pretrained("/data/llama/hf/7b/tokenizer/" then it will complain no "config.json".
OSError: /data/llama/hf/7b/tokenizer/ does not appear to have a file named config.json. Checkout
'https://huggingface.co//data/llama/hf/7b/tokenizer//None' for available files.
I then hacked by softlinking /data/llama/hf/7b/tokenizer/special_tokens_map.json to /data/llama/hf/7b/tokenizer/config.json and it works. So maybe just rename?
Anyway, can now happily play with LLaMA in Hugging Face world and thanks for the great work!
Thanks for the comments. Looks like the saved tokenizer doesn't work for AutoTokenizer but works if you directly instantiate from LLaMATokenizer. Maybe one of the HF folks can chime in on the best way to address that.
The generated text seems to be incoherent. If I try these default values for the generation parameters:
Can you check the input_ids you're using to generate? The tokenizer currently adds both BOS and EOS tokens by default, and an EOS might cause the model to ignore your prompt.
Perhaps I can set EOS to not be added by default so it operates closer to expected behavior.
For this prompt:
'Common sense questions and answers\n\nQuestion: What color is the sky?\nFactual answer:'
these are the input_ids:
tensor([[ 1, 13103, 4060, 5155, 322, 6089, 13, 13, 16492, 29901,
1724, 2927, 338, 278, 14744, 29973, 13, 29943, 19304, 1234,
29901, 2]], device='cuda:0')
I do not know how to interpret these numbers, but if there is an EOS token in that tensor and that token is causing the text generation to derail, changing that default would be valuable.
1 is BOS and 2 is EOS. Can you try without the last input id?
I also added an example in my PR message.
I confirm that doing this
input_ids = input_ids[:, :-1]
to remove the last input id before calling model.generate(...) causes the text generation to become coherent:
Common sense questions and answers
Question: What color is the sky?
Factual answer: The sky is blue. The sky is blue, and it is a fact that it is blue. The sky is indisputably blue.
Added a commit that should fix the tokenizer issues, and not add BOS and EOS by default.
Awesome, I confirm that the text generation is coherent by default now.
I still cannot load the tokenizer with AutoTokenizer.from_pretrained. The error has now changed to this:
File "/tmp/transformers/src/transformers/models/auto/tokenization_auto.py", line 694, in from_pretrained
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
File "/tmp/transformers/src/transformers/models/auto/auto_factory.py", line 610, in __getitem__
raise KeyError(key)
KeyError: <class 'transformers.models.llama.configuration_llama.LLaMAConfig'>
does this work with int8?
No idea! I haven't messed with int8 too much myself. It ought to be compatible with whatever is already supported in the HF models.
After the fix with EOS, int8 (bitsandbytes) looks decent. Example in https://github.com/zsc/llama_infer/blob/main/README.md
After https://github.com/huggingface/transformers/pull/21955/commits/459e2ac9f551650ced58deb1c65f06c3d483d606, AutoTokenizer.from_pretrained now works as expected.
KoboldAI now works
I'd like to see a more memory-efficient conversion script, the current version loads everything into system memory which makes converting the 30B and 65B variants challenging on some systems
Yes, this is a quick and dirty version that loads everything into memory. One issue is that the way the weights are sharded (for tensor parallelism) is orthogonal to the way that HF shards the weights (by layer). So either we have to load everything in at once, or we have to load/write multiple times. The latter would be slower but useful for folks with less memory.
Has anyone tested loading 65B with accelerate to load on multiple GPUs?
I can't load the 7B model to cuda with one A4000 should I just change the gpu?
I'm observing some strange behavior with the tokenizer when encoding sequences beginning with a newline:
>>> t = AutoTokenizer.from_pretrained("llama_hf/tokenizer")
>>> res = t.encode("\nYou:")
>>> res
[29871, 13, 3492, 29901]
>>> t.decode(res)
'You:'
The newline seems to get lost somewhere along the way.
EDIT: Looking into this, it seems it might be the expected behavior of sentencepiece.
Has anyone tested loading 65B with
accelerateto load on multiple GPUs?
| fp16 | int8(bitsandbytes) | |
|---|---|---|
| V100 | OK, 5xV100 | Bad results, short generated sequences |
| A100 | OK, 6xA100 when using "auto" | OK, 3xA100 |
Yes, I currently have a 65B fp16 model running on 6xV100 now (5X should be enough). My working code is at https://github.com/zsc/llama_infer/ . If there are CUDA OOM due to bad distribution of weights among cards, one thing worth trying is tweaking the device_map (accelerate seems to only counts weights when enforcing the memory cap in device_map, so there is an art for setting custom cap a little lower for every card, especially card 0).
Strangely, int8 (LLM.int8 to be specific) for 65B model works like a charm on A100, but leads to bad results on V100 with abnormally short generated sequences.
Strangely, int8 (LLM.int8 to be specific) for 65B model works like a charm on A100, but leads to bad results on V100 with abnormally short generated sequences.
I will have a look at this later next week. The V100 takes a different code path than the A100 because the V100 does not support Int8 tensor cores. I think that is the issue here. We will soon publish FP4 inference which should be more universal and easier to use.
Jumping on @thomasw21 comment, we sadly cannot accept any code licensed GPLv3 as it would taint the whole Transformers library under that license. This means that the modeling code should be copied from GPT-NeoX whenever possible (with Copied from statements) since I believe that this model is very close to it and that you should be super familiar with it @zphang ;-) , and that no parts of the modeling code should be copy-pasted from the original Llama code.
We also cannot attribute Copyright to Meta-AI /Meta in all those files, as attributing that copyright would admit the code in the PR is based on theirs and thus get us back to the license problem.