mlx-swift-examples icon indicating copy to clipboard operation
mlx-swift-examples copied to clipboard

Missing support for LoRA chat, tools, and completions data formats

Open ronaldmannak opened this issue 10 months ago • 7 comments

The loadLoRAData(url:) method currently only seems to support the text data format, e.g. {"text": "This is an example for the model."}. See here.

I was planning to add support for all data formats MLX supports (besides text: chat, tools, and completions)

Before I proceed with implementing loading the the missing formats, I would like to confirm a couple of points:

  1. Is the lack of support for these formats simply due to an oversight in the existing code to load different data formats since the example only uses text?
  2. Alternatively, is there a limitation with LoRATrain.train(model:, train:, validate: ...) that restricts it to handling only text data formats?

Edit: I am referring to the data formats as described in mlx-examples

ronaldmannak avatar Jan 30 '25 02:01 ronaldmannak

I was planning to add support for all data formats MLX supports (besides text: chat, tools, and completions)

Sounds great!

Before I proceed with implementing loading the the missing formats, I would like to confirm a couple of points:

  1. Is the lack of support for these formats simply due to an oversight in the existing code to load different data formats since the example only uses text?

The code it was ported from only did text:

  • https://github.com/ml-explore/mlx-examples/tree/main/lora

so just a limitation in the examples I had when porting.

  1. Alternatively, is there a limitation with LoRATrain.train(model:, train:, validate: ...) that restricts it to handling only text data formats?

There might be limitations around that but I think it would be just not expecting different types of data rather than a limitation of MLX itself -- that should work fine.

One thing to consider is where the LoraTrain.swift file should live -- currently it is in MLXLLM but it might make sense to go to MLXLMCommon

davidkoski avatar Jan 30 '25 02:01 davidkoski

Thanks @davidkoski I'll start adding the other file formats and I'll see if I run into any issues with LoRATrain along the way.

Re: moving LorTrain.swift to MLXLMCommon, I can definitely do that, but will that break projects that use LoRA you think?

ronaldmannak avatar Jan 30 '25 03:01 ronaldmannak

What data format is LoRATrain.train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:) expec in the train and validate parameters?

From this line in LoadJSON(url:) I understand that in the case of a text data structure, only the text values are passed, instead of a stringified JSON object. I guess that makes sense in the case of text since the JSON object only stores a single value (text).

However, the other three data structures (chat, tools, completions as described here) store multiple properties.

What kind of encoding is LoRATrain.train(model:...) expecting?

 return try String(contentsOf: url)
        .components(separatedBy: .newlines)
        .filter {
            $0.first == "{"
        }
        .compactMap {
            try JSONDecoder().decode(Line.self, from: $0.data(using: .utf8)!).text
        }

ronaldmannak avatar Jan 31 '25 00:01 ronaldmannak

Thanks @davidkoski I'll start adding the other file formats and I'll see if I run into any issues with LoRATrain along the way.

Re: moving LorTrain.swift to MLXLMCommon, I can definitely do that, but will that break projects that use LoRA you think?

It may, but it will be very minor -- just import MLXLMCommon (which they may already be doing). We can document this.

davidkoski avatar Jan 31 '25 01:01 davidkoski

What data format is LoRATrain.train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:) expec in the train and validate parameters?

From this line in LoadJSON(url:) I understand that in the case of a text data structure, only the text values are passed, instead of a stringified JSON object. I guess that makes sense in the case of text since the JSON object only stores a single value (text).

However, the other three data structures (chat, tools, completions as described here) store multiple properties.

What kind of encoding is LoRATrain.train(model:...) expecting?

return try String(contentsOf: url) .components(separatedBy: .newlines) .filter { $0.first == "{" } .compactMap { try JSONDecoder().decode(Line.self, from: $0.data(using: .utf8)!).text }

The format is JSONL -- a file containing many JSON blocks. I am not aware of any native swift parser for it, so this is what we have. It may not stand up to a more complicated structure.

Perhaps the way to think about it is what data structure should the LORA training call take? It might be something like:

protocol LORAInput {
    func lmInput() async throws -> LMInput
    func target() async throws -> String
}

and the LoRA training would take an array of these. This would encapsulate anything from simple text to something that had images or video. The training loop already has to prepare the prompt, so this would cut out a few layers there.

Anyway, this could then be independent of the file format.

davidkoski avatar Jan 31 '25 01:01 davidkoski

@davidkoski Sorry, the JSONL part I understand and actually already have working. I reused your approach to split the file on newlines and then decode a single json object per line. From the training data I've seen so far, that works.

I've just created a draft pull request for the update I've made so far.

From your comment I understand we don't really have trainers for the different data formats, and we'll need to create separate trainers for each data format, is that correct?

ronaldmannak avatar Jan 31 '25 03:01 ronaldmannak

From your comment I understand we don't really have trainers for the different data formats, and we'll need to create separate trainers for each data format, is that correct?

We don't have something that will take inputs other than straight text and run the inference pass. That is really just a matter of which API is used and the age of the code (VLMs weren't a thing, at least in mlx-swift, when this was built). I think it is just a matter of giving it the right input (LMInput) and everything should work.

But feel free to contribute whatever you can -- we can build this up in pieces!

davidkoski avatar Jan 31 '25 03:01 davidkoski