llama-stack
llama-stack copied to clipboard
Cerebras Inference Integration
Adding Cerebras Inference as an API provider.
It looks like the other providers use the legacy OpenAI API
but we prefer the new chat completion API. As a result I updated llama_stack/providers/utils/inference/openai_compat.py to support the new one while maintaining backwards compatibility. Let me know if this isn't preferred and I will move the logic to the internals of our integration.
Testing
Build
$ llama stack build --template local-cerebras --name my-cerebras-stack
Configuration
$ llama stack configure my-cerebras-stack
Could not find my-cerebras-stack. Trying conda build name instead...
Llama Stack is composed of several APIs working together. For each API served by the Stack,
we need to configure the providers (implementations) you want to use for these APIs.
Configuring API `inference`...
> Configuring provider `(remote::cerebras)`
Enter value for base_url (default: https://api.cerebras.ai) (required):
Enter value for api_key (default: csk-abcd<redacted>) (optional):
Configuring API `memory`...
> Configuring provider `(meta-reference)`
Configuring API `safety`...
> Configuring provider `(meta-reference)`
Do you want to configure llama_guard_shield? (y/n): n
Enter value for enable_prompt_guard (default: False) (optional):
Configuring API `agents`...
> Configuring provider `(meta-reference)`
Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite):
Configuring SqliteKVStoreConfig:
Enter value for namespace (optional):
Enter value for db_path (default: <home>/.llama/runtime/kvstore.db) (required):
Configuring API `telemetry`...
> Configuring provider `(meta-reference)`
> YAML configuration has been written to `<home>/.llama/builds/conda/my-cerebras-stack-run.yaml`.
You can now run `llama stack run my-cerebras-stack --port PORT`
Running Stack
$ llama stack run my-cerebras-stack --port 3000
Testing with cURL
$ curl --location 'http://localhost:3000/inference/chat_completion' --header 'Content-Type: application/json' --data '{
"model": "Llama3.1-70B-Instruct",
"messages": [
{
"role": "user",
"content": "What is the temperature in Seattle right now?"
}
],
"stream": false,
"sampling_params": {
"strategy": "top_p",
"temperature": 0.5,
"max_tokens": 100
},
"tool_choice": "required",
"tool_prompt_format": "json",
"tools": [
{
"tool_name": "getTemperature",
"description": "Gets the current temperature of a location.",
"parameters": {
"location": {
"param_type": "string",
"description": "The name of the place to get the temperature from in degress celsius.",
"required": true
}
}
}
]
}'
Non-Streaming Response
{
"completion_message": {
"role": "assistant",
"content": "",
"stop_reason": "end_of_turn",
"tool_calls": [
{
"call_id": "17277c905",
"tool_name": "getTemperature",
"arguments": {
"location": "Seattle"
}
}
]
},
"logprobs": null
}
Streaming Response
data: {"event":{"event_type":"start","delta":"","logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":"","logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":{"call_id":"3f0695472","tool_name":"getTemperature","arguments":{"location":"Seattle"}},"parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"complete","delta":"","logprobs":null,"stop_reason":"end_of_turn"}}
Unit Tests
See tests/test_cerebras_inference.py
Pre-Commit Checks
trim trailing whitespace.................................................Passed
check python ast.........................................................Passed
check for merge conflicts................................................Passed
check for added large files..............................................Passed
fix end of files.........................................................Passed
Insert license in comments...............................................Passed
flake8...................................................................Passed
Format files with µfmt...................................................Passed
One thing I am unsure about is whether "Agent" support is available out of the box just from implementing "Inference". I did notice other API vendors advertising on README.md that they have Agent support, but I could not find the corresponding implementation in code.
@ashwinb friendly bump on this PR :) Please allow the CI to run for this PR
I did notice other API vendors advertising on README.md that they have Agent support
which other vendors? if you mean Fireworks and Together, that is because they both have Llama Stack distribution endpoints so they make the entirety of the Llama Stack APIs available on their ends. That includes Agents, Memory, etc.
I did notice other API vendors advertising on README.md that they have Agent support
which other vendors? if you mean Fireworks and Together, that is because they both have Llama Stack distribution endpoints so they make the entirety of the Llama Stack APIs available on their ends. That includes Agents, Memory, etc.
Ah I see, I didn't know about the distribution endpoints. I'll take out the ✅ on the Agents column
@henrytwo
Thanks for sharing instruction for reproducing. Well, here's what the server is returning:
<|python_tag|>{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}<|eom_id|><|start_header_id|>assistant<|end_header_id|>
<|python_tag|>{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}<|eom_id|><|start_header_id|>assistant<|end_header_id|>
<|python_tag|><|python_tag|>{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}<|eom_id|><|start_header_id|>assistant<|end_header_id|>
This is clearly a malformed message. Why is it doing that? Because you aren't stopping generation on <|eom_id|> which should be a stop token.
@ashwinb please have another look. I just rebased to latest main.
Also this might be a bug, but it seems like that test_text_inference.py doesn't run any tests when I use latest main, even when using the example command found in the docs:
llama-stack (main) $ pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py -m "(fireworks or ollama) and llama_3b"
/net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/lib/python3.12/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"
warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
=================================================================================== test session starts ====================================================================================
platform linux -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0 -- /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/bin/python3.12
cachedir: .pytest_cache
rootdir: /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack
configfile: pyproject.toml
plugins: anyio-4.6.2.post1, asyncio-0.24.0
asyncio: mode=Mode.STRICT, default_loop_scope=None
collected 128 items / 128 deselected / 0 selected
=========================================================================== 128 deselected, 5 warnings in 0.15s ============================================================================
As a result I haven't been able to reverify that the E2E tests for this integration are still passing.
@henrytwo We will check why tests are suddenly not getting picked up.