mlx-swift-examples icon indicating copy to clipboard operation
mlx-swift-examples copied to clipboard

How to handle application becoming inactive?

Open MilanNosal opened this issue 8 months ago • 17 comments

Right now when I start inference in foreground and then put the app to background, the app crashes with libc++abi: terminating due to uncaught exception of type std::runtime_error: [METAL] Command buffer execution failed: Insufficient Permission (to submit GPU work from background) (00000006:kIOGPUCommandBufferCallbackErrorBackgroundExecutionNotPermitted), which is I think expected as Metal is not allowed to run in background.

What is the recommended way to handle this? Is there a built in way to pause/cancel the execution, or have it throw catchable Swift exception so that I can handle it gracefully? Any examples that I could use to start from?

MilanNosal avatar Mar 10 '25 15:03 MilanNosal

Probably the simplest way is to make sure all the outstanding work on any active stream is done which you can do with synchronize

Then it should be safe to move the app to the background.

Presumably there is like a callback that you can hook into for when the app gets moved to the background, see e.g. https://developer.apple.com/documentation/metal/preparing-your-metal-app-to-run-in-the-background?language=objc

awni avatar Mar 10 '25 15:03 awni

OK, I'll need a bit more context if you don't mind.. right now I'm using the following approach to use MLXLLM that's based on the swift example code (a bit simplified for brevity):

func generate(modelName: String, input: UserInput) async -> String {
        guard !running else {
            return "Already processing a request"
        }

        running = true
        cancelled = false
        output = ""
        do {
            let modelContainer = try await load(modelName: modelName) // 1
            MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
            let result = try await modelContainer.perform { context in // 2
                let input = try await context.processor.prepare(input: input) // 3
                return try MLXLMCommon.generate( // 4
                    input: input,
                    parameters: generateParameters,
                    context: context
                ) { tokens in
                    var cancelled = false
                    Task { @MainActor in
                        cancelled = self.cancelled
                    }

                    // update the output -- this will make the view show the text as it generates
                    if tokens.count % displayEveryNTokens == 0 {
                        let text = context.tokenizer.decode(tokens: tokens)
                        Task { @MainActor in
                            self.output = text
                        }
                    }

                    if tokens.count >= maxTokens || cancelled {
                        return .stop
                    } else {
                        return .more
                    }
                }
            }
            
            if result.output != output {
                output = result.output
            }
        } catch {
            self.output = ""
            running = false
            return "Failed: \(error)"
        }

        let output = self.output
        running = false
        return output
}

As far as I can tell, the crash can happen in any of the calls marked as //1 , // 2, // 3 and // 4 If it gets as far as // 4, I can just return .stop in the didGenerate callback, but the other calls don't seem to have an API for cancellation. I do not see the Stream type anywhere in those calls and objects associated with them, so not sure how am I supposed to use it from here.. Is there a way to do it with this API, or will I have to edit the code of the MLXLLM libraries?

MilanNosal avatar Mar 10 '25 16:03 MilanNosal

First the easy part:

Stream.gpu.synchronize()

is the call to wait for GPU activity to be done.

davidkoski avatar Mar 10 '25 16:03 davidkoski

For call 1, it could potentially observe task cancellation, see #227 , but one of the calls, eval(model) can potentially take several seconds. Perhaps it could iterate over the parameters in batches and interleave calls to check task cancellation.

2 doesn't do anything with the GPU -- it is just acquiring the synchronization context.

3 and 4 are the items discussed in #227

None of these are directly addressable in the UIApplicationDelegate -- you would have to figure out how to either block this call:

  • https://developer.apple.com/documentation/uikit/uiapplicationdelegate/applicationwillresignactive(_:)

with e.g. a lock, or have some way to cancel this generation task (and in fact have the pieces be cancellable, #227)

davidkoski avatar Mar 10 '25 16:03 davidkoski

AFAIK, blocking resigning isActive is not possible, so the only way how not to crash is cancel/pause anything that is schedule to happen as soon as I get the notification about app going to background.

OK, so from what you're saying it sounds sounds like I cannot just use MLXLLM library as is, but I'll have to fork the MLXLLM and MLXLMCommon and work on those to have it support cancellation, right?

Thanks for your feedback!

MilanNosal avatar Mar 10 '25 17:03 MilanNosal

Yes, I am not sure if the requester for #227 intended to submit a PR, but as-is it doesn't support cancellation. The notes in that issue should help if you want to build it.

davidkoski avatar Mar 10 '25 17:03 davidkoski

Alright, I quickly put some prototype together in

https://github.com/MilanNosal/mlx-swift-examples/pull/1

Basically I "sprinkled" load, generate and prepare with Task.checkCancellation()... I gotta go to bed now, as it's late in my timezone, so did not get time to properly test it, but if any one of you have a bit of time to take a quick look just to see if I am not doing anything obviously stupid with regards to MLX, I would really appreciate it.

MilanNosal avatar Mar 10 '25 23:03 MilanNosal

Hey @davidkoski ! Can I have hopefully a quick question?

but one of the calls, eval(model) can potentially take several seconds. Perhaps it could iterate over the parameters in batches and interleave calls to check task cancellation.

referring to this method in mlx-swift:

/// Evaluate one or more `MLXArray`.
///
/// See ``eval(_:)``
public func eval(_ values: [Any]) {
    var arrays = [MLXArray]()
    for item in values {
        collect(item, into: &arrays)
    }
    eval(arrays)
}

I tried to "batch" it this way:

public func batchedEval(_ values: Any..., batchSize: Int = 5) throws {
    var arrays = [MLXArray]()
    for item in values {
        collect(item, into: &arrays)
    }
    for batch in arrays.chunked(into: batchSize) {
        try Task.checkCancellation()
        eval(batch)
    }
}

eval(batch) is delegating the eval call to the mlx:

public func eval(_ arrays: [MLXArray]) {
    let vector_array = new_mlx_vector_array(arrays)
    mlx_eval(vector_array)
    mlx_vector_array_free(vector_array)
}

it seems to be working the same as with the original version, but from the documentation for mlx.core.eval I am not sure if splitting the array of MLXArray and processsing them separately is semantically equivalent to processing them all together as in original code...

could you please share some thoughts if this is correct?

or did you mean to suggest to chunk the model directly even before collecting them into MLXArrays? (so instead of evaluating batches of MLXArrays collect and evaluate batches of Any model parameters?

thanks in advance!

MilanNosal avatar Mar 13 '25 19:03 MilanNosal

@MilanNosal I think what you did looks reasonable, though we would have to test it to make sure. eval on a single MLXArray will synchronously evaluate the graph that produces the backing of the MLXArray. This might involve hundreds or thousands of operations.

In the case of model loading:

public func loadWeights(
    modelDirectory: URL, model: LanguageModel, quantization: BaseConfiguration.Quantization? = nil
) throws {
...
    for case let url as URL in enumerator {
        if url.pathExtension == "safetensors" {
            let w = try loadArrays(url: url)
...
    // apply the loaded weights
    let parameters = ModuleParameters.unflattened(weights)
    try model.update(parameters: parameters, verify: [.all])

    eval(model)

We load the safetensors and push them into the model, then evaluate the model. The MLXArray instances here have a backing which is a lazily-loaded array of data from the safetensor file. If you were to eval a single array it should call the safetensor API and load the values from disk into memory.

So this code:

public func eval(_ values: Any...) {
    var arrays = [MLXArray]()

    for item in values {
        collect(item, into: &arrays)
    }

    eval(arrays)
}

at the point where it is going to eval(arrays) has:

(lldb) p arrays.count
(Int) 1364

from the model I am loading. The batched eval should run these in chunks and check the task cancellation in between.

or did you mean to suggest to chunk the model directly even before collecting them into MLXArrays? (so instead of evaluating batches of MLXArrays collect and evaluate batches of Any model parameters?

It might be easier to chuck them at the loadWeights() level rather than making a custom eval with the recursive traverse. Inside loadWeights you could do this instead:

    // apply the loaded weights
    let parameters = ModuleParameters.unflattened(weights)
    try model.update(parameters: parameters, verify: [.all])

    for batch in model.innerState().chunked(into: 5) {
        eval(batch)
    }
(lldb) p model.innerState().count
(Int) 1364

it seems to be working the same as with the original version, but from the documentation for mlx.core.eval I am not sure if splitting the array of MLXArray and processsing them separately is semantically equivalent to processing them all together as in original code...

Yes, this is the same semantic -- the requirement of eval() is that when it returns every MLXArray that it collected can be read without waiting.

davidkoski avatar Mar 13 '25 19:03 davidkoski

Thanks for the long explanation! I'll give it a bit more thought and tests over the weekend.

MilanNosal avatar Mar 14 '25 15:03 MilanNosal

ok, a follow up question:

before generating from input, I'm calling:

let input = try await context.processor.prepare(input: input)

where in the callstack prepare -> step -> convertToToken -> processor?.didSample(token: y)

which for RepetitionContext is following:

mutating public func didSample(token: MLXArray) {
        if tokens.count >= repetitionContextSize {
            tokens[index] = token.item(Int.self)
            index = (index + 1) % repetitionContextSize
        } else {
            tokens.append(token.item(Int.self))
        }
    }

the problem is call to token.item(Int.self) which is part of mlx-swift and takes around 1.5s for me - when I hit putting the app to background right after this call starts, that timespan is too big for the grace period given by iOS and so it crashes

any smart suggestion how to work around that? the only thing that I see right now is to fork mlx-swift as well, but I am still not sure if that would help - just by looking at the code of token.item(Int.self) method I have no idea how to "chunk" it...

MilanNosal avatar Mar 18 '25 21:03 MilanNosal

I am not sure either -- that may be the time to first token cost. It requires evaluating the entire graph (the model) for a token. It may require JIT some of the Metal shaders as well.

I am not sure these are good ideas, but they might work:

  • somewhere else in the process run some evals to "prime" the JIT cache, etc.
  • eval various pieces along the model eval graph (basically eval MLXArrays in the middle of the graph) to get partial evals. This would have to be done synchronously otherwise you still have to wait for all the work to complete and a bunch of sync evals in here isn't going to do anything good for performance

Once it is up and scanning tokens it should be generating them pretty quickly and it is less likely to have this problem.

I don't think the python side has any such needs or considerations so there may not be any clues there.

@awni any thoughts or suggestions?

davidkoski avatar Mar 18 '25 23:03 davidkoski

hm, OK, what's making it worse is that the time period is 1.5 second for that given prompt, the longer the prompt the bigger the time period.. which sucks, as I needed it to summarize longer texts for me..

I've seen LLMFarm core deal with this using very unelegant but effective approach - they just wrap every call to a .mm exception catcher that catches the crashing cpp exception and then instead it propagates a ObjC exception that can be caught in Swift.. would something similar be possible in MLX case?

MilanNosal avatar Mar 19 '25 14:03 MilanNosal

It is possible via:

  • https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/seterrorhandler(_:data:dtor:)

You would probably set a global variable indicating that an error occurred and then check that.

If the prompt is long you could potentially evaluate that in chunks -- tune it to see what size chunk is reasonable.

See also https://github.com/ml-explore/mlx/pull/1970 -- something Awni is looking at that might let us interrupt an eval. This is more medium term as we would need to pick it up in mlx-c and mlx-swift.

davidkoski avatar Mar 19 '25 15:03 davidkoski

cool, thanks..

minor update - by setting the smaller prefillStepSize on the parameters it stopped crashing.. I guess that gives it more time between steps to cancel it?

MilanNosal avatar Mar 19 '25 16:03 MilanNosal

that controls what size chunks it feeds the prompt in -- if using a smaller prefill size does it then that is a great way I think (and is the same as what I suggested above)

davidkoski avatar Mar 19 '25 17:03 davidkoski

Just FYI we are prototyping an interruptable eval: https://github.com/ml-explore/mlx/pull/1970

Probably better not to add it if it's not necessary.. so if using a smaller chunk size is sufficient that would be great. Keep us posted.

awni avatar Mar 19 '25 20:03 awni