mlx-examples
mlx-examples copied to clipboard
Creating a module version of lora.py (for referencing the functions in other scripts)
Just adding a module version of lora.py. This is so that you can call in the training, testing, and generating functions into another python script. It was quite useful for me so I can fully automate the process as well as test in notebooks, while also adding my own code around it. It required minimal changes, so I thought I would see if this is useful. Please let me know and/or give me feedback, would be happy to help.
Main change made:
- taking
if __name__ == "__main__"block and dispersing it to applicable functions - taking out bash --args and surrounding logic, and including those as functions args instead
- creating some separate needed functions that contain duplicate logic, such as
prepare_for_training()
General notes:
- I named the file
lora_module.pybut if there is a better name on your mind please feel free to change it. I just didn't know what to call it besides this - Semantics here, but should is be
run_lora_finetuning()orrun_lora_fine_tuning()?
Example usage:
from lora_module import run_lora_finetuning, run_lora_generate, run_lora_test
# Training args
model_path = './models/Mistral-7B-Instruct-v0.2'
data_path = './data/'
lora_layers = 16
batch_size = 4
iters = 20
seed = 0
resume_adapter_file = None
adapter_file = "adapters.npz"
learning_rate = 1e-5
val_batches = 25
steps_per_report = 10
steps_per_eval = 200
# Fine tuning
run_lora_finetuning(model_path, data_path, lora_layers, batch_size, iters, seed, resume_adapter_file, adapter_file, learning_rate, val_batches, steps_per_report, steps_per_eval)
# Test args
model_path = './models/Mistral-7B-Instruct-v0.2'
data_path = './data/'
adapter_file = "adapters.npz"
test_batches = 500
batch_size = 4
# Testing
run_lora_test(model_path, data_path, adapter_file, test_batches, batch_size)
# Generate args
model_path = './models/Mistral-7B-Instruct-v0.2'
num_tokens = 100
temp = 0.8
adapter_file = "adapters.npz"
prompt = """This is a test prompt for generation, what do you think? """
# Generating
generated_text = run_lora_generate(model_path, num_tokens, temp, adapter_file, prompt)
This is cool! We should have something like that. What do you think about integrating it with the main lora.py instead as I see there is a lot of code duplication?
I think with just a few changes to that script you could import the model and have a few functions like lora.train, lora.generate and lora.test. The other thing I would add is I think it's worth decoupling the model loading from the train, test, generate functions. Makes it easier to avoid loading big models multiple times.
Also you have a bug in the prepare_for_training. It should be loading the LoRA model even in test mode.
Re: integrating with main lora.py, I agree! I had thought about this but didn't want to step on any toes so to speak. Will get on this.
Will also decouple model loading as well, agreed that it's probably not necessary in all spots.
Re: bug, thanks! Will fix that as well. Will get on it and push changes then ping you
Will also probably wait until the impending pull request for QLoRA addition as well so I can pull in it and make the modifications in one fell swoop?
Thanks! Also you should rebase on my updates in #219. I will merge them today!
Just pushed (proposed) final changes. @awni take a look at let me know what you think.
Some miscellaneous notes:
- Decoupled loading the model with the
train/evaluate/generatefunctions, so we're not loading the model more than we have to. - Have we considered loading data with HuggingFace Datasets instead of rolling our own?
- I've noticed that some parts of the
mlx-examplesrepo supportweights.npzonly, while others supportsafetensorsonly, while others support both. Given that in the HuggingFace MLX community space there are a mix of both, I will create an issue after this is done to look more into it - Potentially some things can change on this end once I connect with @chimezie regarding #235 . Taking a quick look, it doesn't look like there should be much, if any, to change on this specific PR, but need to look further into it.
- Considering also having the
generatefunction return the generated response, such that it can be stored to a variable.
Quick example usage:
from lora import load_pretrained_model, load_datasets, train, evaluate, generate
# Model args
model_path = './models/Mistral-7B-Instruct-v0.2-4bit-mlx'
lora_layers = 16
seed = 0
resume_adapter_file = None
model_kwargs = {
'batch_size': 4,
'iters': 20,
'adapter_file': "adapters.npz",
'learning_rate': 1e-5,
'val_batches': 25,
'steps_per_report': 10,
'steps_per_eval': 200,
}
data_path = './data/'
model, tokenizer, config = load_pretrained_model(model_path, lora_layers, seed, resume_adapter_file)
train_set, valid_set, test_set = load_datasets(data_path)
train(model, train_set, valid_set, tokenizer, **model_kwargs)
# Testing
# Test args
model_path = './models/Mistral-7B-Instruct-v0.2-4bit-mlx'
data_path = './data/'
adapter_file = "adapters.npz"
test_batches = 100
batch_size = 4
test_loss = evaluate(model, test_set, tokenizer, batch_size, test_batches)
print(f"Test loss: {test_loss}")
# Generate
# Generate args
model_path = './models/Mistral-7B-Instruct-v0.2-4bit-mlx'
num_tokens = 100
temp = 0.8
adapter_file = "adapters.npz"
prompt = """input: What is the meaning of life?
output: """
generated_text = generate(model, prompt, tokenizer, num_tokens, temp, adapter_file)
@ProjectProgramAMark I am currently working on moving the lora example to mlx-lm in order to remove duplicating code between lora and mlx-lm. also, since the current lora example is mostly applicable to llm models, it would be more logical for mlx-lm. I have a similar concept of modularizing Lora; however, I think we might need to break down the lora module into smaller components. For instance, the Python API may looks like as follows:
import mlx.optimizers as optim
from mlx.utils import tree_flatten, tree_unflatten
from mlx_lm.lora.dataset import dataset_loader
from mlx_lm.lora.linear import LoRALinear
from mlx_lm.lora.trainer import LoraTrainer, TrainingArguments
from mlx_lm.utils import generate, get_model_path, load_model
from transformers import AutoTokenizer
model_path = get_model_path("./mlx_model")
# Loading models in training will be handled by the mlx-lm function.
model = load_model(model_path, load_train=True)
# Leave the load tokenizer to the user in case they want to configure it with specific token configurations.
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Training arguments from the Lora module for fine-tuning Lora parameters.
args = TrainingArguments(iters=750, lora_layers=32, batch_size=4)
# Leave the dataset load outside of the main Lora module. We will provide a default dataset load,
# but ideally it should be decoupled from the Lora module.
train_dataset, val_dataset, test_dataset = dataset_loader(
"data", load_train=True, load_test=False
)
# Currently, we have to manually apply the Lora layer.
# Therefore, we leave it outside of the main lora trainer so that users can flexibly apply the lora layers based on their needs.
# However, I feel that we may need to provide a way for users to specify a list of linear layer names and apply those Lora linear layers during trainer initialization.
model.freeze()
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj, rank=64, scale=6)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj, rank=64, scale=6)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
print(f"Trainable parameters {p:.3f}M")
# The user can choose any optimizer they want to use and pass it to the trainer.
learning_rate = 1e-6
opt = optim.AdamW(learning_rate=learning_rate)
trainer = LoraTrainer(
model=model,
tokenizer=tokenizer,
args=args,
optimizer=opt,
train_dataset=train_dataset,
val_dataset=val_dataset,
)
trainer.train()
# The trainer will be responsible for saving adapter weights.
# Currently, we don't have a place to save the information about which adapter layers have been trained for the weight.
# So, during the process of saving the adapter, we may need to save both the weight and some information about which adapter layers were used in the lora fine-tune.
trainer.save_adapter("adapter")
model.load_adapter("adapter")
# Generate would simply use mlx-lm generate to avoid duplication.
generate(
model,
tokenizer,
prompt=f"""table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
Q: What is terrence ross' nationality
A: """,
)
I think this PR is a good starting point for considering how we can properly structure the lora module. Additionally, we may be able to use the mlx-lm lora module as building blocks in the Lora example and demonstrate how to create a lora script to do fine-tuning with different settings.
@awni Let me know your thoughts. In the meantime, I will create a WIP PR just for demonstration purposes. We can also make adjustments to the implementation if we find a better approach.
@mzbac sounds good to me! I like this a lot and agree with it; I think now that the goal is to move this functionality into a package, we should definitely decouple a lot of this. That way the user can have much more control, and use it as a bona fide library rather than just example code to take and modify. I appreciate the work and templates you've done so far. I'm taking a look at the WIP PR you've made thus far, and I'll try and see how to help out. Please feel free to let me know if you need me to tackle specific things, that way you and I can make sure we're working together efficiently and not accidentally working on the same thing or using different standards in implementing it.
@ProjectProgramAMark thanks for the contribution and leading the charge on packaging up Lora.
Since we decided to merge Lora and MLX-lm I think a good goal for this lora example is to keep it as a really simple reference for people wanting to hack on it or learn how it works (rather than use it as a package). To that end it probably makes sense to close this PR in favor of #337. Ok with you?
@awni sorry, just saw this. That sounds good to me, I would agree. I'd like to still contribute in some way, so if you've got backlogged work needed please reach out and let me know!