peft
peft copied to clipboard
Fix warning messages about `config.json` when the base `model_id` is local.
It should be possible for the user to specify a local directory as the base model in the library.
However, currently the library only checks for remote presence of config.json, and fails to check the actual config.json when using a local repo.
This PR adds check for a local model_id and fixes the behavior.
Thanks for the PR, what you describe sounds reasonable. Do you have a small example where this change would apply? Ideally, we can use that to create a unit test.
Also, could you please remove the \ for line breaks?
Sure, a simple example is to create a LoRA adapter for a local base model and saving it.
for example, create a PeftModel for a local snapshot of mistralai/Mistral-7B-v0.1, and saving the adapter issues a warning saying
warnings.warn(
f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified."
)
And the configuration is not checked therefore.
from transformers import AutoModelForCausalLM
from peft import LoraConfig, PeftModel
from peft import prepare_model_for_kbit_training, get_peft_model
local_dir = 'path/to/model'
base_model = AutoModelForCausalLM.from_pretrained(local_dir)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)
peft_model = get_peft_model(base_model, peft_config)
peft_model.save_pretrained("test")
Is this sufficient? @BenjaminBossan
Nice, thanks for providing the example. I used it to create a test:
class TestLocalModel:
def test_local_model_saving_no_warning(self, recwarn, tmp_path):
model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id)
local_dir = tmp_path / model_id
model.save_pretrained(local_dir)
del model
base_model = AutoModelForCausalLM.from_pretrained(local_dir)
peft_config = LoraConfig()
peft_model = get_peft_model(base_model, peft_config)
peft_model.save_pretrained(local_dir)
for warning in recwarn.list:
assert "Could not find a config file" not in warning.message.args[0]
We could for instance put it into tests/test_hub_features.py, WDYT? Running it locally on main, it currently fails, but on your branch, it should pass.
Certainly, I followed the syntax in testing_common.py and created a test unit for the issue.
Maybe some further checks? @BenjaminBossan
Certainly, I followed the syntax in
testing_common.pyand created a test unit for the issue.
The way you added the test, it's not executed. You would have to add corresponding methods that call this method in test_decoder_models.py, test_encoder_decoder_models.py, etc. But this is overkill, we don't really need to check this with all kind of different model architectures. Instead, as I suggested earlier, just add this test to tests/test_hub_features.py and it should be good. Let's also add a comment that explains why we need this test.
The test case is fixed as advised and comments are added to explain the issue. Please review the changes, thanks.
@elementary-particle Thanks for the update. Could you please run make style?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@elementary-particle This PR is almost good to go, just a small merge conflict, could you please check it out?
Thanks for keeping up with this PR. The merge conflict is resolved.