Fallback to a different model in RetryChatModel
- We have a recurring problem with Anthropic Overloaded exceptions
- We want to fallback to a different model (either a different version or a different provider)
- It would be cool if we could do something like the following.
RetryChatModel(
...,
retries=[
AnthropicChatModel("claude-3-7-sonnet-latest"),
AnthropicChatModel("claude-3-5-sonnet-latest"),
OpenAiChatModel("40"),
],
)
- Maybe also support selecting a different model depending on the exception?
def _model_selector(attempt: int, exception: Exception) -> ChatModel | None:
...
RetryChatModel(
...,
retries=_model_selector,
)
Hi @ashwin153 That makes sense! The existing RetryChatModel is focussed on handling errors from the LLM output / schema validation rather than service-level issues like network connectivity errors, invalid API key, or out of tokens/credits, etc. I think there's a clean split here because these errors should be retried without modifying the chat messages, unlike errors in the LLM's logic/output. But maybe it makes sense to handle these all in the one class. There is an existing ticket to handle other exceptions LLM logic/output exceptions https://github.com/jackmpcollins/magentic/issues/417
Either way, since model is a parameter to prompt/Chat/etc., to unblock for the moment you should be able to create your own ChatModel with the behavior you want and use this where needed. Something like
from collections.abc import Callable, Iterable, Sequence
from typing import Any
from magentic.chat_model.base import ChatModel, OutputT
from magentic.chat_model.message import Message
class FallbackChatModel(ChatModel):
def __init__(
self,
chat_models: Sequence[ChatModel],
# maybe accept list of exception types here too
):
self._chat_models = chat_models
# must match signature from parent class
def complete(
self,
messages: Iterable[Message[Any]],
functions: Iterable[Callable[..., Any]] | None = None,
output_types: Iterable[type[OutputT]] | None = None,
*,
stop: list[str] | None = None,
) -> AssistantMessage[OutputT]:
for model in self._chat_models:
try:
return model.complete(messages, functions, output_types)
# Catch the exceptions you care about here
# and customize behavior as desired
except Exception as e:
# log a warning ?
raise NoFallbackWorkedError()
# Similar for `acomplete` method
# USAGE: pass as `model` param
@prompt(
"Create a Superhero named {name}.",
model=FallbackChatModel([ChatModel("gpt-4o"), ...]),
)
def create_superhero(name: str) -> Superhero: ...
See the code for RetryChatModel for reference too: https://github.com/jackmpcollins/magentic/blob/f9533efe67a0679083be74509764358b0e99b58a/src/magentic/chat_model/retry_chat_model.py#L10
Please share here what ends up working for you and we can figure out how best to generalize it and include it in magentic directly. And let me know if you have any questions/issues Thanks!
My code is in a private repository so I can't share, but I'll explain what I did in case it helps you design the feature.
I have a ChatModel subclass that takes a fallback: ChatModel | None argument. In complete and acomplete it tries the model and retries with the fallback if one is defined. You can chain these together in the obvious way ChatModel(fallback=ChatModel(fallback=...)), but I also added a then method, ChatModel(...).then(ChatModel(...)), for when the chain gets longer.
I did it this way because I thought the fallbacks: list[ChatModel] approach might make it harder to have different retry behavior for different fallbacks. I'm not sure yet which approach I'll end up with - I was planning to see how complicated my retry behavior gets before deciding.
The other problem we've having is we want to know which model generated the output, because quality can vary wildly between models. Before it was obvious from whichever model was in scope, but now it can be any one of the fallbacks. Ideally this new feature would also address the provenance problem. Any ideas for how we might implement this?