FIX: Correct adapter dtype with bnb weights
Resolves #2889
Description
The reported bug is this: When the base model is quantized with 4bit bitsandbytes, the adapter weights would be cast to float32, even if autocast_adapter_dtype=False was passed. This is because the dtype of the base layer was not correctly determined in that case. This PR now correctly determines the dtype.
While working on this, I noticed that the peft_model.add_adapter method was lacking the option to disable autocasting. This was added now and the tests cover it as well. I also updated some of the corresponding docstrings.
Tangential changes
An unrelated issue I noticed is when I was debugging: At one point, OSF calls if not hasattr(module, "osf_svd_params"). This would error when the module was a ModulesToSaveWrapper because ModulesToSaveWrapper._hasattr_wrapped was not taking into account the case that there is no active adapter. This is now fixed too.
Moreover, OSF implemented its own _cast_adapter_dtype. This would basically bypass upcasting to float32 of the OSF weights if the base model is loaded in lower precision. However, unless the user explicitly passes autocast_adapter_dtype=False, the default in PEFT is to upcast the adapter weights to float32. With the changes to this PR, upcasting is now done. To make this work with the forward pass, the x is cast to the dtype of the weight. We assume that the output dtype should be the same as the original dtype of x.
TODOs
There is still an issue left with 8bit bnb weights. They don't have a compute dtype, so at a layer level, it is not possible to determine what the dtype of the PEFT adapter should be (of course, it cannot be int8). Therefore, the corresponding tests for 8bit bnb are x-failing for now. One possible solution could be to pass down the dtype of the base model (if any) and use that as a fallback. This could be implemented in a later PR.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@NikhilNayak-debug this PR contains some updates to OSF. Could you please check if they make sense? Check the PR description above for the reason of the changes.
@githubnemo We haven't heard back from Nikhil for two weeks, let's just proceed. PR is ready for review.
@githubnemo I removed the change to _hasattr_wrapped, PR should be good to review.