llama-stack icon indicating copy to clipboard operation
llama-stack copied to clipboard

Cerebras Inference Integration

Open henrytwo opened this issue 1 year ago • 4 comments

Adding Cerebras Inference as an API provider.

It looks like the other providers use the legacy OpenAI API but we prefer the new chat completion API. As a result I updated llama_stack/providers/utils/inference/openai_compat.py to support the new one while maintaining backwards compatibility. Let me know if this isn't preferred and I will move the logic to the internals of our integration.

Testing

Build

$ llama stack build --template local-cerebras --name my-cerebras-stack

Configuration

$ llama stack configure my-cerebras-stack
Could not find my-cerebras-stack. Trying conda build name instead...

Llama Stack is composed of several APIs working together. For each API served by the Stack,
we need to configure the providers (implementations) you want to use for these APIs.

Configuring API `inference`...
> Configuring provider `(remote::cerebras)`
Enter value for base_url (default: https://api.cerebras.ai) (required): 
Enter value for api_key (default: csk-abcd<redacted>) (optional): 

Configuring API `memory`...
> Configuring provider `(meta-reference)`

Configuring API `safety`...
> Configuring provider `(meta-reference)`
Do you want to configure llama_guard_shield? (y/n): n
Enter value for enable_prompt_guard (default: False) (optional): 

Configuring API `agents`...
> Configuring provider `(meta-reference)`
Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite): 

Configuring SqliteKVStoreConfig:
Enter value for namespace (optional): 
Enter value for db_path (default: <home>/.llama/runtime/kvstore.db) (required): 

Configuring API `telemetry`...
> Configuring provider `(meta-reference)`

> YAML configuration has been written to `<home>/.llama/builds/conda/my-cerebras-stack-run.yaml`.
You can now run `llama stack run my-cerebras-stack --port PORT`

Running Stack

$ llama stack run my-cerebras-stack --port 3000

Testing with cURL

$ curl --location 'http://localhost:3000/inference/chat_completion' --header 'Content-Type: application/json' --data '{
    "model": "Llama3.1-70B-Instruct",
    "messages": [
        {
            "role": "user",
            "content": "What is the temperature in Seattle right now?"
        }
    ],
    "stream": false,
    "sampling_params": {
        "strategy": "top_p",
        "temperature": 0.5,
        "max_tokens": 100
    },                   
    "tool_choice": "required",
    "tool_prompt_format": "json",
    "tools": [                   
        {
            "tool_name": "getTemperature",
            "description": "Gets the current temperature of a location.",
            "parameters": {                                              
                "location": {
                    "param_type": "string",
                    "description": "The name of the place to get the temperature from in degress celsius.",
                    "required": true                                                                       
                }                   
            }    
        }    
    ]    
}'

Non-Streaming Response

{
  "completion_message": {
    "role": "assistant",
    "content": "",
    "stop_reason": "end_of_turn",
    "tool_calls": [
      {
        "call_id": "17277c905",
        "tool_name": "getTemperature",
        "arguments": {
          "location": "Seattle"
        }
      }
    ]
  },
  "logprobs": null
}

Streaming Response

data: {"event":{"event_type":"start","delta":"","logprobs":null,"stop_reason":null}}

data: {"event":{"event_type":"progress","delta":"","logprobs":null,"stop_reason":null}}

data: {"event":{"event_type":"progress","delta":{"content":{"call_id":"3f0695472","tool_name":"getTemperature","arguments":{"location":"Seattle"}},"parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}

data: {"event":{"event_type":"complete","delta":"","logprobs":null,"stop_reason":"end_of_turn"}}

Unit Tests

See tests/test_cerebras_inference.py

Pre-Commit Checks

trim trailing whitespace.................................................Passed
check python ast.........................................................Passed
check for merge conflicts................................................Passed
check for added large files..............................................Passed
fix end of files.........................................................Passed
Insert license in comments...............................................Passed
flake8...................................................................Passed
Format files with µfmt...................................................Passed

henrytwo avatar Oct 17 '24 23:10 henrytwo

One thing I am unsure about is whether "Agent" support is available out of the box just from implementing "Inference". I did notice other API vendors advertising on README.md that they have Agent support, but I could not find the corresponding implementation in code.

henrytwo avatar Oct 18 '24 01:10 henrytwo

@ashwinb friendly bump on this PR :) Please allow the CI to run for this PR

henrytwo avatar Oct 21 '24 13:10 henrytwo

I did notice other API vendors advertising on README.md that they have Agent support

which other vendors? if you mean Fireworks and Together, that is because they both have Llama Stack distribution endpoints so they make the entirety of the Llama Stack APIs available on their ends. That includes Agents, Memory, etc.

ashwinb avatar Oct 21 '24 21:10 ashwinb

I did notice other API vendors advertising on README.md that they have Agent support

which other vendors? if you mean Fireworks and Together, that is because they both have Llama Stack distribution endpoints so they make the entirety of the Llama Stack APIs available on their ends. That includes Agents, Memory, etc.

Ah I see, I didn't know about the distribution endpoints. I'll take out the ✅ on the Agents column

henrytwo avatar Oct 21 '24 21:10 henrytwo

@henrytwo

Thanks for sharing instruction for reproducing. Well, here's what the server is returning:

<|python_tag|>{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}<|eom_id|><|start_header_id|>assistant<|end_header_id|>

<|python_tag|>{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}<|eom_id|><|start_header_id|>assistant<|end_header_id|>

<|python_tag|><|python_tag|>{"type": "function", "name": "get_weather", "parameters": {"location": "San Francisco, CA"}}<|eom_id|><|start_header_id|>assistant<|end_header_id|>

This is clearly a malformed message. Why is it doing that? Because you aren't stopping generation on <|eom_id|> which should be a stop token.

ashwinb avatar Nov 13 '24 19:11 ashwinb

@ashwinb please have another look. I just rebased to latest main.

Also this might be a bug, but it seems like that test_text_inference.py doesn't run any tests when I use latest main, even when using the example command found in the docs:

llama-stack (main) $ pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py -m "(fireworks or ollama) and llama_3b"
/net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/lib/python3.12/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
=================================================================================== test session starts ====================================================================================
platform linux -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0 -- /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/bin/python3.12
cachedir: .pytest_cache
rootdir: /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack
configfile: pyproject.toml
plugins: anyio-4.6.2.post1, asyncio-0.24.0
asyncio: mode=Mode.STRICT, default_loop_scope=None
collected 128 items / 128 deselected / 0 selected                                                                                                                                          

=========================================================================== 128 deselected, 5 warnings in 0.15s ============================================================================

As a result I haven't been able to reverify that the E2E tests for this integration are still passing.

henrytwo avatar Nov 25 '24 16:11 henrytwo

@henrytwo We will check why tests are suddenly not getting picked up.

ashwinb avatar Dec 04 '24 05:12 ashwinb