llama-stack icon indicating copy to clipboard operation
llama-stack copied to clipboard

Cerebras Inference Integration

Open henrytwo opened this issue 4 months ago • 4 comments

Adding Cerebras Inference as an API provider.

It looks like the other providers use the legacy OpenAI API but we prefer the new chat completion API. As a result I updated llama_stack/providers/utils/inference/openai_compat.py to support the new one while maintaining backwards compatibility. Let me know if this isn't preferred and I will move the logic to the internals of our integration.

Testing

Build

$ llama stack build --template local-cerebras --name my-cerebras-stack

Configuration

$ llama stack configure my-cerebras-stack
Could not find my-cerebras-stack. Trying conda build name instead...

Llama Stack is composed of several APIs working together. For each API served by the Stack,
we need to configure the providers (implementations) you want to use for these APIs.

Configuring API `inference`...
> Configuring provider `(remote::cerebras)`
Enter value for base_url (default: https://api.cerebras.ai) (required): 
Enter value for api_key (default: csk-abcd<redacted>) (optional): 

Configuring API `memory`...
> Configuring provider `(meta-reference)`

Configuring API `safety`...
> Configuring provider `(meta-reference)`
Do you want to configure llama_guard_shield? (y/n): n
Enter value for enable_prompt_guard (default: False) (optional): 

Configuring API `agents`...
> Configuring provider `(meta-reference)`
Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite): 

Configuring SqliteKVStoreConfig:
Enter value for namespace (optional): 
Enter value for db_path (default: <home>/.llama/runtime/kvstore.db) (required): 

Configuring API `telemetry`...
> Configuring provider `(meta-reference)`

> YAML configuration has been written to `<home>/.llama/builds/conda/my-cerebras-stack-run.yaml`.
You can now run `llama stack run my-cerebras-stack --port PORT`

Running Stack

$ llama stack run my-cerebras-stack --port 3000

Testing with cURL

$ curl --location 'http://localhost:3000/inference/chat_completion' --header 'Content-Type: application/json' --data '{
    "model": "Llama3.1-70B-Instruct",
    "messages": [
        {
            "role": "user",
            "content": "What is the temperature in Seattle right now?"
        }
    ],
    "stream": false,
    "sampling_params": {
        "strategy": "top_p",
        "temperature": 0.5,
        "max_tokens": 100
    },                   
    "tool_choice": "required",
    "tool_prompt_format": "json",
    "tools": [                   
        {
            "tool_name": "getTemperature",
            "description": "Gets the current temperature of a location.",
            "parameters": {                                              
                "location": {
                    "param_type": "string",
                    "description": "The name of the place to get the temperature from in degress celsius.",
                    "required": true                                                                       
                }                   
            }    
        }    
    ]    
}'

Non-Streaming Response

{
  "completion_message": {
    "role": "assistant",
    "content": "",
    "stop_reason": "end_of_turn",
    "tool_calls": [
      {
        "call_id": "17277c905",
        "tool_name": "getTemperature",
        "arguments": {
          "location": "Seattle"
        }
      }
    ]
  },
  "logprobs": null
}

Streaming Response

data: {"event":{"event_type":"start","delta":"","logprobs":null,"stop_reason":null}}

data: {"event":{"event_type":"progress","delta":"","logprobs":null,"stop_reason":null}}

data: {"event":{"event_type":"progress","delta":{"content":{"call_id":"3f0695472","tool_name":"getTemperature","arguments":{"location":"Seattle"}},"parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}

data: {"event":{"event_type":"complete","delta":"","logprobs":null,"stop_reason":"end_of_turn"}}

Unit Tests

See tests/test_cerebras_inference.py

Pre-Commit Checks

trim trailing whitespace.................................................Passed
check python ast.........................................................Passed
check for merge conflicts................................................Passed
check for added large files..............................................Passed
fix end of files.........................................................Passed
Insert license in comments...............................................Passed
flake8...................................................................Passed
Format files with µfmt...................................................Passed

henrytwo avatar Oct 17 '24 23:10 henrytwo