diff --git a/Cargo.toml b/Cargo.toml index e67ee8c462..383ee0d142 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -285,6 +285,7 @@ git_ui = { path = "crates/git_ui" } go_to_line = { path = "crates/go_to_line" } google_ai = { path = "crates/google_ai" } google_vertex_ai = { path = "crates/google_vertex_ai" } +anthropic_vertex_ai = { path = "crates/anthropic_vertex_ai" } gpui = { path = "crates/gpui", default-features = false, features = [ "http_client", ] } diff --git a/crates/anthropic_vertex_ai/Cargo.toml b/crates/anthropic_vertex_ai/Cargo.toml new file mode 100644 index 0000000000..61b3b52a1c --- /dev/null +++ b/crates/anthropic_vertex_ai/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "anthropic_vertex_ai" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[features] +default = [] +schemars = ["dep:schemars"] + +[lints] +workspace = true + +[lib] +path = "src/anthropic_vertex_ai.rs" + +[dependencies] +anyhow.workspace = true +chrono.workspace = true +futures.workspace = true +http_client.workspace = true +schemars = { workspace = true, optional = true } +anthropic.workspace = true +serde.workspace = true +serde_json.workspace = true +strum.workspace = true +thiserror.workspace = true +workspace-hack.workspace = true diff --git a/crates/anthropic_vertex_ai/src/anthropic_vertex_ai.rs b/crates/anthropic_vertex_ai/src/anthropic_vertex_ai.rs new file mode 100644 index 0000000000..7611b3a5e6 --- /dev/null +++ b/crates/anthropic_vertex_ai/src/anthropic_vertex_ai.rs @@ -0,0 +1,744 @@ +use std::time::Duration; + +use anthropic::{AnthropicError, ApiError}; +use anyhow::{Context as _, Result, anyhow}; +use chrono::{DateTime, Utc}; +use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; +use http_client::http::{HeaderMap, HeaderValue}; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Serialize}; +use strum::EnumIter; + +#[derive(Clone, Debug, Default, Deserialize)] +pub struct AnthropicVertexAISettings { + pub project_id: Option, + pub location: Option, +} + +pub const ANTHROPIC_API_URL: &str = "https://aiplatform.googleapis.com"; + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct AnthropicVertexModelCacheConfiguration { + pub min_total_token: u64, + pub should_speculate: bool, + pub max_cache_anchors: usize, +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub enum ModelMode { + #[default] + Default, + Thinking { + budget_tokens: Option, + }, +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] +pub enum Model { + #[serde(rename = "claude-opus-4", alias = "claude-opus-4@20250514")] + ClaudeOpus4, + #[serde(rename = "claude-opus-4-thinking")] + ClaudeOpus4Thinking, + #[default] + #[serde(rename = "claude-sonnet-4", alias = "claude-sonnet-4@20250514")] + ClaudeSonnet4, + #[serde(rename = "claude-sonnet-4-thinking")] + ClaudeSonnet4Thinking, + #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet@20250219")] + Claude3_7Sonnet, + #[serde(rename = "claude-3-7-sonnet-thinking")] + Claude3_7SonnetThinking, + #[serde(rename = "custom")] + Custom { + name: String, + max_tokens: u64, + /// The name displayed in the UI, such as in the assistant panel model dropdown menu. + display_name: Option, + /// Override this model with a different Anthropic model for tool calls. + tool_override: Option, + /// Indicates whether this custom model supports caching. + cache_configuration: Option, + max_output_tokens: Option, + default_temperature: Option, + #[serde(default)] + mode: ModelMode, + }, +} + +impl Model { + pub fn default_fast() -> Self { + Self::ClaudeSonnet4 + } + + pub fn from_id(id: &str) -> Result { + if id.starts_with("claude-opus-4") { + return Ok(Self::ClaudeOpus4Thinking); + } + + if id.starts_with("claude-opus-4") { + return Ok(Self::ClaudeOpus4); + } + + if id.starts_with("claude-sonnet-4") { + return Ok(Self::ClaudeSonnet4Thinking); + } + + if id.starts_with("claude-sonnet-4") { + return Ok(Self::ClaudeSonnet4); + } + + if id.starts_with("claude-3-7-sonnet") { + return Ok(Self::Claude3_7SonnetThinking); + } + + if id.starts_with("claude-3-7-sonnet") { + return Ok(Self::Claude3_7Sonnet); + } + + Err(anyhow!("invalid model ID: {id}")) + } + + pub fn id(&self) -> &str { + match self { + Self::ClaudeOpus4 => "claude-opus-4@20250514", + Self::ClaudeOpus4Thinking => "claude-opus-4@20250514", + Self::ClaudeSonnet4 => "claude-sonnet-4@20250514", + Self::ClaudeSonnet4Thinking => "claude-sonnet-4@20250514", + Self::Claude3_7Sonnet => " claude-3-7-sonnet@20250219", + Self::Claude3_7SonnetThinking => " claude-3-7-sonnet@20250219", + Self::Custom { name, .. } => name, + } + } + + /// The id of the model that should be used for making API requests + pub fn request_id(&self) -> &str { + match self { + Self::ClaudeOpus4 | Self::ClaudeOpus4Thinking => "claude-opus-4@20250514", + Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking => "claude-sonnet-4@20250514", + Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => " claude-3-7-sonnet@20250219", + Self::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::ClaudeOpus4 => "Claude Opus 4", + Self::ClaudeOpus4Thinking => "Claude Opus 4 Thinking", + Self::ClaudeSonnet4 => "Claude Sonnet 4", + Self::ClaudeSonnet4Thinking => "Claude Sonnet 4 Thinking", + Self::Claude3_7Sonnet => "Claude 3.7 Sonnet", + Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking", + Self::Custom { + name, display_name, .. + } => display_name.as_ref().unwrap_or(name), + } + } + + pub fn cache_configuration(&self) -> Option { + match self { + Self::ClaudeOpus4 + | Self::ClaudeOpus4Thinking + | Self::ClaudeSonnet4 + | Self::ClaudeSonnet4Thinking + | Self::Claude3_7Sonnet + | Self::Claude3_7SonnetThinking => Some(AnthropicVertexModelCacheConfiguration { + min_total_token: 2_048, + should_speculate: true, + max_cache_anchors: 4, + }), + Self::Custom { + cache_configuration, + .. + } => cache_configuration.clone(), + } + } + + pub fn max_token_count(&self) -> u64 { + match self { + Self::ClaudeOpus4 + | Self::ClaudeOpus4Thinking + | Self::ClaudeSonnet4 + | Self::ClaudeSonnet4Thinking + | Self::Claude3_7Sonnet + | Self::Claude3_7SonnetThinking => 200_000, + Self::Custom { max_tokens, .. } => *max_tokens, + } + } + + pub fn max_output_tokens(&self) -> u64 { + match self { + Self::ClaudeOpus4 + | Self::ClaudeOpus4Thinking + | Self::ClaudeSonnet4 + | Self::ClaudeSonnet4Thinking + | Self::Claude3_7Sonnet + | Self::Claude3_7SonnetThinking => 8_192, + Self::Custom { + max_output_tokens, .. + } => max_output_tokens.unwrap_or(4_096), + } + } + + pub fn default_temperature(&self) -> f32 { + match self { + Self::ClaudeOpus4 + | Self::ClaudeOpus4Thinking + | Self::ClaudeSonnet4 + | Self::ClaudeSonnet4Thinking + | Self::Claude3_7Sonnet + | Self::Claude3_7SonnetThinking => 1.0, + Self::Custom { + default_temperature, + .. + } => default_temperature.unwrap_or(1.0), + } + } + + pub fn mode(&self) -> ModelMode { + match self { + Self::ClaudeOpus4 | Self::ClaudeSonnet4 | Self::Claude3_7Sonnet => ModelMode::Default, + Self::ClaudeOpus4Thinking + | Self::ClaudeSonnet4Thinking + | Self::Claude3_7SonnetThinking => ModelMode::Thinking { + budget_tokens: Some(4_096), + }, + Self::Custom { mode, .. } => mode.clone(), + } + } + + pub fn tool_model_id(&self) -> &str { + if let Self::Custom { + tool_override: Some(tool_override), + .. + } = self + { + tool_override + } else { + self.request_id() + } + } +} + +pub async fn complete( + client: &dyn HttpClient, + api_url: &str, + request: Request, +) -> Result { + let uri = format!("{api_url}/v1/messages"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json"); + + let serialized_request = + serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; + let request = request_builder + .body(AsyncBody::from(serialized_request)) + .map_err(AnthropicError::BuildRequestBody)?; + + let mut response = client + .send(request) + .await + .map_err(AnthropicError::HttpSend)?; + let status_code = response.status(); + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse)?; + + if status_code.is_success() { + Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?) + } else { + Err(AnthropicError::HttpResponseError { + status_code, + message: body, + }) + } +} + +pub async fn stream_completion( + client: &dyn HttpClient, + api_url: &str, + project_id: &str, + location_id: &str, + access_token: &str, + request: Request, +) -> Result>, AnthropicError> { + stream_completion_with_rate_limit_info( + client, + api_url, + project_id, + location_id, + access_token, + request, + ) + .await + .map(|output| output.0) +} + +/// An individual rate limit. +#[derive(Debug)] +pub struct RateLimit { + pub limit: usize, + pub remaining: usize, + pub reset: DateTime, +} + +impl RateLimit { + fn from_headers(resource: &str, headers: &HeaderMap) -> Result { + let limit = + get_header(&format!("anthropic-ratelimit-{resource}-limit"), headers)?.parse()?; + let remaining = get_header( + &format!("anthropic-ratelimit-{resource}-remaining"), + headers, + )? + .parse()?; + let reset = DateTime::parse_from_rfc3339(get_header( + &format!("anthropic-ratelimit-{resource}-reset"), + headers, + )?)? + .to_utc(); + + Ok(Self { + limit, + remaining, + reset, + }) + } +} + +/// +#[derive(Debug)] +pub struct RateLimitInfo { + pub retry_after: Option, + pub requests: Option, + pub tokens: Option, + pub input_tokens: Option, + pub output_tokens: Option, +} + +impl RateLimitInfo { + fn from_headers(headers: &HeaderMap) -> Self { + // Check if any rate limit headers exist + let has_rate_limit_headers = headers + .keys() + .any(|k| k == "retry-after" || k.as_str().starts_with("anthropic-ratelimit-")); + + if !has_rate_limit_headers { + return Self { + retry_after: None, + requests: None, + tokens: None, + input_tokens: None, + output_tokens: None, + }; + } + + Self { + retry_after: parse_retry_after(headers), + requests: RateLimit::from_headers("requests", headers).ok(), + tokens: RateLimit::from_headers("tokens", headers).ok(), + input_tokens: RateLimit::from_headers("input-tokens", headers).ok(), + output_tokens: RateLimit::from_headers("output-tokens", headers).ok(), + } + } +} + +/// Parses the Retry-After header value as an integer number of seconds (anthropic always uses +/// seconds). Note that other services might specify an HTTP date or some other format for this +/// header. Returns `None` if the header is not present or cannot be parsed. +pub fn parse_retry_after(headers: &HeaderMap) -> Option { + headers + .get("retry-after") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()) + .map(Duration::from_secs) +} + +fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> anyhow::Result<&'a str> { + Ok(headers + .get(key) + .with_context(|| format!("missing header `{key}`"))? + .to_str()?) +} + +pub async fn stream_completion_with_rate_limit_info( + client: &dyn HttpClient, + api_url: &str, + project_id: &str, + location_id: &str, + access_token: &str, + request: Request, +) -> Result< + ( + BoxStream<'static, Result>, + Option, + ), + AnthropicError, +> { + let model_id = request.model.clone(); + let request = StreamingRequest { + base: request, + stream: true, + }; + + let endpoint = if location_id == "global" { + "https://{api_url}".to_string() + } else { + format!("https://{location_id}-{api_url}") + }; + + let uri = format!( + "{endpoint}/v1/projects/{project_id}/locations/{location_id}/publishers/anthropic/models/{model_id}:streamRawPredict" + ); + + // MODIFICATION 4: Add Authorization header for bearer token authentication. + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Authorization", format!("Bearer {}", access_token)) + .header("Content-Type", "application/json"); + + let serialized_request = + serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; + let request = request_builder + .body(AsyncBody::from(serialized_request)) + .map_err(AnthropicError::BuildRequestBody)?; + + let mut response = client + .send(request) + .await + .map_err(AnthropicError::HttpSend)?; + let rate_limits = RateLimitInfo::from_headers(response.headers()); + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + let stream = reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + let line = line.strip_prefix("data: ")?; + match serde_json::from_str(line) { + Ok(response) => Some(Ok(response)), + Err(error) => Some(Err(AnthropicError::DeserializeResponse(error))), + } + } + Err(error) => Some(Err(AnthropicError::ReadResponse(error))), + } + }) + .boxed(); + Ok((stream, Some(rate_limits))) + } else if response.status().as_u16() == 529 { + Err(AnthropicError::ServerOverloaded { + retry_after: rate_limits.retry_after, + }) + } else if let Some(retry_after) = rate_limits.retry_after { + Err(AnthropicError::RateLimit { retry_after }) + } else { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse)?; + + match serde_json::from_str::(&body) { + Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)), + Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError { + status_code: response.status(), + message: body, + }), + } + } +} + +#[derive(Debug, Serialize, Deserialize, Copy, Clone)] +#[serde(rename_all = "lowercase")] +pub enum CacheControlType { + Ephemeral, +} + +#[derive(Debug, Serialize, Deserialize, Copy, Clone)] +pub struct CacheControl { + #[serde(rename = "type")] + pub cache_type: CacheControlType, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub content: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum RequestContent { + #[serde(rename = "text")] + Text { + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, + #[serde(rename = "thinking")] + Thinking { + thinking: String, + signature: String, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, + #[serde(rename = "redacted_thinking")] + RedactedThinking { data: String }, + #[serde(rename = "image")] + Image { + source: ImageSource, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, + #[serde(rename = "tool_result")] + ToolResult { + tool_use_id: String, + is_error: bool, + content: ToolResultContent, + #[serde(skip_serializing_if = "Option::is_none")] + cache_control: Option, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolResultContent { + Plain(String), + Multipart(Vec), +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ToolResultPart { + Text { text: String }, + Image { source: ImageSource }, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ResponseContent { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "thinking")] + Thinking { thinking: String }, + #[serde(rename = "redacted_thinking")] + RedactedThinking { data: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ImageSource { + #[serde(rename = "type")] + pub source_type: String, + pub media_type: String, + pub data: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Tool { + pub name: String, + pub description: String, + pub input_schema: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ToolChoice { + Auto, + Any, + Tool { name: String }, + None, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum Thinking { + Enabled { budget_tokens: Option }, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum StringOrContents { + String(String), + Content(Vec), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Request { + #[serde(skip)] + pub model: String, + pub anthropic_version: String, + pub max_tokens: u64, + pub messages: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub thinking: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub system: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub stop_sequences: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub top_p: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct StreamingRequest { + #[serde(flatten)] + pub base: Request, + pub stream: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Metadata { + pub user_id: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Usage { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cache_creation_input_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cache_read_input_tokens: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Response { + pub id: String, + #[serde(rename = "type")] + pub response_type: String, + pub role: Role, + pub content: Vec, + pub model: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub stop_reason: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub stop_sequence: Option, + pub usage: Usage, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum Event { + #[serde(rename = "message_start")] + MessageStart { message: Response }, + #[serde(rename = "content_block_start")] + ContentBlockStart { + index: usize, + content_block: ResponseContent, + }, + #[serde(rename = "content_block_delta")] + ContentBlockDelta { index: usize, delta: ContentDelta }, + #[serde(rename = "content_block_stop")] + ContentBlockStop { index: usize }, + #[serde(rename = "message_delta")] + MessageDelta { delta: MessageDelta, usage: Usage }, + #[serde(rename = "message_stop")] + MessageStop, + #[serde(rename = "ping")] + Ping, + #[serde(rename = "error")] + Error { error: ApiError }, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ContentDelta { + #[serde(rename = "text_delta")] + TextDelta { text: String }, + #[serde(rename = "thinking_delta")] + ThinkingDelta { thinking: String }, + #[serde(rename = "signature_delta")] + SignatureDelta { signature: String }, + #[serde(rename = "input_json_delta")] + InputJsonDelta { partial_json: String }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MessageDelta { + pub stop_reason: Option, + pub stop_sequence: Option, +} + +pub fn parse_prompt_too_long(message: &str) -> Option { + message + .strip_prefix("prompt is too long: ")? + .split_once(" tokens")? + .0 + .parse() + .ok() +} + +#[test] +fn test_match_window_exceeded() { + let error = ApiError { + error_type: "invalid_request_error".to_string(), + message: "prompt is too long: 220000 tokens > 200000".to_string(), + }; + assert_eq!(error.match_window_exceeded(), Some(220_000)); + + let error = ApiError { + error_type: "invalid_request_error".to_string(), + message: "prompt is too long: 1234953 tokens".to_string(), + }; + assert_eq!(error.match_window_exceeded(), Some(1234953)); + + let error = ApiError { + error_type: "invalid_request_error".to_string(), + message: "not a prompt length error".to_string(), + }; + assert_eq!(error.match_window_exceeded(), None); + + let error = ApiError { + error_type: "rate_limit_error".to_string(), + message: "prompt is too long: 12345 tokens".to_string(), + }; + assert_eq!(error.match_window_exceeded(), None); + + let error = ApiError { + error_type: "invalid_request_error".to_string(), + message: "prompt is too long: invalid tokens".to_string(), + }; + assert_eq!(error.match_window_exceeded(), None); +} diff --git a/crates/anthropic_vertex_ai/task.md b/crates/anthropic_vertex_ai/task.md new file mode 100644 index 0000000000..5e159bf6f1 --- /dev/null +++ b/crates/anthropic_vertex_ai/task.md @@ -0,0 +1,79 @@ +Task ID: TA001 - Integrate Anthropic Models via Google Vertex AI** + +**Objective:** +To develop a new language model provider, `anthropic_vertex_ai`, that seamlessly integrates Anthropic's models (e.g., Claude) into the Zed editor via the Google Cloud Vertex AI platform. + +**Background:** +While Zed has a direct integration with Anthropic's API, many users operate within the Google Cloud ecosystem. Vertex AI provides access to third-party models like Anthropic's through its own endpoint. This task involves creating a new provider that bridges the existing `anthropic` API logic with the authentication and endpoint requirements of Google Cloud. + +This integration will not use explicit API keys. Instead, it will leverage Google's Application Default Credentials (ADC), a standard mechanism for authenticating GCP services, ensuring a secure and streamlined user experience. Configuration will be provided through `settings.json` to specify the required `project_id` and `location` for the Vertex AI endpoint. + +**Key Requirements:** +- **Authentication:** Must use Google Cloud's Application Default Credentials (ADC) for all API requests. The implementation should not handle manual tokens. +- **Configuration:** The provider must be configurable via `settings.json`, allowing the user to specify their Google Cloud `project_id` and `location`. +- **Endpoint Construction:** Must dynamically construct the correct Vertex AI endpoint URL for each request, in the format: `https://$LOCATION-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/$LOCATION/publishers/anthropic/models/$MODEL:streamRawPredict`. +- **Payload Adaptation:** The JSON payload sent to the endpoint must be modified to: + - Include the mandatory field: `"anthropic_version": "vertex-2023-10-16"`. + - Exclude the `model` field, as it is specified in the URL. +- **Integration:** The new provider must be a first-class citizen within Zed, appearing in the model selection list and functioning identically to other integrated providers. + +**Implementation Plan:** + +**Step 1: Foundational Analysis & Crate Setup** + +* **Action 1.1: Analyze `google_vertex` Crate:** Thoroughly examine `crates/google_vertex/src/google_vertex.rs` to understand its implementation of ADC-based authentication and how it reads settings like `project_id` and `location`. This will serve as the template for our authentication logic. +* **Action 1.2: Define Configuration Struct:** In a new file, `crates/anthropic_vertex_ai/src/lib.rs`, define the `AnthropicVertexAISettings` struct. This struct will deserialize the `project_id` and `location` from the user's `settings.json`. +* **Action 1.3: Update `Cargo.toml`:** Create/update the `Cargo.toml` file for the `anthropic_vertex_ai` crate. It should include dependencies from both `anthropic` (for serde structs) and `google_vertex` (for GCP-related dependencies like `gcp_auth`). +* **Action 1.4: Create `lib.rs`:** Ensure `crates/anthropic_vertex_ai/src/lib.rs` exists to house the `LanguageModelProvider` implementation and serve as the crate's entry point. + +**Step 2: Adapt Core Anthropic Logic** + +* **Action 2.1: Modify `Request` Struct:** In `crates/anthropic_vertex_ai/src/anthropic_vertex_ai.rs`, modify the main `Request` struct: + - Add a new field: `pub anthropic_version: &'static str`. + - Remove the existing `pub model: String` field. +* **Action 2.2: Refactor Completion Functions:** Refactor the `stream_completion_with_rate_limit_info` function to be more generic. + - It will now accept the fully-constructed Vertex AI endpoint URL as a parameter. + - It will accept an ADC-aware `HttpClient` instance instead of a simple API key. + - The logic for setting the `Authorization` header will be updated to use a `Bearer` token provided by the `HttpClient`. + +**Step 3: Implement the `LanguageModelProvider`** + +* **Action 3.1: Define Provider Struct:** In `crates/anthropic_vertex_ai/src/lib.rs`, define the main `AnthropicVertexAIProvider` struct. It will store the settings defined in Action 1.2. +* **Action 3.2: Implement `LanguageModelProvider` Trait:** Implement the `language_model::LanguageModelProvider` trait for `AnthropicVertexAIProvider`. +* **Action 3.3: Implement Core Logic:** The trait methods will contain the central logic: + 1. On initialization, the provider will create an `HttpClient` configured to use Google's ADC, following the pattern in the `google_vertex` crate. + 2. For each completion request, it will dynamically construct the full, model-specific Vertex AI URL using the configured `project_id`, `location`, and the requested model name. + 3. It will create an instance of the modified `Request` struct from `anthropic_vertex_ai.rs`, setting the `anthropic_version` field correctly. + 4. Finally, it will call the refactored `stream_completion_with_rate_limit_info` function, passing the authenticated client and the constructed request. + +**Step 4: Final Integration** + +* **Action 4.1: Workspace Integration:** Add `anthropic_vertex_ai` to the main workspace `Cargo.toml` to link the new crate. +* **Action 4.2: Module Declaration:** Add `pub mod anthropic_vertex_ai;` to `crates/language_models/src/provider.rs` to make the module visible. +* **Action 4.3: Provider Registration:** In `crates/language_models/src/lib.rs`, update the central list of language model providers to include an instance of `AnthropicVertexAIProvider`. + +**Verification Plan:** + +* **Compile-Time Verification:** At each major step, ask the human to review the code for compilation errors and adherence to project standards. +* **Configuration Verification:** The implementation will be tested against a `settings.json` file configured as follows: + ```json + "language_servers": { + "anthropic-vertex": { + "enabled": true, + "project_id": "your-gcp-project-id", + "location": "europe-west1" + } + }, + "assistant": { + "default_model": { + "provider": "anthropic-vertex", + "name": "claude-sonnet-4@20250514" + } + } + ``` +* **Runtime Verification:** + 1. Launch Zed with the above configuration. + 2. Ensure the local environment is authenticated with GCP (e.g., via `gcloud auth application-default login`). + 3. Open the assistant panel and confirm that `"anthropic-vertex/claude-sonnet-4@20250514"` is the selected model. + 4. Send a test prompt to the assistant. + 5. **Success Condition:** A valid, streamed response is received from the assistant, confirming that the entire chain—from configuration and authentication to request execution and response parsing—is working correctly. diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 33c91880b5..991d72525b 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -33,6 +33,7 @@ fs.workspace = true futures.workspace = true google_ai = { workspace = true, features = ["schemars"] } google_vertex_ai = { workspace = true, features = ["schemars"] } +anthropic_vertex_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true gpui_tokio.workspace = true http_client.workspace = true diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 854d528ef4..b4ceefb2b3 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -10,6 +10,7 @@ mod settings; pub mod ui; use crate::provider::anthropic::AnthropicLanguageModelProvider; +use crate::provider::anthropic_vertex::AnthropicVertexLanguageModelProvider; use crate::provider::bedrock::BedrockLanguageModelProvider; use crate::provider::cloud::CloudLanguageModelProvider; use crate::provider::copilot_chat::CopilotChatLanguageModelProvider; @@ -72,6 +73,11 @@ fn register_language_model_providers( GoogleVertexLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider( + // NEW REGISTRATION BY DIAB + AnthropicVertexLanguageModelProvider::new(client.http_client(), cx), + cx, + ); registry.register_provider( MistralLanguageModelProvider::new(client.http_client(), cx), cx, diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index 6c1bd602c0..a1445edf0a 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -1,4 +1,5 @@ pub mod anthropic; +pub mod anthropic_vertex; pub mod bedrock; pub mod cloud; pub mod copilot_chat; diff --git a/crates/language_models/src/provider/anthropic_vertex.rs b/crates/language_models/src/provider/anthropic_vertex.rs new file mode 100644 index 0000000000..638fa77f88 --- /dev/null +++ b/crates/language_models/src/provider/anthropic_vertex.rs @@ -0,0 +1,1119 @@ +use crate::AllLanguageModelSettings; +use crate::ui::InstructionListItem; +use anthropic::AnthropicError; +use anthropic_vertex_ai::{ + ContentDelta, Event, ModelMode, ResponseContent, ToolResultContent, ToolResultPart, Usage, +}; +use anyhow::{Result, anyhow}; +use collections::{BTreeMap, HashMap}; +use credentials_provider::CredentialsProvider; +use futures::Stream; +use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; +use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; +use http_client::HttpClient; +use language_model::{ + AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, + LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent, + RateLimiter, Role, +}; +use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsStore}; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; +use strum::IntoEnumIterator; +use ui::{Icon, IconName, List, Tooltip, prelude::*}; +use util::ResultExt; + +const PROVIDER_ID: &str = "anthropic-vertex-ai"; +const PROVIDER_NAME: &str = "Anthropic Vertex AI"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct AnthropicVertexSettings { + pub api_url: String, + pub project_id: String, // ADDED + pub location_id: String, // ADDED + pub available_models: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + /// The model's name in the Anthropic API. e.g. claude-3-5-sonnet-latest, claude-3-opus-20240229, etc + pub name: String, + /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel. + pub display_name: Option, + /// The model's context window size. + pub max_tokens: u64, + /// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling. + pub tool_override: Option, + /// Configuration of Anthropic's caching API. + pub cache_configuration: Option, + pub max_output_tokens: Option, + pub default_temperature: Option, + #[serde(default)] + pub extra_beta_headers: Vec, + /// The model's mode (e.g. thinking) + pub mode: Option, +} + +pub struct AnthropicVertexLanguageModelProvider { + http_client: Arc, + state: gpui::Entity, +} + +pub struct State { + api_key: Option, + api_key_from_env: bool, + _subscription: Subscription, +} + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut Context) -> Task> { + let credentials_provider = ::global(cx); + // Ensure api_url, project_id, and location_id are available for credentials deletion + let settings = AllLanguageModelSettings::get_global(cx) + .google_vertex + .clone(); + + cx.spawn(async move |this, cx| { + credentials_provider + .delete_credentials(&settings.api_url, &cx) // Use api_url + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = None; + this.api_key_from_env = false; + cx.notify(); + }) + }) + } + + fn authenticate(&self, cx: &mut Context) -> Task> { + log::info!("Authenticating Google Vertex AI..."); + + if self.is_authenticated() { + return Task::ready(Ok(())); + } + + // The Tokio runtime provided by `gpui::spawn` is not sufficient for `tokio::process` + // or `tokio::task::spawn_blocking`. We must fall back to the standard library's threading + // to run the synchronous `gcloud` command, and use a channel to communicate the + // result back to our async context. + cx.spawn(async move |this, cx| { + let (tx, rx) = futures::channel::oneshot::channel(); + + std::thread::spawn(move || { + let result = std::process::Command::new("gcloud") + .args(&["auth", "application-default", "print-access-token"]) + .output() + .map_err(|e| { + AuthenticateError::Other(anyhow!("Failed to execute gcloud command: {}", e)) + }); + + // Send the result back to the async task, ignoring if the receiver was dropped. + let _ = tx.send(result); + }); + + // Await the result from the channel. + // First, explicitly handle the channel's `Canceled` error. + // Then, use `?` to propagate the `AuthenticateError` from the command execution. + let token_output = rx.await.map_err(|_cancelled| { + AuthenticateError::Other(anyhow!("Authentication task was cancelled")) + })??; + + // Retrieve the access token from the gcloud command output. + // Ensure UTF-8 decoding and trim whitespace. + let access_token = String::from_utf8(token_output.stdout) + .map_err(|e| { + AuthenticateError::Other(anyhow!("Invalid UTF-8 in gcloud output: {}", e)) + })? + .trim() + .to_string(); + + // Check the exit status of the gcloud command. + if !token_output.status.success() { + let stderr = String::from_utf8_lossy(&token_output.stderr).into_owned(); + return Err(AuthenticateError::Other(anyhow!( + "gcloud command failed: {}", + stderr + ))); + } + + let api_key = access_token; // Use the retrieved token as the API key. + let from_env = false; // This token is dynamically fetched, not from env or keychain. + + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + })?; + + Ok(()) + }) + } +} + +impl AnthropicVertexLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| State { + api_key: None, + api_key_from_env: false, + _subscription: cx.observe_global::(|_, cx| { + cx.notify(); + }), + }); + + Self { http_client, state } + } + + fn create_language_model(&self, model: anthropic_vertex_ai::Model) -> Arc { + Arc::new(AnthropicVertexModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } +} + +impl LanguageModelProviderState for AnthropicVertexLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for AnthropicVertexLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn icon(&self) -> IconName { + IconName::AiAnthropic + } + + fn default_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(anthropic_vertex_ai::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(anthropic_vertex_ai::Model::default_fast())) + } + + fn recommended_models(&self, _cx: &App) -> Vec> { + [ + anthropic_vertex_ai::Model::ClaudeSonnet4, + anthropic_vertex_ai::Model::ClaudeSonnet4Thinking, + ] + .into_iter() + .map(|model| self.create_language_model(model)) + .collect() + } + + fn provided_models(&self, cx: &App) -> Vec> { + let mut models = BTreeMap::default(); + + // Add base models from anthropic_vertex_ai::Model::iter() + for model in anthropic_vertex_ai::Model::iter() { + if !matches!(model, anthropic_vertex_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + // Override with available models from settings + for model in AllLanguageModelSettings::get_global(cx) + .anthropic_vertex + .available_models + .iter() + { + models.insert( + model.name.clone(), + anthropic_vertex_ai::Model::Custom { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + tool_override: model.tool_override.clone(), + cache_configuration: model.cache_configuration.as_ref().map(|config| { + anthropic_vertex_ai::AnthropicVertexModelCacheConfiguration { + max_cache_anchors: config.max_cache_anchors, + should_speculate: config.should_speculate, + min_total_token: config.min_total_token, + } + }), + max_output_tokens: model.max_output_tokens, + default_temperature: model.default_temperature, + mode: model.mode.clone().unwrap(), + }, + ); + } + + models + .into_values() + .map(|model| self.create_language_model(model)) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.reset_api_key(cx)) + } +} + +pub struct AnthropicVertexModel { + id: LanguageModelId, + model: anthropic_vertex_ai::Model, + state: gpui::Entity, + http_client: Arc, + request_limiter: RateLimiter, +} + +pub fn count_anthropic_tokens( + request: LanguageModelRequest, + cx: &App, +) -> BoxFuture<'static, Result> { + cx.background_spawn(async move { + let messages = request.messages; + let mut tokens_from_images = 0; + let mut string_messages = Vec::with_capacity(messages.len()); + + for message in messages { + use language_model::MessageContent; + + let mut string_contents = String::new(); + + for content in message.content { + match content { + MessageContent::Text(text) => { + string_contents.push_str(&text); + } + MessageContent::Thinking { .. } => { + // Thinking blocks are not included in the input token count. + } + MessageContent::RedactedThinking(_) => { + // Thinking blocks are not included in the input token count. + } + MessageContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + MessageContent::ToolUse(_tool_use) => { + // TODO: Estimate token usage from tool uses. + } + MessageContent::ToolResult(tool_result) => match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + string_contents.push_str(text); + } + LanguageModelToolResultContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + }, + } + } + + if !string_contents.is_empty() { + string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(string_contents), + name: None, + function_call: None, + }); + } + } + + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) + .map(|tokens| (tokens + tokens_from_images) as u64) + }) + .boxed() +} + +impl AnthropicVertexModel { + fn stream_completion( + &self, + request: anthropic_vertex_ai::Request, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { + let http_client = self.http_client.clone(); + + let Ok((access_token_option, api_url, project_id, location_id)) = + cx.read_entity(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).anthropic_vertex; + ( + state.api_key.clone(), // This is the access token for Vertex AI + settings.api_url.clone(), + settings.project_id.clone(), // ADDED + settings.location_id.clone(), // ADDED + ) + }) + else { + return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed(); + }; + + async move { + let access_token = access_token_option.unwrap(); + let request = anthropic_vertex_ai::stream_completion( + http_client.as_ref(), + &api_url, + &project_id, // ADDED + &location_id, // ADDED + &access_token, + request, + ); + request.await.map_err(Into::into) + } + .boxed() + } +} + +impl LanguageModel for AnthropicVertexModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn supports_tools(&self) -> bool { + true + } + + fn supports_images(&self) -> bool { + true + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto + | LanguageModelToolChoice::Any + | LanguageModelToolChoice::None => true, + } + } + + fn telemetry_id(&self) -> String { + format!("anthropic/{}", self.model.id()) + } + + fn api_key(&self, cx: &App) -> Option { + self.state.read(cx).api_key.clone() + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + Some(self.model.max_output_tokens()) + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + count_anthropic_tokens(request, cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { + let request = into_anthropic( + request, + self.model.request_id().into(), + self.model.default_temperature(), + self.model.max_output_tokens(), + self.model.mode(), + ); + let request = self.stream_completion(request, cx); + let future = self.request_limiter.stream(async move { + let response = request.await?; + Ok(AnthropicVertexEventMapper::new().map_stream(response)) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + + fn cache_configuration(&self) -> Option { + self.model + .cache_configuration() + .map(|config| LanguageModelCacheConfiguration { + max_cache_anchors: config.max_cache_anchors, + should_speculate: config.should_speculate, + min_total_token: config.min_total_token, + }) + } +} + +pub fn into_anthropic( + request: LanguageModelRequest, + model: String, + default_temperature: f32, + max_output_tokens: u64, + mode: ModelMode, +) -> anthropic_vertex_ai::Request { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages { + if message.contents_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + let mut anthropic_message_content: Vec = + message + .content + .into_iter() + .filter_map(|content| match content { + MessageContent::Text(text) => { + let text = + if text.chars().last().map_or(false, |c| c.is_whitespace()) { + text.trim_end().to_string() + } else { + text + }; + if !text.is_empty() { + Some(anthropic_vertex_ai::RequestContent::Text { + text, + cache_control: None, + }) + } else { + None + } + } + MessageContent::Thinking { + text: thinking, + signature, + } => { + if !thinking.is_empty() { + Some(anthropic_vertex_ai::RequestContent::Thinking { + thinking, + signature: signature.unwrap_or_default(), + cache_control: None, + }) + } else { + None + } + } + MessageContent::RedactedThinking(data) => { + if !data.is_empty() { + Some(anthropic_vertex_ai::RequestContent::RedactedThinking { + data, + }) + } else { + None + } + } + MessageContent::Image(image) => { + Some(anthropic_vertex_ai::RequestContent::Image { + source: anthropic_vertex_ai::ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + cache_control: None, + }) + } + MessageContent::ToolUse(tool_use) => { + Some(anthropic_vertex_ai::RequestContent::ToolUse { + id: tool_use.id.to_string(), + name: tool_use.name.to_string(), + input: tool_use.input, + cache_control: None, + }) + } + MessageContent::ToolResult(tool_result) => { + Some(anthropic_vertex_ai::RequestContent::ToolResult { + tool_use_id: tool_result.tool_use_id.to_string(), + is_error: tool_result.is_error, + content: match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + ToolResultContent::Plain(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + ToolResultContent::Multipart(vec![ + ToolResultPart::Image { + source: anthropic_vertex_ai::ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + }, + ]) + } + }, + cache_control: None, + }) + } + }) + .collect(); + let anthropic_role = match message.role { + Role::User => anthropic_vertex_ai::Role::User, + Role::Assistant => anthropic_vertex_ai::Role::Assistant, + Role::System => unreachable!("System role should never occur here"), + }; + if let Some(last_message) = new_messages.last_mut() { + if last_message.role == anthropic_role { + last_message.content.extend(anthropic_message_content); + continue; + } + } + + // Mark the last segment of the message as cached + if message.cache { + let cache_control_value = Some(anthropic_vertex_ai::CacheControl { + cache_type: anthropic_vertex_ai::CacheControlType::Ephemeral, + }); + for message_content in anthropic_message_content.iter_mut().rev() { + match message_content { + anthropic_vertex_ai::RequestContent::RedactedThinking { .. } => { + // Caching is not possible, fallback to next message + } + anthropic_vertex_ai::RequestContent::Text { cache_control, .. } + | anthropic_vertex_ai::RequestContent::Thinking { + cache_control, .. + } + | anthropic_vertex_ai::RequestContent::Image { + cache_control, .. + } + | anthropic_vertex_ai::RequestContent::ToolUse { + cache_control, .. + } + | anthropic_vertex_ai::RequestContent::ToolResult { + cache_control, + .. + } => { + *cache_control = cache_control_value; + break; + } + } + } + } + + new_messages.push(anthropic_vertex_ai::Message { + role: anthropic_role, + content: anthropic_message_content, + }); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.string_contents()); + } + } + } + + anthropic_vertex_ai::Request { + model: model, + anthropic_version: "vertex-2023-10-16".to_string(), + messages: new_messages, + max_tokens: max_output_tokens, + system: if system_message.is_empty() { + None + } else { + Some(anthropic_vertex_ai::StringOrContents::String( + system_message, + )) + }, + thinking: if request.thinking_allowed + && let ModelMode::Thinking { budget_tokens } = mode + { + Some(anthropic_vertex_ai::Thinking::Enabled { budget_tokens }) + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| anthropic_vertex_ai::Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => anthropic_vertex_ai::ToolChoice::Auto, + LanguageModelToolChoice::Any => anthropic_vertex_ai::ToolChoice::Any, + LanguageModelToolChoice::None => anthropic_vertex_ai::ToolChoice::None, + }), + metadata: None, + stop_sequences: Vec::new(), + temperature: request.temperature.or(Some(default_temperature)), + top_k: None, + top_p: None, + } +} + +pub struct AnthropicVertexEventMapper { + tool_uses_by_index: HashMap, + usage: Usage, + stop_reason: StopReason, +} + +impl AnthropicVertexEventMapper { + pub fn new() -> Self { + Self { + tool_uses_by_index: HashMap::default(), + usage: Usage::default(), + stop_reason: StopReason::EndTurn, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(error.into())], + }) + }) + } + + pub fn map_event( + &mut self, + event: Event, + ) -> Vec> { + match event { + Event::ContentBlockStart { + index, + content_block, + } => match content_block { + ResponseContent::Text { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ResponseContent::Thinking { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ResponseContent::RedactedThinking { data } => { + vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })] + } + ResponseContent::ToolUse { id, name, .. } => { + self.tool_uses_by_index.insert( + index, + RawToolUse { + id, + name, + input_json: String::new(), + }, + ); + Vec::new() + } + }, + Event::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ContentDelta::ThinkingDelta { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ContentDelta::SignatureDelta { signature } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: "".to_string(), + signature: Some(signature), + })] + } + ContentDelta::InputJsonDelta { partial_json } => { + if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) { + tool_use.input_json.push_str(&partial_json); + + // Try to convert invalid (incomplete) JSON into + // valid JSON that serde can accept, e.g. by closing + // unclosed delimiters. This way, we can update the + // UI with whatever has been streamed back so far. + if let Ok(input) = serde_json::Value::from_str( + &partial_json_fixer::fix_json(&tool_use.input_json), + ) { + return vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.clone().into(), + name: tool_use.name.clone().into(), + is_input_complete: false, + raw_input: tool_use.input_json.clone(), + input, + }, + ))]; + } + } + return vec![]; + } + }, + Event::ContentBlockStop { index } => { + if let Some(tool_use) = self.tool_uses_by_index.remove(&index) { + let input_json = tool_use.input_json.trim(); + let input_value = if input_json.is_empty() { + Ok(serde_json::Value::Object(serde_json::Map::default())) + } else { + serde_json::Value::from_str(input_json) + }; + let event_result = match input_value { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.into(), + name: tool_use.name.into(), + is_input_complete: true, + input, + raw_input: tool_use.input_json.clone(), + }, + )), + Err(json_parse_err) => { + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_use.id.into(), + tool_name: tool_use.name.into(), + raw_input: input_json.into(), + json_parse_error: json_parse_err.to_string(), + }) + } + }; + + vec![event_result] + } else { + Vec::new() + } + } + Event::MessageStart { message } => { + update_usage(&mut self.usage, &message.usage); + vec![ + Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( + &self.usage, + ))), + Ok(LanguageModelCompletionEvent::StartMessage { + message_id: message.id, + }), + ] + } + Event::MessageDelta { delta, usage } => { + update_usage(&mut self.usage, &usage); + if let Some(stop_reason) = delta.stop_reason.as_deref() { + self.stop_reason = match stop_reason { + "end_turn" => StopReason::EndTurn, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + "refusal" => StopReason::Refusal, + _ => { + log::error!("Unexpected anthropic stop_reason: {stop_reason}"); + StopReason::EndTurn + } + }; + } + vec![Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&self.usage), + ))] + } + Event::MessageStop => { + vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] + } + Event::Error { error } => { + vec![Err(error.into())] + } + _ => Vec::new(), + } + } +} + +struct RawToolUse { + id: String, + name: String, + input_json: String, +} + +/// Updates usage data by preferring counts from `new`. +fn update_usage(usage: &mut Usage, new: &Usage) { + if let Some(input_tokens) = new.input_tokens { + usage.input_tokens = Some(input_tokens); + } + if let Some(output_tokens) = new.output_tokens { + usage.output_tokens = Some(output_tokens); + } + if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens { + usage.cache_creation_input_tokens = Some(cache_creation_input_tokens); + } + if let Some(cache_read_input_tokens) = new.cache_read_input_tokens { + usage.cache_read_input_tokens = Some(cache_read_input_tokens); + } +} + +fn convert_usage(usage: &Usage) -> language_model::TokenUsage { + language_model::TokenUsage { + input_tokens: usage.input_tokens.unwrap_or(0), + output_tokens: usage.output_tokens.unwrap_or(0), + cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), + cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), + } +} + +struct ConfigurationView { + state: gpui::Entity, + load_credentials_task: Option>, +} + +impl ConfigurationView { + fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn_in(window, { + let state = state.clone(); + async move |this, cx| { + if let Some(task) = state + .update(cx, |state, cx| state.authenticate(cx)) + .log_err() + { + // We don't log an error, because "not signed in" is also an error. + let _ = task.await; + } + this.update(cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + state, + load_credentials_task, + } + } + + fn authenticate_gcloud(&mut self, window: &mut Window, cx: &mut Context) { + println!("Authenticating with gcloud..."); + + let state = self.state.clone(); + self.load_credentials_task = Some(cx.spawn_in(window, { + async move |this, cx| { + if let Some(task) = state + .update(cx, |state, cx| state.authenticate(cx)) + .log_err() + { + let _ = task.await; + } + this.update(cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + cx.notify(); + } + + fn reset_gcloud_auth(&mut self, window: &mut Window, cx: &mut Context) { + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state.update(cx, |state, cx| state.reset_api_key(cx))?.await + }) + .detach_and_log_err(cx); + + cx.notify(); + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let is_authenticated = self.state.read(cx).is_authenticated(); + + if self.load_credentials_task.is_some() { + div() + .child(Label::new("Attempting to authenticate with gcloud...")) + .into_any() + } else if !is_authenticated { + v_flex() + .size_full() + .child(Label::new("Please authenticate with Google Cloud to use this provider.")) + .child( + List::new() + .child(InstructionListItem::text_only( + "1. Ensure Google Cloud SDK is installed and configured.", + )) + .child(InstructionListItem::text_only( + "2. Run 'gcloud auth application-default login' in your terminal.", + )) + .child(InstructionListItem::text_only( + "3. Configure your desired Google Cloud Project ID and Location ID in Zed's settings.json file under 'language_models.google_vertex'.", + )) + ) + .child( + h_flex() + .w_full() + .my_2() + .child( + Button::new("authenticate-gcloud", "Authenticate with gcloud") + .label_size(LabelSize::Small) + .icon_size(IconSize::Small) + .on_click(cx.listener(|this, _, window, cx| this.authenticate_gcloud(window, cx))), + ), + ) + .child( + Label::new( + "This will attempt to acquire an access token using your + gcloud application-default credentials. You might need to run + 'gcloud auth application-default login' manually first." + ) + .size(LabelSize::Small).color(Color::Muted), + ) + .into_any() + } else { + h_flex() + .mt_1() + .p_1() + // .justify_between() // Removed, button is handled separately + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().background) + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new("Authenticated with gcloud.")), + ) + .child( + Button::new("reset-gcloud-auth", "Clear Token") + .label_size(LabelSize::Small) + .icon(Some(IconName::Trash)) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .tooltip(Tooltip::text("Clear the in-memory access token. You will need to re-authenticate to use the provider.")) + .on_click(cx.listener(|this, _, window, cx| this.reset_gcloud_auth(window, cx))), + ) + .into_any() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use anthropic_vertex_ai::ModelMode; + use language_model::{LanguageModelRequestMessage, MessageContent}; + + #[test] + fn test_cache_control_only_on_last_segment() { + let request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![ + MessageContent::Text("Some prompt".to_string()), + MessageContent::Image(language_model::LanguageModelImage::empty()), + MessageContent::Image(language_model::LanguageModelImage::empty()), + MessageContent::Image(language_model::LanguageModelImage::empty()), + MessageContent::Image(language_model::LanguageModelImage::empty()), + ], + cache: true, + }], + thread_id: None, + prompt_id: None, + intent: None, + mode: None, + stop: vec![], + temperature: None, + tools: vec![], + tool_choice: None, + thinking_allowed: true, + }; + + let anthropic_request = into_anthropic( + request, + "claude-3-5-sonnet".to_string(), + 0.7, + 4096, + ModelMode::Default, + ); + + assert_eq!(anthropic_request.messages.len(), 1); + + let message = &anthropic_request.messages[0]; + assert_eq!(message.content.len(), 5); + + assert!(matches!( + message.content[0], + anthropic_vertex_ai::RequestContent::Text { + cache_control: None, + .. + } + )); + for i in 1..3 { + assert!(matches!( + message.content[i], + anthropic_vertex_ai::RequestContent::Image { + cache_control: None, + .. + } + )); + } + + assert!(matches!( + message.content[4], + anthropic_vertex_ai::RequestContent::Image { + cache_control: Some(anthropic_vertex_ai::CacheControl { + cache_type: anthropic_vertex_ai::CacheControlType::Ephemeral, + }), + .. + } + )); + } +} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index 3e5ba26cc0..792ee77064 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -7,6 +7,7 @@ use settings::{Settings, SettingsSources}; use crate::provider::{ self, anthropic::AnthropicSettings, + anthropic_vertex::AnthropicVertexSettings, bedrock::AmazonBedrockSettings, cloud::{self, ZedDotDevSettings}, deepseek::DeepSeekSettings, @@ -33,6 +34,7 @@ pub struct AllLanguageModelSettings { pub deepseek: DeepSeekSettings, pub google: GoogleSettings, pub google_vertex: GoogleVertexSettings, + pub anthropic_vertex: AnthropicVertexSettings, pub lmstudio: LmStudioSettings, pub mistral: MistralSettings, pub ollama: OllamaSettings, @@ -50,6 +52,7 @@ pub struct AllLanguageModelSettingsContent { pub deepseek: Option, pub google: Option, pub google_vertex: Option, + pub anthropic_vertex: Option, pub lmstudio: Option, pub mistral: Option, pub ollama: Option, @@ -126,6 +129,14 @@ pub struct GoogleVertexSettingsContent { pub available_models: Option>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct AnthropicVertexSettingsContent { + pub api_url: Option, + pub project_id: Option, // ADDED + pub location_id: Option, // ADDED + pub available_models: Option>, +} + #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct XAiSettingsContent { pub api_url: Option, @@ -322,6 +333,29 @@ impl settings::Settings for AllLanguageModelSettings { .as_ref() .and_then(|s| s.location_id.clone()), ); + + // Anthropic Vertex AI + merge( + &mut settings.anthropic_vertex.api_url, + value + .anthropic_vertex + .as_ref() + .and_then(|s| s.api_url.clone()), + ); + merge( + &mut settings.anthropic_vertex.project_id, + value + .anthropic_vertex + .as_ref() + .and_then(|s| s.project_id.clone()), + ); + merge( + &mut settings.anthropic_vertex.location_id, + value + .anthropic_vertex + .as_ref() + .and_then(|s| s.location_id.clone()), + ); } Ok(settings)