vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[V1] [P/D] Add Support for KV Load Failure Recovery

Open sdavidbd opened this issue 5 months ago • 5 comments

🎯 Purpose

This PR implements the suggested design from RFC #19329.

It introduces a mechanism to recover from KV load failures during inference in vLLM’s KV connector v1 path by:

  • Detecting failed KV block loads,
  • Automatically rescheduling affected requests for recomputation from a valid prefix.

This feature improves the robustness of vLLM when using external systems for KV cache offload or transfer.


🧪 Test Plan

A self-contained test has been added under:

examples/offline_inference/kv_load_failure_recovery/

This test demonstrates the recovery mechanism using a fault-injecting connector that simulates failed KV loads. The test flow includes:

  • A prefill stage that saves KV data to the local filesystem.
  • Two decode stages:
    1. Normal decode loading KV data.
    2. Simulated failure decode using a custom connector.

✅ Test Results

To run the test:

cd ~/vllm/examples/offline_inference/kv_load_failure_recovery
./run.sh

The script:

  1. Clears local KV state.
  2. Executes the prefill stage.
  3. Runs a normal decode that loads KV data from disk.
  4. Runs a second decode with simulated KV load failure.

✅ Both decode runs produce identical outputs, confirming that requests referencing failed blocks were successfully rescheduled and recomputed.

sdavidbd avatar Jun 08 '25 14:06 sdavidbd

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

github-actions[bot] avatar Jun 08 '25 14:06 github-actions[bot]

@sdavidbd the problem I see here is in multi-worker case. For example, if you use tensor-parallelism or pipeline parallelism. Each worker may fail loading on different block ids. Although each worker reports its own invalid_block_ids, the scheduler will only get the invalid_block_ids of the first-rank worker. This is due to the implementation of MultiprocExecutor.execute_model which simply discards the ModelRunnerOutput of all but the first worker.

In theory, the same problem holds for the existing fields ModelRunnerOutput.finished_sending and ModelRunnerOutput.finished_recving. The current solution in the code for this is having the first-rank worker communicate with all other workers to aggregate input from all workers to its own output. This is what was done in NixlConnector.get_finished.

However, I this it is better to have aggregation by the scheduler, and not by the workers. And also, have one aggregation per all fields of ModelRunnerOutput.

I'm actually planning on opening up a PR in the upcoming days that will add a generic KVConnectorMetadata field to ModelRunnerOutput. This field will be used to transfer ALL required connector information from workers to the scheduler. This includes finished_sending and finished_recving, and potentially your invalid_block_ids.

With this approach, the get_block_ids_with_load_errors will become a scheduler-side connector call and will extract the invalid_block_ids from the abstract KVConnectorMetadata.

My motivation for this is my plan to add an offloading connector, and be able to report offloading status from workers to the scheduler.

cc @njhill

orozery avatar Jun 10 '25 06:06 orozery

@sdavidbd the problem I see here is in multi-worker case. For example, if you use tensor-parallelism or pipeline parallelism. Each worker may fail loading on different block ids. Although each worker reports its own invalid_block_ids, the scheduler will only get the invalid_block_ids of the first-rank worker. This is due to the implementation of MultiprocExecutor.execute_model which simply discards the ModelRunnerOutput of all but the first worker.

In theory, the same problem holds for the existing fields ModelRunnerOutput.finished_sending and ModelRunnerOutput.finished_recving. The current solution in the code for this is having the first-rank worker communicate with all other workers to aggregate input from all workers to its own output. This is what was done in NixlConnector.get_finished.

However, I this it is better to have aggregation by the scheduler, and not by the workers. And also, have one aggregation per all fields of ModelRunnerOutput.

I'm actually planning on opening up a PR in the upcoming days that will add a generic KVConnectorMetadata field to ModelRunnerOutput. This field will be used to transfer ALL required connector information from workers to the scheduler. This includes finished_sending and finished_recving, and potentially your invalid_block_ids.

With this approach, the get_block_ids_with_load_errors will become a scheduler-side connector call and will extract the invalid_block_ids from the abstract KVConnectorMetadata.

My motivation for this is my plan to add an offloading connector, and be able to report offloading status from workers to the scheduler.

cc @njhill

@orozery Thanks a lot for the detailed feedback — you're absolutely right. I'm aware of this gap in the multi-worker setup, and I'm already working on adding aggregation logic for invalid_block_ids at the MultiprocExecutor level to address it.

I also really like your idea of encapsulating KV connector–related metadata in a dedicated KVConnectorOutput (or similar — probably best to distinguish it from the existing KVConnectorMetadata) field within ModelRunnerOutput. It seems like a clean and extensible way to surface connector-specific signals to the scheduler.

Looking forward to your upcoming PR — happy to align with it or adapt my changes accordingly.

sdavidbd avatar Jun 10 '25 08:06 sdavidbd

Converting this PR to draft for now - I'm working on extending support for tensor-parallel (TP) setups and adding more unit tests to improve coverage and reliability. Will mark as ready for review once that's complete.

sdavidbd avatar Jun 11 '25 12:06 sdavidbd

@sdavidbd we need generic failure handling logic... would it make more sense to just have the returned list of request ids that have finished loading reflect a failed/succeeded status (or equivalently include a mutually exclusive list of failed load req ids)?

The scheduler already knows the associated block ids and so would be able to recompute accordingly. This could then cover other kinds of failure too.

njhill avatar Jun 20 '25 00:06 njhill

@sdavidbd we need generic failure handling logic... would it make more sense to just have the returned list of request ids that have finished loading reflect a failed/succeeded status (or equivalently include a mutually exclusive list of failed load req ids)?

The scheduler already knows the associated block ids and so would be able to recompute accordingly. This could then cover other kinds of failure too.

Thanks @nick-hill — appreciate the suggestion!

I actually think that returning the list of failed block IDs is already a fairly generic and extensible approach, for a few reasons:

  1. Connector visibility: The connector is inherently aware of which blocks failed to load, but it may not have enough context to map failures back to specific request IDs. We’d prefer not to require connectors to track per-request mappings, as this would increase implementation complexity and reduce flexibility.
  2. One-to-many dependency: Multiple requests in the same batch may share the same KV blocks. In case of a block load failure, all dependent requests should be rescheduled. Reporting failed blocks allows the scheduler to identify the affected requests based on its existing request-to-blocks mapping.
  3. Granularity of recovery: Treating a request as failed based on any load error may be too coarse. A request might attempt to load 100 blocks and fail only on the last one. In such cases, it’s more efficient to recompute only the failed blocks rather than discarding all previously loaded ones.
  4. Future enhancement of KV cache reuse: Currently, KV reuse in vLLM is limited to a prefix of the prompt. We're planning to enhance this to support reuse of arbitrary subsets of the prompt blocks. This is particularly important when offloading KV cache to external backends with their own eviction policies or when sporadic load failures occur. By allowing connectors to report the exact set of failed blocks, we lay the groundwork for fine-grained recovery in the future.

sdavidbd avatar Jun 22 '25 10:06 sdavidbd

Thanks @sdavidbd, and sorry for the delay getting back to this.

I guess I'm still not convinced that block ids would be "simpler" - since we already pass back ids of finished requests, just having these flagged as failed or succeeded seems simpler to me, and would be straightfoward on the scheduler side for example to just fall back to locally prefilling any failed requests.

Since errors would presumably be quite rare, the granularity optimization seems like it might be unnecessary.

One thing I have been wondering about related to this though, is how we should we handle cases where concurrent requests want to load overlapping prefixes. If there is an async kv load in already flight for a subset of the blocks, the request should ideally just subscribe / wait behind that (in addition to loading additional blocks in parallel, if needed). I don't think that's how it currently works.

njhill avatar Jul 03 '25 12:07 njhill

Thanks for the work @sdavidbd !

re: @orozery

I'm actually planning on opening up a PR in the upcoming days that will add a generic KVConnectorMetadata field to ModelRunnerOutput.

This is actually independent from having the reduce carried out in the scheduler though. We could still have rank0 be in charge of aggregating data from the other ranks, so that the actual logic remains where the data is produced. Perhaps I am just missing the advantage of scheduler-side reduction. But we definitely need to aggregate as you pointed out.

having these (request_id) flagged as failed or succeeded seems simpler to me

I agree here I think it would make the aggregation above simpler and would move less data around the ranks. I believe we can overload the single communication step where we already exchange finished_req_ids.

NickLucche avatar Jul 03 '25 15:07 NickLucche

Thanks @njhill - really appreciate the follow-up.

You're absolutely right that today errors are relatively rare. However, as KV cache offloading gains traction - especially with the rise of disaggregated storage solutions - we do expect failures to become more common. Failures could stem from transient disconnections or external eviction policies, and designing for them upfront helps ensure resilience.

Looking forward, KV cache offloading is expected to play a key role in enabling KV reuse across various scenarios: preempted requests, multi-turn conversations, long shared document queries, and more. In these cases, externally computed tokens are loaded as part of running prefill requests and can potentially be shared by multiple requests in the same batch.

Critically, the connector typically does not have full visibility into all requests that might share these blocks - specifically, requests beyond the first will treat the blocks as if they were computed locally rather than loaded externally. Reporting block-level failures allows the scheduler, which does maintain a global view of request-to-block mappings, to identify and reschedule only the affected requests efficiently.

Moreover, the future enhancement I mentioned - supporting reuse of arbitrary prompt blocks rather than only the prefix - is an important and feasible optimization we are actively working on. The block-level failure reporting design directly aligns with and enables this finer-grained reuse and recovery path.

Regarding your point about overlapping async KV loads: I agree - currently, overlapping blocks are duplicated locally instead of being shared. While this gap is certainly addressable, it also raises a broader question: Why do disaggregated decode requests need to fully wait for remotely computed KV blocks to finish loading before starting at all? Couldn't we schedule the requests immediately and load the KV blocks asynchronously, layer by layer, during the forward pass? This approach could resolve the duplication issue and allow us to overlap block loading with computation, further optimizing end-to-end latency.

Thanks again for the constructive discussion - happy to further align on this if you'd like!

sdavidbd avatar Jul 03 '25 15:07 sdavidbd

Thanks @sdavidbd. Maybe it's better to have the connector API be block-based? (see discussion in other PR https://github.com/vllm-project/vllm/pull/19555#discussion_r2187014032).

Couldn't we schedule the requests immediately and load the KV blocks asynchronously, layer by layer, during the forward pass?

This is already supported by the connector API but there may be a trade-off since we need to load at least the first layer before the forward pass starts, which will add latency to every other request in the batch. It might also slow down subsequent layer if the loading takes longer than computation of the prior layer. There might also be greater overhead associated with multiple per-layer transfers. But it's something we could experiment with. We want to at least support async for the initial handshake.

I guess we need to think through the sync and async loading cases with respect to block sharing. You're right that the sharing happens automatically/implicitly in the sync case right now, it would be good to have that work for the async case too.

njhill avatar Jul 05 '25 10:07 njhill

✅ Added unit tests.

This PR currently assumes that invalid block set aggregation for multi-worker setups is handled within the connector. I plan to align this to the aggregation logic introduced in PR #19555 (at the multi-executor level) in a follow-up PR. @njhill @orozery

Additionally, there is a known gap in PP setups where concurrent batches may share KV blocks. If a block is marked invalid in an earlier batch due to a load failure, a subsequent batch might incorrectly treat it as valid, potentially leading to incorrect outputs. I already have an approach in mind to resolve this and will address it in a follow-up PR.

sdavidbd avatar Jul 13 '25 15:07 sdavidbd

Thanks @sdavidbd ! I like the idea of this PR. This will also be very useful for LMCache's use cases.

ApostaC avatar Jul 14 '25 18:07 ApostaC

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @sdavidbd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Jul 20 '25 15:07 mergify[bot]

This feature is critically important. Without a failure handling mechanism, any cache storage system is essentially unusable.

yoo-kumaneko avatar Jul 24 '25 06:07 yoo-kumaneko

  • Rebased the PR and resolved merge conflicts.
  • Refactored GpuModelRunner.execute_model to encapsulate the KV connector lifecycle using a context manager.
    This simplifies propagating additional data back from the KV connector and ensures clear_connector_metadata is properly called.
    (The refactor can be extracted to a separate PR if it needs to be merged independently.)

Currently WIP in this PR:

  • Add support for async KV load failure recovery
  • Implement get_block_ids_with_load_errors API in MultiConnector

ETA: July 27

sdavidbd avatar Jul 25 '25 00:07 sdavidbd

Added failure recovery support for asynchronous KV load. Covered by the system test in examples/offline_inference/kv_load_failure_recovery. Additional unit tests will be added in a follow-up commit.

sdavidbd avatar Jul 28 '25 00:07 sdavidbd

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @sdavidbd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Jul 28 '25 05:07 mergify[bot]

@sdavidbd sorry again for the delay getting back to this and thanks a lot for the great work and your patience.

To recap high level context: there are various overlapping discussions underway regarding the evolution of the kv connector API. It would be nice to reach agreement on a unified north star design, but that's practically going to take some time.

The hesitation of making incremental changes/improvements to the interface in the meantime is that they might not align with the final unified (re-)design and so may then need to be changed again. But from a pragmatic pov there are incremental features needed immediately in support of production use of the existing integrations. So the best course of action i.m.o. is to proceed with such API changes knowing that it's possible they will be short-lived once the broader plan is solidified. And we should make it clear to connector implementors that the API should still not yet be considered stable.

For this reason though I am hesitant to make such changes more complicated that they need to be for immediate purposes. In particular the ability to deal with partial failures / fine grained recovery and non-contiguous block reuse I don't feel are urgent aspects. But basic error recovery is somewhat urgent.

The existing API is very request-oriented and my feeling is still that it would be cleaner / simpler from the connector implementations' pov to have the error reporting be that way for the initial implementation.

It's true that the connector doesn't have visibility into other affected requests, but:

  • In the sync case, the per-request blocks passed to the connector will be mutually exclusive, and are known by the scheduler, so it should be just as easy to determine other affected requests from the prior step.
  • In the async case there's currently no sharing of the blocks until after they're loaded (successfully), so we shouldn't even have to check for other requests.

I'm not sure in general that it's simpler for the connector to report failed block ids, it probably depends on the implementation. Particularly for async-loading connectors like NixlConnector for example this wouldn't be the case since it's managing a transfer per request. And again since the block ids are mutually exclusive from the connectors' pov, it should only need to report a single failed request for any given failed block.

Please feel free to correct me if there's something I have got wrong here.

I am not saying that we should not ultimately move to the block-oriented approach but feel that would make more sense to do in conjunction with the wider connector re-think i.e. to make it more block-oriented as a whole. And there are more things to consider as part of that such as eviction-based offloading, progressive offload-completion reporting (so that some blocks can be freed earlier), allowing scheduler to request immediate cancellation/release of async offloads, etc.

Would also welcome more input from others on this. Let's have a call to discuss soon.

njhill avatar Jul 30 '25 13:07 njhill

Thanks again @sdavidbd I think the code looks great.

Refactored GpuModelRunner.execute_model to encapsulate the KV connector lifecycle using a context manager. This simplifies propagating additional data back from the KV connector and ensures clear_connector_metadata is properly called.

I like this but as you suggested given it's orthogonal it may be better to move it to a dedicated PR which we can hopefully then get merged much quicker 😅

Moved the refactor into a separate PR: #21980

sdavidbd avatar Jul 30 '25 23:07 sdavidbd

@sdavidbd sorry again for the delay getting back to this and thanks a lot for the great work and your patience.

To recap high level context: there are various overlapping discussions underway regarding the evolution of the kv connector API. It would be nice to reach agreement on a unified north star design, but that's practically going to take some time.

The hesitation of making incremental changes/improvements to the interface in the meantime is that they might not align with the final unified (re-)design and so may then need to be changed again. But from a pragmatic pov there are incremental features needed immediately in support of production use of the existing integrations. So the best course of action i.m.o. is to proceed with such API changes knowing that it's possible they will be short-lived once the broader plan is solidified. And we should make it clear to connector implementors that the API should still not yet be considered stable.

For this reason though I am hesitant to make such changes more complicated that they need to be for immediate purposes. In particular the ability to deal with partial failures / fine grained recovery and non-contiguous block reuse I don't feel are urgent aspects. But basic error recovery is somewhat urgent.

The existing API is very request-oriented and my feeling is still that it would be cleaner / simpler from the connector implementations' pov to have the error reporting be that way for the initial implementation.

It's true that the connector doesn't have visibility into other affected requests, but:

  • In the sync case, the per-request blocks passed to the connector will be mutually exclusive, and are known by the scheduler, so it should be just as easy to determine other affected requests from the prior step.
  • In the async case there's currently no sharing of the blocks until after they're loaded (successfully), so we shouldn't even have to check for other requests.

I'm not sure in general that it's simpler for the connector to report failed block ids, it probably depends on the implementation. Particularly for async-loading connectors like NixlConnector for example this wouldn't be the case since it's managing a transfer per request. And again since the block ids are mutually exclusive from the connectors' pov, it should only need to report a single failed request for any given failed block.

Please feel free to correct me if there's something I have got wrong here.

I am not saying that we should not ultimately move to the block-oriented approach but feel that would make more sense to do in conjunction with the wider connector re-think i.e. to make it more block-oriented as a whole. And there are more things to consider as part of that such as eviction-based offloading, progressive offload-completion reporting (so that some blocks can be freed earlier), allowing scheduler to request immediate cancellation/release of async offloads, etc.

Would also welcome more input from others on this. Let's have a call to discuss soon.

Thanks @njhill for the thoughtful and detailed response - I really appreciate you taking the time.

I agree that the scheduler-side connector API is naturally request-oriented, since load/save operations are initiated in the context of a request. However, from the worker-side connector's point of view, there's often no need to retain request context - it operates at the level of block identifiers. You're right that the current NixlConnector handles async transfers per request, and that NIXL treats each multi-block transfer as a transaction - all-or-nothing. But it's worth noting that many existing external connectors integrate seamlessly via the KVConnectorBase_V1 interface, and some of them already support finer-grained control. These connectors could benefit immediately from partial failure support, even if non-contiguous block reuse is not a short-term priority. In fact, partial failure handling is quite low-hanging fruit. To support it, we need to identify the last valid block - something #19330 enables by reporting failed block IDs, and #21534 by reporting a valid prefix length. In my view, reporting failed block IDs is the more general approach. Most worker-side connectors maintain an external-to-local block mapping already, and aren’t inherently request-aware - making block-level reporting natural and straightforward.

I also don’t think space complexity (i.e. reporting one failed request vs. several failed blocks) should be a concern - this only affects the error path. Looking ahead, we may even want finer-grained failure signaling, down to tokens and layers.

A key concern though is API stability: we want to avoid breaking changes to the connector API, given the number of third-party implementations in the wild. #19330 addresses this by only adding a new method - no breakages, and no changes to the existing API surface.

As I see it, the long-term direction should converge toward a block-oriented model where the connector reports back which blocks it finished handling (whether save or load), and whether they succeeded or failed.

On the topic of shared blocks: while async loading currently doesn't support sharing, enabling it should be relatively straightforward - something I may contribute soon. Note that a request-based approach would retain the same level of complexity in the scheduler logic for detecting affected requests, while likely making it less readable.

To summarize, I believe block-oriented error handling provides the most general, backward-compatible solution - with no overhead in the common case. It also enables connector implementers to choose the level of fidelity appropriate to their backend: Basic: fail the whole request - report just the first block per affected request. Intermediate: salvage valid prefix - report just the first failed block per affected request. Advanced: report all failed blocks. As a fallback, if we do decide to land a lighter-weight, request-based approach in the short term, I believe the KVConnectorOutput abstraction (#21980) offers a clean path to supporting both models side by side.

Happy to join a call to discuss further.

sdavidbd avatar Jul 31 '25 11:07 sdavidbd

Hi @sdavidbd, thanks for the PR. Something additionally, for KV connectors (NIXLConnector) who use async API (e.g. get_finished), there are cases that KV fetching from remote keeps failing, and in that case, we want to notify scheduler to abort/finish/remove that request from further processing and free up its blocks. The current get_finished API only notifies scheduler which request is finished, but no way to tell which request can't be finished at all.

Do you think we can also cover this case as well in this PR?

liuzijing2014 avatar Jul 31 '25 20:07 liuzijing2014

Hi @sdavidbd, thanks for the PR. Something additionally, for KV connectors (NIXLConnector) who use async API (e.g. get_finished), there are cases that KV fetching from remote keeps failing, and in that case, we want to notify scheduler to abort/finish/remove that request from further processing and free up its blocks. The current get_finished API only notifies scheduler which request is finished, but no way to tell which request can't be finished at all.

Do you think we can also cover this case as well in this PR?

Thank you, @liuzijing2014!

First, just to clarify: this PR focuses on adding the infrastructure for automatic recovery on the decoder side in the event of KV load failures. Specifically, when a connector reports failed block loads via get_block_ids_with_load_errors, the decoder can fall back to recomputing the missing blocks. However, this PR does not yet include changes to NixlConnector to support this mechanism (i.e., it doesn't implement get_block_ids_with_load_errors), so failure recovery is not currently enabled for that use case. That will be handled in a follow-up PR.

As for your question about the prefiller side: you're right that if a remote load fails, the connector may not report the request as finished - which would prevent the scheduler from releasing the associated blocks. To handle this, we can rely on the timeout mechanism introduced in #20139, which ensures that such requests are eventually cleaned up.

sdavidbd avatar Aug 02 '25 19:08 sdavidbd

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @sdavidbd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Aug 03 '25 19:08 mergify[bot]

@WoosukKwon Since this PR involves scheduler change, just wondering if we can get your feedback on this.

KuntaiDu avatar Aug 04 '25 23:08 KuntaiDu

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @sdavidbd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Aug 11 '25 02:08 mergify[bot]

First, just to clarify: this PR focuses on adding the infrastructure for automatic recovery on the decoder side in the event of KV load failures. Specifically, when a connector reports failed block loads via get_block_ids_with_load_errors, the decoder can fall back to recomputing the missing blocks. However, this PR does not yet include changes to NixlConnector to support this mechanism (i.e., it doesn't implement get_block_ids_with_load_errors), so failure recovery is not currently enabled for that use case. That will be handled in a follow-up PR.

As for your question about the prefiller side: you're right that if a remote load fails, the connector may not report the request as finished - which would prevent the scheduler from releasing the associated blocks. To handle this, we can rely on the timeout mechanism introduced in #20139, which ensures that such requests are eventually cleaned up.

Thanks for the explanation @sdavidbd! It sounds like this PR covers the error handling for KV connectors who do KV injection/load "in-place":

start_load_kv on request 1 -> wait_kv_load on request 1 layer 1 -> attn run for layer 1 -> wait_kv_load on request 1 layer 2 -> attn run for layer 2 -> and so on ...

I am more referring to the case where KV connectors who do KV injection/load fully async regarding to model execution where a request would only be scheduled/batched for model run (forward execution) after its KV loading is fully complete on all layers.

start_load_kv on request 1 to trigger fully async where request 1 is not even scheduled to run in the current model forward iteration -> wait_kv_load no ops -> attn run for layer 1 -> wait_kv_load no ops -> attn run for layer 2 -> so on .. -> get_finished informs scheduler request KV injection/loading is complete -> request 1 would be scheduled for model run for the next iteration.

I think NIXLKVConnector and our vendor internal KV connector solution both adopt this approach.

liuzijing2014 avatar Aug 13 '25 18:08 liuzijing2014

First, just to clarify: this PR focuses on adding the infrastructure for automatic recovery on the decoder side in the event of KV load failures. Specifically, when a connector reports failed block loads via get_block_ids_with_load_errors, the decoder can fall back to recomputing the missing blocks. However, this PR does not yet include changes to NixlConnector to support this mechanism (i.e., it doesn't implement get_block_ids_with_load_errors), so failure recovery is not currently enabled for that use case. That will be handled in a follow-up PR. As for your question about the prefiller side: you're right that if a remote load fails, the connector may not report the request as finished - which would prevent the scheduler from releasing the associated blocks. To handle this, we can rely on the timeout mechanism introduced in #20139, which ensures that such requests are eventually cleaned up.

Thanks for the explanation @sdavidbd! It sounds like this PR covers the error handling for KV connectors who do KV injection/load "in-place":

start_load_kv on request 1 -> wait_kv_load on request 1 layer 1 -> attn run for layer 1 -> wait_kv_load on request 1 layer 2 -> attn run for layer 2 -> and so on ...

I am more referring to the case where KV connectors who do KV injection/load fully async regarding to model execution where a request would only be scheduled/batched for model run (forward execution) after its KV loading is fully complete on all layers.

start_load_kv on request 1 to trigger fully async where request 1 is not even scheduled to run in the current model forward iteration -> wait_kv_load no ops -> attn run for layer 1 -> wait_kv_load no ops -> attn run for layer 2 -> so on .. -> get_finished informs scheduler request KV injection/loading is complete -> request 1 would be scheduled for model run for the next iteration.

I think NIXLKVConnector and our vendor internal KV connector solution both adopt this approach.

Thanks @liuzijing2014 — this PR handles both asynchronous and synchronous loading. In either case, the worker-side connector reports any failed block IDs via the get_block_ids_with_load_errors API after the forward pass. The connector itself is agnostic to whether the blocks were loaded for an actively running request (sync loading) or for a waiting request (async loading). The Scheduler’s _handle_invalid_blocks method covers both scenarios, and the included unit tests and example test demonstrate each case.

sdavidbd avatar Aug 13 '25 20:08 sdavidbd

Updated PR description to highlight support for both sync and async load handling, and to detail corresponding test coverage.

sdavidbd avatar Aug 13 '25 21:08 sdavidbd

Get it @sdavidbd. Just to double confirm: in the case of fully async, when scheduler handles a failed request (whose kv loading has failed), it would reschedule such request to start_load_kv in the next round again?

liuzijing2014 avatar Aug 14 '25 05:08 liuzijing2014

Get it @sdavidbd. Just to double confirm: in the case of fully async, when scheduler handles a failed request (whose kv loading has failed), it would reschedule such request to start_load_kv in the next round again?

@liuzijing2014 Yes and no. start_load_kv is a worker-connector API invoked once per forward pass and is basically agnostic to individual requests—the connector determines what to load based on its metadata. In each forward pass, the worker-connector lifecycle starts with bind_connector_metadata and ends with clear_connector_metadata. While some connector implementations may keep internal state, the general rule is: once a block fails to load and is reported as such, it should not be retried unless explicitly requested.

If you’re asking about whether the scheduler-side API get_num_new_matched_tokens will be called again for requests that waited for async KV load and that load failed (or partially failed) — it depends:

  • If the request already has some computed tokens (either locally generated or successfully loaded externally), the API will not be called again for that request.
  • If the request has no computed tokens (no local tokens and all external tokens failed to load), it will be treated as a new request. In that case, the Scheduler will re-check for local/external cache hits, get_num_new_matched_tokens will be called again, and the connector must reflect the actual available state—i.e., if external cache cannot be loaded, it should not be returned as available.

sdavidbd avatar Aug 14 '25 09:08 sdavidbd