peft
peft copied to clipboard
Gemma with softprompt raises error
System Info
I am doing soft-prompt tuning on gemma2b. THere is issue during generation
File "/home/krishna/PII/fs-llm/libs/peft/src/peft/peft_model.py", line 1920, in prepare_inputs_for_generation model_kwargs["attention_mask"] = torch.cat( RuntimeError: Tensors must have same number of dimensions: got 2 and 4
Who can help?
No response
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder - [ ] My own task or dataset (give details below)
Reproduction
softprompt with gemma
Expected behavior
No error
Sample code
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PromptTuningConfig
import torch
from peft import get_peft_model, TaskType, PromptTuningConfig, PromptTuningInit
# Load the pre-trained gemma2b-it model and tokenizer
model_name = "google/gemma-2-2b"
cache_dir = "/assets/hub"
model = AutoModelForCausalLM.from_pretrained(
model_name, cache_dir=cache_dir,
attn_implementation="eager" if "gemma" in model_name else None,
device_map="auto",
torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir,)
config = PromptTuningConfig(
peft_type="PROMPT_TUNING",
task_type=TaskType.CAUSAL_LM,
prompt_tuning_init=PromptTuningInit.RANDOM,
prompt_tuning_init_text="email", # "phone" # "address"
num_virtual_tokens=20,
tokenizer_name_or_path=model_name)
model = get_peft_model(model, config)
# Define a batch of text for generation
input_texts = [
"In the world of artificial intelligence,",
]
# Tokenize the batch of input texts
inputs = tokenizer(input_texts, return_tensors="pt",
padding=True, truncation=True
)
print(f"Input attention mask shape: {inputs['attention_mask'].shape}")
# Generate text with the model for the batch
generated_output = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=50
)
# Decode and print the generated text for each example in the batch
generated_texts = tokenizer.batch_decode(
generated_output, skip_special_tokens=True
)
for idx, generated_text in enumerate(generated_texts):
print(f"Generated text for input {idx + 1}: {generated_text}")
The error happens in peft_model.py during concatenation at https://github.com/huggingface/peft/blob/e2262d29a93dd190c9d52267a8b7c386e8bce1b2/src/peft/peft_model.py#L1920
where the shapes of model_kwargs["attention_mask"] and
prefix_attention_mask are [1,1,8,49] and [1, 20] before the concatenation step.
Here 49 corresponds to number of max_length minus 1. I changed the max_length and this value changes accordingly, and 8 refers to the number of tokens in the input.
I tested with other models like EleutherAI/pythia-6.9b, the code works and it print these variables as 2-D tensors like [1,8] and [1,20]
Edit: The issue is possibly from transformers function self.base_model_prepare_inputs_for_generation() which returns the 4-D attention mask.
Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
accelerate 1.5.2 pypi_0 pypi
blobfile 3.0.0 pypi_0 pypi
bzip2 1.0.8 h5eee18b_6
ca-certificates 2025.2.25 h06a4308_0
certifi 2025.1.31 pypi_0 pypi
charset-normalizer 3.4.1 pypi_0 pypi
expat 2.6.4 h6a678d5_0
filelock 3.18.0 pypi_0 pypi
fsspec 2025.3.0 pypi_0 pypi
huggingface-hub 0.29.3 pypi_0 pypi
idna 3.10 pypi_0 pypi
jinja2 3.1.6 pypi_0 pypi
ld_impl_linux-64 2.40 h12ee557_0
libffi 3.4.4 h6a678d5_1
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libmpdec 4.0.0 h5eee18b_0
libstdcxx-ng 11.2.0 h1234567_1
libuuid 1.41.5 h5eee18b_0
lxml 5.3.1 pypi_0 pypi
markupsafe 3.0.2 pypi_0 pypi
mpmath 1.3.0 pypi_0 pypi
ncurses 6.4 h6a678d5_0
networkx 3.4.2 pypi_0 pypi
numpy 2.2.4 pypi_0 pypi
nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi
nvidia-nccl-cu12 2.21.5 pypi_0 pypi
nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
openssl 3.0.16 h5eee18b_0
packaging 24.2 pypi_0 pypi
peft 0.14.0 pypi_0 pypi
pip 25.0 py313h06a4308_0
protobuf 6.30.1 pypi_0 pypi
psutil 7.0.0 pypi_0 pypi
pycryptodomex 3.22.0 pypi_0 pypi
python 3.13.2 hf623796_100_cp313
python_abi 3.13 0_cp313
pyyaml 6.0.2 pypi_0 pypi
readline 8.2 h5eee18b_0
regex 2024.11.6 pypi_0 pypi
requests 2.32.3 pypi_0 pypi
safetensors 0.5.3 pypi_0 pypi
setuptools 75.8.0 py313h06a4308_0
sqlite 3.45.3 h5eee18b_0
sympy 1.13.1 pypi_0 pypi
tiktoken 0.8.0 pypi_0 pypi
tk 8.6.14 h39e8969_0
tokenizers 0.21.1 pypi_0 pypi
torch 2.6.0 pypi_0 pypi
tqdm 4.67.1 pypi_0 pypi
transformers 4.49.0 pypi_0 pypi
triton 3.2.0 pypi_0 pypi
typing-extensions 4.13.0 pypi_0 pypi
tzdata 2025a h04d1e81_0
urllib3 2.3.0 pypi_0 pypi
wheel 0.45.1 py313h06a4308_0
xz 5.6.4 h5eee18b_1
zlib 1.2.13 h5eee18b_1
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
not stale, waiting for #2458 to be merged