trlx
trlx copied to clipboard
initial commit for trlx LORA support
Basic support for low rank adaptation.
This should take a similar form to how hydra models are built. It shouldn't be required and directly integrated into ilql or PPO model
https://github.com/CarperAI/trlx/issues/80 Relevant issue.
@ethankim00 just a gentle push on when you expect to finish this?
cc @Sayanc93
I can get to it tomorrow or Monday. I'm wondering what the API should be to avoid modifying the model definitions?
I can get to it tomorrow or Monday. I'm wondering what the API should be to avoid modifying the model definitions?
I think it would be like, instead of modifying the CausalLMWithValueHeads
or GPTHydraHeadWithValueModel
class definitions, the delta versions could be subclasses, and then the config can treat them as just another architecture to be trained
Circling back around on this.
@cat-state does it make sense to do lora + hydra or just have lora be entirely separate...
We could have a function to modify the base model of each different model type rather than creating subclasses.
Hm... I differ to better software engineers haha @jon-tow your input would be great here too
@ethankim00 This looks great! I've made a few changes based on some testing on our cluster. Here's the summary:
-
Updates "gpt_neo"` model type name in the modifier map.
-
Fixes layer regex pattern, as the previous one could not capture on ranges with multiple digits, e.g. adapting LORA to a model's 8 through 11 block layers failed since
[8-11]
is an invalid regex character range. -
Moves the delta model modifications to the base trainer to avoid unnecessary duplication.
-
Change
_opendelta_available
toHAS_OPENDELTA
for consistency with other modules (seeHAS_BNB
).
Overall things look very promising. Check out these runs from the PPO sentiments task here. I'm going to begin testing on ILQL and once that's cleared up, we can get this ready for review and a merge.
Reports:
This looks fantastic! Definitely worth including for the 0.4 release next week. Let's get this merged :)