peft icon indicating copy to clipboard operation
peft copied to clipboard

Gemma with softprompt raises error

Open krishnakanthnakkav2 opened this issue 8 months ago • 3 comments

System Info

I am doing soft-prompt tuning on gemma2b. THere is issue during generation

File "/home/krishna/PII/fs-llm/libs/peft/src/peft/peft_model.py", line 1920, in prepare_inputs_for_generation model_kwargs["attention_mask"] = torch.cat( RuntimeError: Tensors must have same number of dimensions: got 2 and 4

Who can help?

No response

Information

  • [ ] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder
  • [ ] My own task or dataset (give details below)

Reproduction

softprompt with gemma

Expected behavior

No error

krishnakanthnakkav2 avatar Mar 26 '25 10:03 krishnakanthnakkav2

Sample code

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PromptTuningConfig
import torch
from peft import get_peft_model, TaskType, PromptTuningConfig, PromptTuningInit

# Load the pre-trained gemma2b-it model and tokenizer
model_name = "google/gemma-2-2b"
cache_dir = "/assets/hub"




model = AutoModelForCausalLM.from_pretrained(
    model_name, cache_dir=cache_dir,
    attn_implementation="eager" if "gemma" in model_name else None,
    device_map="auto",
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir,)

config = PromptTuningConfig(
    peft_type="PROMPT_TUNING",
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.RANDOM,
    prompt_tuning_init_text="email",  # "phone" # "address"
    num_virtual_tokens=20,
    tokenizer_name_or_path=model_name)

model = get_peft_model(model, config)



# Define a batch of text for generation
input_texts = [
    "In the world of artificial intelligence,",
   
]

# Tokenize the batch of input texts
inputs = tokenizer(input_texts, return_tensors="pt",
                    padding=True, truncation=True
                   )

print(f"Input attention mask shape: {inputs['attention_mask'].shape}")

# Generate text with the model for the batch
generated_output = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    max_length=50
)

# Decode and print the generated text for each example in the batch
generated_texts = tokenizer.batch_decode(
    generated_output, skip_special_tokens=True
)

for idx, generated_text in enumerate(generated_texts):
    print(f"Generated text for input {idx + 1}: {generated_text}")

The error happens in peft_model.py during concatenation at https://github.com/huggingface/peft/blob/e2262d29a93dd190c9d52267a8b7c386e8bce1b2/src/peft/peft_model.py#L1920

where the shapes of model_kwargs["attention_mask"] and prefix_attention_mask are [1,1,8,49] and [1, 20] before the concatenation step.

Here 49 corresponds to number of max_length minus 1. I changed the max_length and this value changes accordingly, and 8 refers to the number of tokens in the input.

I tested with other models like EleutherAI/pythia-6.9b, the code works and it print these variables as 2-D tensors like [1,8] and [1,20]

Edit: The issue is possibly from transformers function self.base_model_prepare_inputs_for_generation() which returns the 4-D attention mask.

Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
_openmp_mutex             5.1                       1_gnu
accelerate                1.5.2                    pypi_0    pypi
blobfile                  3.0.0                    pypi_0    pypi
bzip2                     1.0.8                h5eee18b_6
ca-certificates           2025.2.25            h06a4308_0
certifi                   2025.1.31                pypi_0    pypi
charset-normalizer        3.4.1                    pypi_0    pypi
expat                     2.6.4                h6a678d5_0
filelock                  3.18.0                   pypi_0    pypi
fsspec                    2025.3.0                 pypi_0    pypi
huggingface-hub           0.29.3                   pypi_0    pypi
idna                      3.10                     pypi_0    pypi
jinja2                    3.1.6                    pypi_0    pypi
ld_impl_linux-64          2.40                 h12ee557_0
libffi                    3.4.4                h6a678d5_1
libgcc-ng                 11.2.0               h1234567_1
libgomp                   11.2.0               h1234567_1
libmpdec                  4.0.0                h5eee18b_0
libstdcxx-ng              11.2.0               h1234567_1
libuuid                   1.41.5               h5eee18b_0
lxml                      5.3.1                    pypi_0    pypi
markupsafe                3.0.2                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0
networkx                  3.4.2                    pypi_0    pypi
numpy                     2.2.4                    pypi_0    pypi
nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
openssl                   3.0.16               h5eee18b_0
packaging                 24.2                     pypi_0    pypi
peft                      0.14.0                   pypi_0    pypi
pip                       25.0            py313h06a4308_0
protobuf                  6.30.1                   pypi_0    pypi
psutil                    7.0.0                    pypi_0    pypi
pycryptodomex             3.22.0                   pypi_0    pypi
python                    3.13.2          hf623796_100_cp313
python_abi                3.13                    0_cp313
pyyaml                    6.0.2                    pypi_0    pypi
readline                  8.2                  h5eee18b_0
regex                     2024.11.6                pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
safetensors               0.5.3                    pypi_0    pypi
setuptools                75.8.0          py313h06a4308_0
sqlite                    3.45.3               h5eee18b_0
sympy                     1.13.1                   pypi_0    pypi
tiktoken                  0.8.0                    pypi_0    pypi
tk                        8.6.14               h39e8969_0
tokenizers                0.21.1                   pypi_0    pypi
torch                     2.6.0                    pypi_0    pypi
tqdm                      4.67.1                   pypi_0    pypi
transformers              4.49.0                   pypi_0    pypi
triton                    3.2.0                    pypi_0    pypi
typing-extensions         4.13.0                   pypi_0    pypi
tzdata                    2025a                h04d1e81_0
urllib3                   2.3.0                    pypi_0    pypi
wheel                     0.45.1          py313h06a4308_0
xz                        5.6.4                h5eee18b_1
zlib                      1.2.13               h5eee18b_1

krishnakanthnakkav2 avatar Mar 26 '25 11:03 krishnakanthnakkav2

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Apr 25 '25 15:04 github-actions[bot]

not stale, waiting for #2458 to be merged

BenjaminBossan avatar Apr 25 '25 15:04 BenjaminBossan