Google Vertex AI provider with Region Selection support
Check for existing issues
- [X] Completed
Describe the feature
Platforms like Cursor.ai and Continue.dev allow the use of Vertex AI as a provider, which offers models like Gemini Flash and Gemini Pro. A key benefit of Vertex AI is region selection, enabling users to pick servers closer to them for lower latency.
Additionally, using Vertex AI provides other advantages, such as unified billing with GCP. Many companies already use Google Cloud, so activating Vertex AI is as simple as enabling a service they’re already paying for.
If applicable, add mockups / screenshots to help present your vision of the feature
No response
Does Vertex AI have a distinct API semantics or is it just an alternate endpoint for Gemini Flash / Gemini Pro? If the latter it may just require tweaking our endpoint code.
See: Assistant Configuration: Custom Endpoint in the docs.
Currently we assume the following URL structure under that endpoint: https://github.com/zed-industries/zed/blob/cdead5760a2c200179e6a51d5a396c33a3e06e3d/crates/google_ai/src/google_ai.rs#L20-L23
But looking at the Vertex AI docs I think the endpoints are alternatively of the form:
https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:streamGenerateContent
I find this interesting because apparently Claude can also run on Vertex.ai with zero downtime
@notpeter
- The URL format is different:
https://${LOCATION_ID}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION_ID}/publishers/google/models/${MODEL_ID}:${GENERATE_CONTENT_API} - Auth token from an IAM user with Vertex User role:
"Authorization: Bearer $(gcloud auth print-access-token)" - Images/documents can be sent inline as base64
%{
"contents" => [
%{
"role" => "user",
"parts" => [
%{
"inlineData" => %{
"mimeType" => "application/pdf",
"data" => base64_content
}
},
%{
"text" => prompt
}
]
}
]
}
I would love to see Vertex AI support for at least Google models in Zed
FYI: Somebody made a proxy server to enable this: https://github.com/prantlf/ovai
i would love to see Vertex AI support in Zed
FYI: Somebody made a proxy server to enable this: https://github.com/prantlf/ovai
did you get it to work with zed?
+1 for support for Vertex AI Model Garden models
+1 Please add support for Vertex, specifically Anthropic models through Vertex.
yes, and vertex ai also uses google adc. Please add!
This would be really helpful in enterprise use cases.
Vertex support, please. Been using Anthropic's models via Bedrock and Vertex, and Vertex tends to be faster.
Indeed, the Claude 4 models have been very slow and highly rate limited on Bedrock. Switching to VertexAI solved all my issues. Support would be greatly appreciated.
+1 Please add support for Vertex
Hi All,
I created a provider for this that is working great for inline edits, agent mode, token counting, thread summation etc. It's really simple to add it to the project:
screenshots
- Simply clone the google.ai. crate and replace google_ai.rs with google_vertex_ai.rs. I've added comments for transparency on what has changed.
code
use std::mem;
use anyhow::{Result, anyhow, bail};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
// MODIFICATION 1: Update API_URL to the correct Vertex AI endpoint.
pub const API_URL: &str = "https://aiplatform.googleapis.com";
// MODIFICATION 2: Change function signature to accept Vertex AI parameters and remove api_key.
pub async fn stream_generate_content(
client: &dyn HttpClient,
api_url: &str,
project_id: &str,
location_id: &str,
access_token: &str,
mut request: GenerateContentRequest,
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
validate_generate_content_request(&request)?;
// The `model` field is emptied as it is provided as a path parameter.
let model_id = mem::take(&mut request.model.model_id);
// MODIFICATION 3: Update URL to the correct Vertex AI format.
let uri = format!(
"{api_url}/v1/projects/{project_id}/locations/{location_id}/publishers/google/models/{model_id}:streamGenerateContent?alt=sse"
);
// MODIFICATION 4: Add Authorization header for bearer token authentication.
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Authorization", format!("Bearer {}", access_token))
.header("Content-Type", "application/json");
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
Ok(reader
.lines()
.filter_map(|line| async move {
match line {
Ok(line) => {
if let Some(line) = line.strip_prefix("data: ") {
match serde_json::from_str(line) {
Ok(response) => Some(Ok(response)),
Err(error) => Some(Err(anyhow!(format!(
"Error parsing JSON: {error:?}\n{line:?}"
)))),
}
} else {
None
}
}
Err(error) => {
Some(Err(anyhow!(error)))
},
}
})
.boxed())
} else {
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
Err(anyhow!(
"error during streamGenerateContent, status code: {:?}, body: {}",
response.status(),
text
))
}
}
// MODIFICATION 5: Change function signature to accept Vertex AI parameters and remove api_key.
pub async fn count_tokens(
client: &dyn HttpClient,
api_url: &str,
project_id: &str,
location_id: &str,
access_token: &str,
request: CountTokensRequest,
) -> Result<CountTokensResponse> {
validate_generate_content_request(&request.generate_content_request)?;
// MODIFICATION 6: Update URL to the correct Vertex AI format.
let uri = format!(
"{api_url}/v1/projects/{project_id}/locations/{location_id}/publishers/google/models/{model_id}:countTokens",
model_id = &request.generate_content_request.model.model_id,
);
// convert requests.generate_content_request.contents to {contents: <requests.generate_content_request.contents>}
// Construct the payload to match the {"contents": [...]} format
#[derive(Serialize)]
struct CountTokensPayload {
contents: Vec<Content>,
}
let payload = CountTokensPayload {
contents: request.generate_content_request.contents,
};
let request_body = serde_json::to_string(&payload)?;
// MODIFICATION 7: Add Authorization header for bearer token authentication.
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(&uri)
.header("Authorization", format!("Bearer {}", access_token))
.header("Content-Type", "application/json");
let http_request = request_builder.body(AsyncBody::from(request_body))?;
let mut response = client.send(http_request).await?;
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
anyhow::ensure!(
response.status().is_success(),
"error during countTokens, status code: {:?}, body: {}",
response.status(),
text
);
Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
}
pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
if request.model.is_empty() {
bail!("Model must be specified");
}
if request.contents.is_empty() {
bail!("Request must contain at least one content item");
}
if let Some(user_content) = request
.contents
.iter()
.find(|content| content.role == Role::User)
{
if user_content.parts.is_empty() {
bail!("User content must contain at least one part");
}
}
Ok(())
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Task {
#[serde(rename = "generateContent")]
GenerateContent,
#[serde(rename = "streamGenerateContent")]
StreamGenerateContent,
#[serde(rename = "countTokens")]
CountTokens,
#[serde(rename = "embedContent")]
EmbedContent,
#[serde(rename = "batchEmbedContents")]
BatchEmbedContents,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
#[serde(default, skip_serializing_if = "ModelName::is_empty")]
pub model: ModelName,
pub contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_instruction: Option<SystemInstruction>,
#[serde(skip_serializing_if = "Option::is_none")]
pub generation_config: Option<GenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_settings: Option<Vec<SafetySetting>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_config: Option<ToolConfig>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub candidates: Option<Vec<GenerateContentCandidate>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_feedback: Option<PromptFeedback>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage_metadata: Option<UsageMetadata>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentCandidate {
#[serde(skip_serializing_if = "Option::is_none")]
pub index: Option<usize>,
pub content: Content,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_ratings: Option<Vec<SafetyRating>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub citation_metadata: Option<CitationMetadata>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Content {
#[serde(default)]
pub parts: Vec<Part>,
pub role: Role,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SystemInstruction {
pub parts: Vec<Part>,
}
#[derive(Debug, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub enum Role {
User,
Model,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Part {
TextPart(TextPart),
InlineDataPart(InlineDataPart),
FunctionCallPart(FunctionCallPart),
FunctionResponsePart(FunctionResponsePart),
ThoughtPart(ThoughtPart),
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TextPart {
pub text: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InlineDataPart {
pub inline_data: GenerativeContentBlob,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerativeContentBlob {
pub mime_type: String,
pub data: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallPart {
pub function_call: FunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionResponsePart {
pub function_response: FunctionResponse,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThoughtPart {
pub thought: bool,
pub thought_signature: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationSource {
#[serde(skip_serializing_if = "Option::is_none")]
pub start_index: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub end_index: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub license: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationMetadata {
pub citation_sources: Vec<CitationSource>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptFeedback {
#[serde(skip_serializing_if = "Option::is_none")]
pub block_reason: Option<String>,
pub safety_ratings: Vec<SafetyRating>,
#[serde(skip_serializing_if = "Option::is_none")]
pub block_reason_message: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_content_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub candidates_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_use_prompt_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thoughts_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_token_count: Option<u64>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThinkingConfig {
pub thinking_budget: u32,
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
pub enum GoogleModelMode {
#[default]
Default,
Thinking {
budget_tokens: Option<u32>,
},
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub candidate_count: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_config: Option<ThinkingConfig>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetySetting {
pub category: HarmCategory,
pub threshold: HarmBlockThreshold,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum HarmCategory {
#[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
Unspecified,
#[serde(rename = "HARM_CATEGORY_DEROGATORY")]
Derogatory,
#[serde(rename = "HARM_CATEGORY_TOXICITY")]
Toxicity,
#[serde(rename = "HARM_CATEGORY_VIOLENCE")]
Violence,
#[serde(rename = "HARM_CATEGORY_SEXUAL")]
Sexual,
#[serde(rename = "HARM_CATEGORY_MEDICAL")]
Medical,
#[serde(rename = "HARM_CATEGORY_DANGEROUS")]
Dangerous,
#[serde(rename = "HARM_CATEGORY_HARASSMENT")]
Harassment,
#[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
HateSpeech,
#[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
SexuallyExplicit,
#[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
DangerousContent,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HarmBlockThreshold {
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
Unspecified,
BlockLowAndAbove,
BlockMediumAndAbove,
BlockOnlyHigh,
BlockNone,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HarmProbability {
#[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
Unspecified,
Negligible,
Low,
Medium,
High,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetyRating {
pub category: HarmCategory,
pub probability: HarmProbability,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensRequest {
pub generate_content_request: GenerateContentRequest,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponse {
pub total_tokens: u64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub args: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionResponse {
pub name: String,
pub response: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub function_declarations: Vec<FunctionDeclaration>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfig {
pub function_calling_config: FunctionCallingConfig,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallingConfig {
pub mode: FunctionCallingMode,
#[serde(skip_serializing_if = "Option::is_none")]
pub allowed_function_names: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum FunctionCallingMode {
Auto,
Any,
None,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
// NOTE: The ModelName struct and its serialization logic are no longer correct for Vertex AI,
// as the model is not part of the path prefix anymore. It's just the ID.
// However, the existing code correctly `mem::take`s the model_id and uses it in the path,
// so this logic can be left as-is without breaking anything. No modification needed here.
#[derive(Debug, Default)]
pub struct ModelName {
pub model_id: String,
}
impl ModelName {
pub fn is_empty(&self) -> bool {
self.model_id.is_empty()
}
}
const MODEL_NAME_PREFIX: &str = "models/";
impl Serialize for ModelName {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
}
}
impl<'de> Deserialize<'de> for ModelName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let string = String::deserialize(deserializer)?;
if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
Ok(Self {
model_id: id.to_string(),
})
} else {
// Vertex AI model names (e.g., in responses) might not have this prefix,
// so we handle that case gracefully.
Ok(Self {
model_id: string,
})
}
}
}
// MODIFICATION STARTS: Model enum updated to only include versions 2.0 and higher.
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
pub enum Model {
#[serde(rename = "gemini-2.0-flash")]
Gemini20Flash,
#[serde(
rename = "gemini-2.5-flash",
alias = "gemini-2.0-flash-thinking-exp",
alias = "gemini-2.5-flash-preview-04-17",
alias = "gemini-2.5-flash-preview-05-20",
alias = "gemini-2.5-flash-preview-latest"
)]
#[default]
Gemini25Flash,
#[serde(
rename = "gemini-2.5-pro",
alias = "gemini-2.0-pro-exp",
alias = "gemini-2.5-pro-preview-latest",
alias = "gemini-2.5-pro-exp-03-25",
alias = "gemini-2.5-pro-preview-03-25",
alias = "gemini-2.5-pro-preview-05-06",
alias = "gemini-2.5-pro-preview-06-05"
)]
Gemini25Pro,
#[serde(rename = "custom")]
Custom {
name: String,
/// The name displayed in the UI, such as in the assistant panel model dropdown menu.
display_name: Option<String>,
max_tokens: u64,
#[serde(default)]
mode: GoogleModelMode,
},
}
impl Model {
pub fn default_fast() -> Self {
Self::Gemini20Flash
}
pub fn id(&self) -> &str {
match self {
Self::Gemini20Flash => "gemini-2.0-flash",
Self::Gemini25Flash => "gemini-2.5-flash",
Self::Gemini25Pro => "gemini-2.5-pro",
Self::Custom { name, .. } => name,
}
}
pub fn request_id(&self) -> &str {
match self {
Self::Gemini20Flash => "gemini-2.0-flash",
Self::Gemini25Flash => "gemini-2.5-flash",
Self::Gemini25Pro => "gemini-2.5-pro",
Self::Custom { name, .. } => name,
}
}
pub fn display_name(&self) -> &str {
match self {
Self::Gemini20Flash => "Gemini 2.0 Flash",
Self::Gemini25Flash => "Gemini 2.5 Flash",
Self::Gemini25Pro => "Gemini 2.5 Pro",
Self::Custom {
name, display_name, ..
} => display_name.as_ref().unwrap_or(name),
}
}
pub fn max_token_count(&self) -> u64 {
match self {
Self::Gemini20Flash => 1_048_576,
Self::Gemini25Flash => 1_048_576,
Self::Gemini25Pro => 1_048_576,
Self::Custom { max_tokens, .. } => *max_tokens,
}
}
pub fn max_output_tokens(&self) -> Option<u64> {
match self {
Model::Gemini20Flash => Some(8_192),
Model::Gemini25Flash => Some(65_536),
Model::Gemini25Pro => Some(65_536),
Model::Custom { .. } => None,
}
}
pub fn supports_tools(&self) -> bool {
true
}
pub fn supports_images(&self) -> bool {
true
}
pub fn mode(&self) -> GoogleModelMode {
match self {
Self::Gemini20Flash => GoogleModelMode::Default,
Self::Gemini25Flash | Self::Gemini25Pro => {
GoogleModelMode::Thinking {
// By default these models are set to "auto", so we preserve that behavior
// but indicate they are capable of thinking mode
budget_tokens: None,
}
}
Self::Custom { mode, .. } => *mode,
}
}
}
impl std::fmt::Display for Model {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.id())
}
}
- Create a provider in the 'language_models' crate called google_vertex.rs
code
use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use google_vertex_ai::{
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
ThinkingConfig, UsageMetadata,
};
use gpui::{
AnyView, App, AsyncApp, Context, Subscription, Task,
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, StopReason,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, RateLimiter, Role,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::sync::{
Arc,
atomic::{self, AtomicU64},
};
use strum::IntoEnumIterator;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
use crate::AllLanguageModelSettings;
use crate::ui::InstructionListItem;
const PROVIDER_ID: &str = "google-vertex-ai";
const PROVIDER_NAME: &str = "Google Vertex AI";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct GoogleVertexSettings {
pub api_url: String,
pub project_id: String, // ADDED
pub location_id: String, // ADDED
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ModelMode {
#[default]
Default,
Thinking {
/// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
budget_tokens: Option<u32>,
},
}
impl From<ModelMode> for GoogleModelMode {
fn from(value: ModelMode) -> Self {
match value {
ModelMode::Default => GoogleModelMode::Default,
ModelMode::Thinking { budget_tokens } => GoogleModelMode::Thinking { budget_tokens },
}
}
}
impl From<GoogleModelMode> for ModelMode {
fn from(value: GoogleModelMode) -> Self {
match value {
GoogleModelMode::Default => ModelMode::Default,
GoogleModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
name: String,
display_name: Option<String>,
max_tokens: u64,
mode: Option<ModelMode>,
}
pub struct GoogleVertexLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: gpui::Entity<State>,
}
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
_subscription: Subscription,
}
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
// Ensure api_url, project_id, and location_id are available for credentials deletion
let settings = AllLanguageModelSettings::get_global(cx)
.google_vertex
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&settings.api_url, &cx) // Use api_url
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
log::info!("Authenticating Google Vertex AI...");
if self.is_authenticated() {
return Task::ready(Ok(()));
}
// The Tokio runtime provided by `gpui::spawn` is not sufficient for `tokio::process`
// or `tokio::task::spawn_blocking`. We must fall back to the standard library's threading
// to run the synchronous `gcloud` command, and use a channel to communicate the
// result back to our async context.
cx.spawn(async move |this, cx| {
let (tx, rx) = futures::channel::oneshot::channel();
std::thread::spawn(move || {
let result = std::process::Command::new("gcloud")
.args(&["auth", "application-default", "print-access-token"])
.output()
.map_err(|e| AuthenticateError::Other(anyhow!("Failed to execute gcloud command: {}", e)));
// Send the result back to the async task, ignoring if the receiver was dropped.
let _ = tx.send(result);
});
// Await the result from the channel.
// First, explicitly handle the channel's `Canceled` error.
// Then, use `?` to propagate the `AuthenticateError` from the command execution.
let token_output = rx.await
.map_err(|_cancelled| AuthenticateError::Other(anyhow!("Authentication task was cancelled")))?
?;
// Retrieve the access token from the gcloud command output.
// Ensure UTF-8 decoding and trim whitespace.
let access_token = String::from_utf8(token_output.stdout)
.map_err(|e| AuthenticateError::Other(anyhow!("Invalid UTF-8 in gcloud output: {}", e)))?
.trim()
.to_string();
// Check the exit status of the gcloud command.
if !token_output.status.success() {
let stderr = String::from_utf8_lossy(&token_output.stderr).into_owned();
return Err(AuthenticateError::Other(anyhow!("gcloud command failed: {}", stderr)));
}
let api_key = access_token; // Use the retrieved token as the API key.
let from_env = false; // This token is dynamically fetched, not from env or keychain.
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
}
impl GoogleVertexLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
Self { http_client, state }
}
fn create_language_model(&self, model: google_vertex_ai::Model) -> Arc<dyn LanguageModel> {
Arc::new(GoogleVertexLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
})
}
}
impl LanguageModelProviderState for GoogleVertexLanguageModelProvider {
type ObservableEntity = State;
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
Some(self.state.clone())
}
}
impl LanguageModelProvider for GoogleVertexLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn icon(&self) -> IconName {
IconName::AiGoogle
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(google_vertex_ai::Model::default()))
}
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(google_vertex_ai::Model::default_fast()))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
// Add base models from google_vertex_ai::Model::iter()
for model in google_vertex_ai::Model::iter() {
if !matches!(model, google_vertex_ai::Model::Custom { .. }) {
models.insert(model.id().to_string(), model);
}
}
// Override with available models from settings
for model in &AllLanguageModelSettings::get_global(cx)
.google_vertex
.available_models
{
models.insert(
model.name.clone(),
google_vertex_ai::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
mode: model.mode.unwrap_or_default().into(),
},
);
}
models
.into_values()
.map(|model| {
Arc::new(GoogleVertexLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
}
fn is_authenticated(&self, cx: &App) -> bool {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
}
}
pub struct GoogleVertexLanguageModel {
id: LanguageModelId,
model: google_vertex_ai::Model,
state: gpui::Entity<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl GoogleVertexLanguageModel {
fn stream_completion(
&self,
request: google_vertex_ai::GenerateContentRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
> {
let http_client = self.http_client.clone();
let Ok((access_token_option, api_url, project_id, location_id)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google_vertex;
(
state.api_key.clone(), // This is the access token for Vertex AI
settings.api_url.clone(),
settings.project_id.clone(), // ADDED
settings.location_id.clone(), // ADDED
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let access_token = access_token_option.context("Missing Google API key (access token)")?;
let request = google_vertex_ai::stream_generate_content(
http_client.as_ref(),
&api_url,
&project_id, // ADDED
&location_id, // ADDED
&access_token,
request,
);
request.await.context("failed to stream completion")
}
.boxed()
}
}
impl LanguageModel for GoogleVertexLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn supports_tools(&self) -> bool {
self.model.supports_tools()
}
fn supports_images(&self) -> bool {
self.model.supports_images()
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto
| LanguageModelToolChoice::Any
| LanguageModelToolChoice::None => true,
}
}
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
LanguageModelToolSchemaFormat::JsonSchemaSubset
}
fn telemetry_id(&self) -> String {
format!("google_vertex/{}", self.model.request_id())
}
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u64> {
self.model.max_output_tokens()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
let model_id = self.model.request_id().to_string();
let request = into_vertex_ai(request, model_id.clone(), self.model.mode());
let http_client = self.http_client.clone();
// Synchronously read the state and settings.
// `read_entity` executes the closure and returns its result directly.
let (access_token_option, api_url, project_id, location_id) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google_vertex;
(
state.api_key.clone(), // This is the access token for Vertex AI (Option<String>)
settings.api_url.clone(), // String
settings.project_id.clone(), // String
settings.location_id.clone(), // String
)
}); // No .unwrap_or_default() here, as read_entity directly returns the tuple
async move {
// Check if the access token is present. If not, return an error.
let access_token = access_token_option
.context("Missing Google API key (access token). Please authenticate.")?;
let response = google_vertex_ai::count_tokens(
http_client.as_ref(),
&api_url,
&project_id,
&location_id,
&access_token,
google_vertex_ai::CountTokensRequest {
generate_content_request: request,
},
)
.await?;
Ok(response.total_tokens)
}
.boxed()
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
let request = into_vertex_ai(
request,
self.model.request_id().to_string(),
self.model.mode(),
);
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request
.await
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
Ok(GoogleVertexEventMapper::new().map_stream(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
pub fn into_vertex_ai(
mut request: LanguageModelRequest,
model_id: String,
mode: GoogleModelMode,
) -> google_vertex_ai::GenerateContentRequest {
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
content
.into_iter()
.flat_map(|content| match content {
language_model::MessageContent::Text(text) => {
if !text.is_empty() {
vec![Part::TextPart(google_vertex_ai::TextPart { text })]
} else {
vec![]
}
}
language_model::MessageContent::Thinking {
text: _,
signature: Some(signature),
} => {
if !signature.is_empty() {
vec![Part::ThoughtPart(google_vertex_ai::ThoughtPart {
thought: true,
thought_signature: signature,
})]
} else {
vec![]
}
}
language_model::MessageContent::Thinking { .. } => {
vec![]
}
language_model::MessageContent::RedactedThinking(_) => vec![],
language_model::MessageContent::Image(image) => {
vec![Part::InlineDataPart(google_vertex_ai::InlineDataPart {
inline_data: google_vertex_ai::GenerativeContentBlob {
mime_type: "image/png".to_string(), // Assuming PNG for simplicity, could derive from format
data: image.source.to_string(), // Assuming base64 encoded for simplicity
},
})]
}
language_model::MessageContent::ToolUse(tool_use) => {
vec![Part::FunctionCallPart(google_vertex_ai::FunctionCallPart {
function_call: google_vertex_ai::FunctionCall {
name: tool_use.name.to_string(),
args: tool_use.input,
},
})]
}
language_model::MessageContent::ToolResult(tool_result) => {
match tool_result.content {
language_model::LanguageModelToolResultContent::Text(text) => {
vec![Part::FunctionResponsePart(
google_vertex_ai::FunctionResponsePart {
function_response: google_vertex_ai::FunctionResponse {
name: tool_result.tool_name.to_string(),
// The API expects a valid JSON object
response: serde_json::json!({
"output": text
}),
},
},
)]
}
language_model::LanguageModelToolResultContent::Image(image) => {
vec![
Part::FunctionResponsePart(google_vertex_ai::FunctionResponsePart {
function_response: google_vertex_ai::FunctionResponse {
name: tool_result.tool_name.to_string(),
// The API expects a valid JSON object
response: serde_json::json!({
"output": "Tool responded with an image"
}),
},
}),
Part::InlineDataPart(google_vertex_ai::InlineDataPart {
inline_data: google_vertex_ai::GenerativeContentBlob {
mime_type: "image/png".to_string(),
data: image.source.to_string(),
},
}),
]
}
}
}
})
.collect()
}
let system_instructions = if request
.messages
.first()
.map_or(false, |msg| matches!(msg.role, Role::System))
{
let message = request.messages.remove(0);
Some(SystemInstruction {
parts: map_content(message.content),
})
} else {
None
};
google_vertex_ai::GenerateContentRequest {
model: google_vertex_ai::ModelName { model_id },
system_instruction: system_instructions,
contents: request
.messages
.into_iter()
.filter_map(|message| {
let parts = map_content(message.content);
if parts.is_empty() {
None
} else {
Some(google_vertex_ai::Content {
parts,
role: match message.role {
Role::User => google_vertex_ai::Role::User,
Role::Assistant => google_vertex_ai::Role::Model,
Role::System => google_vertex_ai::Role::User, // Google AI doesn't have a distinct system role; often maps to user for initial context
},
})
}
})
.collect(),
generation_config: Some(google_vertex_ai::GenerationConfig {
candidate_count: Some(1),
stop_sequences: Some(request.stop),
max_output_tokens: None,
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
thinking_config: match mode {
GoogleModelMode::Thinking { budget_tokens } => {
budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
}
GoogleModelMode::Default => None,
},
top_p: None,
top_k: None,
}),
safety_settings: None, // Safety settings are handled at a different layer or can be configured.
tools: (request.tools.len() > 0).then(|| {
vec![google_vertex_ai::Tool {
function_declarations: request
.tools
.into_iter()
.map(|tool| FunctionDeclaration {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
})
.collect(),
}]
}),
tool_config: request.tool_choice.map(|choice| google_vertex_ai::ToolConfig {
function_calling_config: google_vertex_ai::FunctionCallingConfig {
mode: match choice {
LanguageModelToolChoice::Auto => google_vertex_ai::FunctionCallingMode::Auto,
LanguageModelToolChoice::Any => google_vertex_ai::FunctionCallingMode::Any,
LanguageModelToolChoice::None => google_vertex_ai::FunctionCallingMode::None,
},
allowed_function_names: None,
},
}),
}
}
pub struct GoogleVertexEventMapper {
usage: UsageMetadata,
stop_reason: StopReason,
}
impl GoogleVertexEventMapper {
pub fn new() -> Self {
Self {
usage: UsageMetadata::default(),
stop_reason: StopReason::EndTurn,
}
}
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events
.map(Some)
.chain(futures::stream::once(async { None }))
.flat_map(move |event| {
futures::stream::iter(match event {
Some(Ok(event)) => self.map_event(event),
Some(Err(error)) => {
vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))]
}
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
})
})
}
pub fn map_event(
&mut self,
event: GenerateContentResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
let mut events: Vec<_> = Vec::new();
let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut self.usage, &usage_metadata);
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&self.usage),
)))
}
if let Some(candidates) = event.candidates {
for candidate in candidates {
if let Some(finish_reason) = candidate.finish_reason.as_deref() {
self.stop_reason = match finish_reason {
"STOP" => StopReason::EndTurn,
"MAX_TOKENS" => StopReason::MaxTokens,
_ => {
log::error!("Unexpected google_vertex finish_reason: {finish_reason}");
StopReason::EndTurn
}
};
}
candidate
.content
.parts
.into_iter()
.for_each(|part| match part {
Part::TextPart(text_part) => {
events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
}
Part::InlineDataPart(_) => {}
Part::FunctionCallPart(function_call_part) => {
wants_to_use_tool = true;
let name: Arc<str> = function_call_part.function_call.name.into();
let next_tool_id =
TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
let id: LanguageModelToolUseId =
format!("{}-{}", name, next_tool_id).into();
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id,
name,
is_input_complete: true,
raw_input: function_call_part.function_call.args.to_string(),
input: function_call_part.function_call.args,
},
)));
}
Part::FunctionResponsePart(_) => {}
Part::ThoughtPart(part) => {
events.push(Ok(LanguageModelCompletionEvent::Thinking {
text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
signature: Some(part.thought_signature),
}));
}
});
}
}
// Even when Gemini wants to use a Tool, the API
// responds with `finish_reason: STOP`
if wants_to_use_tool {
self.stop_reason = StopReason::ToolUse;
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
}
events
}
}
pub fn count_google_tokens(
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
// We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
// So we have to use tokenizer from tiktoken_rs to count tokens.
cx.background_spawn(async move {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.string_contents()),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
})
.boxed()
}
fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
if let Some(prompt_token_count) = new.prompt_token_count {
usage.prompt_token_count = Some(prompt_token_count);
}
if let Some(cached_content_token_count) = new.cached_content_token_count {
usage.cached_content_token_count = Some(cached_content_token_count);
}
if let Some(candidates_token_count) = new.candidates_token_count {
usage.candidates_token_count = Some(candidates_token_count);
}
if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
}
if let Some(thoughts_token_count) = new.thoughts_token_count {
usage.thoughts_token_count = Some(thoughts_token_count);
}
if let Some(total_token_count) = new.total_token_count {
usage.total_token_count = Some(total_token_count);
}
}
fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
let input_tokens = prompt_tokens - cached_tokens;
let output_tokens = usage.candidates_token_count.unwrap_or(0);
language_model::TokenUsage {
input_tokens,
output_tokens,
cache_read_input_tokens: cached_tokens,
cache_creation_input_tokens: 0,
}
}
struct ConfigurationView {
state: gpui::Entity<State>,
load_credentials_task: Option<Task<()>>,
}
impl ConfigurationView {
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
cx.observe(&state, |_, _, cx| {
cx.notify();
})
.detach();
let load_credentials_task = Some(cx.spawn_in(window, {
let state = state.clone();
async move |this, cx| {
if let Some(task) = state
.update(cx, |state, cx| state.authenticate(cx))
.log_err()
{
// We don't log an error, because "not signed in" is also an error.
let _ = task.await;
}
this.update(cx, |this, cx| {
this.load_credentials_task = None;
cx.notify();
})
.log_err();
}
}));
Self {
state,
load_credentials_task,
}
}
fn authenticate_gcloud(&mut self, window: &mut Window, cx: &mut Context<Self>) {
println!("Authenticating with gcloud...");
let state = self.state.clone();
self.load_credentials_task = Some(cx.spawn_in(window, {
async move |this, cx| {
if let Some(task) = state
.update(cx, |state, cx| state.authenticate(cx))
.log_err()
{
let _ = task.await;
}
this.update(cx, |this, cx| {
this.load_credentials_task = None;
cx.notify();
})
.log_err();
}
}));
cx.notify();
}
fn reset_gcloud_auth(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
})
.detach_and_log_err(cx);
cx.notify();
}
}
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let is_authenticated = self.state.read(cx).is_authenticated();
if self.load_credentials_task.is_some() {
div().child(Label::new("Attempting to authenticate with gcloud...")).into_any()
} else if !is_authenticated {
v_flex()
.size_full()
.child(Label::new("Please authenticate with Google Cloud to use this provider."))
.child(
List::new()
.child(InstructionListItem::text_only(
"1. Ensure Google Cloud SDK is installed and configured.",
))
.child(InstructionListItem::text_only(
"2. Run 'gcloud auth application-default login' in your terminal.",
))
.child(InstructionListItem::text_only(
"3. Configure your desired Google Cloud Project ID and Location ID in Zed's settings.json file under 'language_models.google_vertex'.",
))
)
.child(
h_flex()
.w_full()
.my_2()
.child(
Button::new("authenticate-gcloud", "Authenticate with gcloud")
.label_size(LabelSize::Small)
.icon_size(IconSize::Small)
.on_click(cx.listener(|this, _, window, cx| this.authenticate_gcloud(window, cx))),
),
)
.child(
Label::new(
"This will attempt to acquire an access token using your
gcloud application-default credentials. You might need to run
'gcloud auth application-default login' manually first."
)
.size(LabelSize::Small).color(Color::Muted),
)
.into_any()
} else {
h_flex()
.mt_1()
.p_1()
// .justify_between() // Removed, button is handled separately
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().background)
.child(
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new("Authenticated with gcloud.")),
)
.child(
Button::new("reset-gcloud-auth", "Clear Token")
.label_size(LabelSize::Small)
.icon(Some(IconName::Trash))
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.tooltip(Tooltip::text("Clear the in-memory access token. You will need to re-authenticate to use the provider."))
.on_click(cx.listener(|this, _, window, cx| this.reset_gcloud_auth(window, cx))),
)
.into_any()
}
}
}
- Update your language_model.rs (in language models) crate to add the new provider
code
use std::sync::Arc;
use client::{Client, UserStore};
use fs::Fs;
use gpui::{App, Context, Entity};
use language_model::LanguageModelRegistry;
use provider::deepseek::DeepSeekLanguageModelProvider;
pub mod provider;
mod settings;
pub mod ui;
use crate::provider::anthropic::AnthropicLanguageModelProvider;
use crate::provider::bedrock::BedrockLanguageModelProvider;
use crate::provider::cloud::CloudLanguageModelProvider;
use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
use crate::provider::google::GoogleLanguageModelProvider;
use crate::provider::google_vertex::GoogleVertexLanguageModelProvider;
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
use crate::provider::mistral::MistralLanguageModelProvider;
use crate::provider::ollama::OllamaLanguageModelProvider;
use crate::provider::open_ai::OpenAiLanguageModelProvider;
use crate::provider::open_router::OpenRouterLanguageModelProvider;
pub use crate::settings::*;
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut App) {
crate::settings::init(fs, cx);
let registry = LanguageModelRegistry::global(cx);
registry.update(cx, |registry, cx| {
register_language_model_providers(registry, user_store, client, cx);
});
}
fn register_language_model_providers(
registry: &mut LanguageModelRegistry,
user_store: Entity<UserStore>,
client: Arc<Client>,
cx: &mut Context<LanguageModelRegistry>,
) {
registry.register_provider(
CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
cx,
);
registry.register_provider(
AnthropicLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
OpenAiLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
OllamaLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
LmStudioLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
DeepSeekLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
GoogleLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider( // NEW REGISTRATION BY DIAB
GoogleVertexLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
MistralLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
BedrockLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
OpenRouterLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
}
-
Add google_vertex to providers.rs
-
Edit settings.rs like so inside the load function:
code
// Google Vertex AI has api_url, project_id and location_id
merge(
&mut settings.google_vertex.api_url,
value.google_vertex.as_ref().and_then(|s| s.api_url.clone()),
);
merge(
&mut settings.google_vertex.project_id,
value.google_vertex.as_ref().and_then(|s| s.project_id.clone()),
);
merge(
&mut settings.google_vertex.location_id,
value.google_vertex.as_ref().and_then(|s| s.location_id.clone()),
);
- Edit your settings to add the project_id and location_id - all you need to do is run gcloud auth application-default login. It will automatically connect if you have already done so. You can also add additional models via settings like with oLlama.
I have tested for the last couple of days and all seems well with inline edit, token counting and agent mode. The only thing I have not confirmed is that ALL the additional models work as I just wanted to get the Google models running for my purposes.
Best Wishes, Devin
Hi All,
I created a provider for this that is working great for inline edits, agent mode, token counting, thread summation etc. It's really simple to add it to the project:
...
I have tested for the last couple of days and all seems well with inline edit, token counting and agent mode. The only thing I have not confirmed is that ALL the additional models work as I just wanted to get the Google models running for my purposes.
Best Wishes, Devin
Woah, strong work, @DevInABoxLLC ... @zed-intelligence, you may have a new contributor here... I'd strike while the iron is hot! Let's get this legitimized and merged in!
@DevInABoxLLC I was able to get your suggestion working locally and am using the Vertex AI Google LLMs now. I'm curious why you didn't create a PR for it. I'm happy to start up a PR if you don't want to but I don't want take any credit for all your work since I really didn't do anything other that implement your comment in code. I don't even really know rust lol.
Hi there! 👋 We're working to clean up our issue tracker by closing older bugs that might not be relevant anymore. If you are able to reproduce this issue in the latest version of Zed, please let us know by commenting on this issue, and it will be kept open. If you can't reproduce it, feel free to close the issue yourself. Otherwise, it will close automatically in 14 days. Thanks for your help!
Issue is not stale and @joeyboey has an active PR https://github.com/zed-industries/zed/pull/40023
@rupesh1 also updated the repo again: https://github.com/joeyboey/zed/commit/2b5891a44a159b8a49f12b26cb8b7ee163dad146