peft
peft copied to clipboard
Fix warning messages about `config.json` when the base `model_id` is local.
It should be possible for the user to specify a local directory as the base model in the library.
However, currently the library only checks for remote presence of config.json
, and fails to check the actual config.json
when using a local repo.
This PR adds check for a local model_id
and fixes the behavior.
Thanks for the PR, what you describe sounds reasonable. Do you have a small example where this change would apply? Ideally, we can use that to create a unit test.
Also, could you please remove the \
for line breaks?
Sure, a simple example is to create a LoRA adapter for a local base model and saving it.
for example, create a PeftModel
for a local snapshot of mistralai/Mistral-7B-v0.1
, and saving the adapter issues a warning saying
warnings.warn(
f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified."
)
And the configuration is not checked therefore.
from transformers import AutoModelForCausalLM
from peft import LoraConfig, PeftModel
from peft import prepare_model_for_kbit_training, get_peft_model
local_dir = 'path/to/model'
base_model = AutoModelForCausalLM.from_pretrained(local_dir)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)
peft_model = get_peft_model(base_model, peft_config)
peft_model.save_pretrained("test")
Is this sufficient? @BenjaminBossan
Nice, thanks for providing the example. I used it to create a test:
class TestLocalModel:
def test_local_model_saving_no_warning(self, recwarn, tmp_path):
model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id)
local_dir = tmp_path / model_id
model.save_pretrained(local_dir)
del model
base_model = AutoModelForCausalLM.from_pretrained(local_dir)
peft_config = LoraConfig()
peft_model = get_peft_model(base_model, peft_config)
peft_model.save_pretrained(local_dir)
for warning in recwarn.list:
assert "Could not find a config file" not in warning.message.args[0]
We could for instance put it into tests/test_hub_features.py
, WDYT? Running it locally on main, it currently fails, but on your branch, it should pass.
Certainly, I followed the syntax in testing_common.py
and created a test unit for the issue.
Maybe some further checks? @BenjaminBossan
Certainly, I followed the syntax in
testing_common.py
and created a test unit for the issue.
The way you added the test, it's not executed. You would have to add corresponding methods that call this method in test_decoder_models.py
, test_encoder_decoder_models.py
, etc. But this is overkill, we don't really need to check this with all kind of different model architectures. Instead, as I suggested earlier, just add this test to tests/test_hub_features.py
and it should be good. Let's also add a comment that explains why we need this test.
The test case is fixed as advised and comments are added to explain the issue. Please review the changes, thanks.
@elementary-particle Thanks for the update. Could you please run make style
?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@elementary-particle This PR is almost good to go, just a small merge conflict, could you please check it out?
Thanks for keeping up with this PR. The merge conflict is resolved.