From b8f3a9101c77ade0e3e44f06fb339af177031e56 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 29 Jul 2025 19:30:45 -0400 Subject: [PATCH] Add `cloud_llm_client` crate (#35307) This PR adds a `cloud_llm_client` crate to take the place of the `zed_llm_client`. Release Notes: - N/A --- Cargo.lock | 13 + Cargo.toml | 3 +- crates/cloud_llm_client/Cargo.toml | 23 ++ crates/cloud_llm_client/LICENSE-APACHE | 1 + .../cloud_llm_client/src/cloud_llm_client.rs | 370 ++++++++++++++++++ 5 files changed, 409 insertions(+), 1 deletion(-) create mode 100644 crates/cloud_llm_client/Cargo.toml create mode 120000 crates/cloud_llm_client/LICENSE-APACHE create mode 100644 crates/cloud_llm_client/src/cloud_llm_client.rs diff --git a/Cargo.lock b/Cargo.lock index d91c5d5eca..527b99f3c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3031,6 +3031,19 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "cloud_llm_client" +version = "0.1.0" +dependencies = [ + "anyhow", + "pretty_assertions", + "serde", + "serde_json", + "strum 0.27.1", + "uuid", + "workspace-hack", +] + [[package]] name = "clru" version = "0.6.2" diff --git a/Cargo.toml b/Cargo.toml index 16ace7dee0..e08736e38e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ members = [ "crates/cli", "crates/client", "crates/clock", + "crates/cloud_llm_client", "crates/collab", "crates/collab_ui", "crates/collections", @@ -70,7 +71,6 @@ members = [ "crates/gpui", "crates/gpui_macros", "crates/gpui_tokio", - "crates/html_to_markdown", "crates/http_client", "crates/http_client_tls", @@ -251,6 +251,7 @@ channel = { path = "crates/channel" } cli = { path = "crates/cli" } client = { path = "crates/client" } clock = { path = "crates/clock" } +cloud_llm_client = { path = "crates/cloud_llm_client" } collab = { path = "crates/collab" } collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections" } diff --git a/crates/cloud_llm_client/Cargo.toml b/crates/cloud_llm_client/Cargo.toml new file mode 100644 index 0000000000..6f090d3c6e --- /dev/null +++ b/crates/cloud_llm_client/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "cloud_llm_client" +version = "0.1.0" +publish.workspace = true +edition.workspace = true +license = "Apache-2.0" + +[lints] +workspace = true + +[lib] +path = "src/cloud_llm_client.rs" + +[dependencies] +anyhow.workspace = true +serde = { workspace = true, features = ["derive", "rc"] } +serde_json.workspace = true +strum = { workspace = true, features = ["derive"] } +uuid = { workspace = true, features = ["serde"] } +workspace-hack.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true diff --git a/crates/cloud_llm_client/LICENSE-APACHE b/crates/cloud_llm_client/LICENSE-APACHE new file mode 120000 index 0000000000..1cd601d0a3 --- /dev/null +++ b/crates/cloud_llm_client/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs new file mode 100644 index 0000000000..2488088a49 --- /dev/null +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -0,0 +1,370 @@ +use std::str::FromStr; +use std::sync::Arc; + +use anyhow::Context as _; +use serde::{Deserialize, Serialize}; +use strum::{Display, EnumIter, EnumString}; +use uuid::Uuid; + +/// The name of the header used to indicate which version of Zed the client is running. +pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version"; + +/// The name of the header used to indicate when a request failed due to an +/// expired LLM token. +/// +/// The client may use this as a signal to refresh the token. +pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token"; + +/// The name of the header used to indicate what plan the user is currently on. +pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan"; + +/// The name of the header used to indicate the usage limit for model requests. +pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit"; + +/// The name of the header used to indicate the usage amount for model requests. +pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount"; + +/// The name of the header used to indicate the usage limit for edit predictions. +pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit"; + +/// The name of the header used to indicate the usage amount for edit predictions. +pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount"; + +/// The name of the header used to indicate the resource for which the subscription limit has been reached. +pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource"; + +pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests"; +pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions"; + +/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached. +pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached"; + +/// The name of the header used to indicate the the minimum required Zed version. +/// +/// This can be used to force a Zed upgrade in order to continue communicating +/// with the LLM service. +pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version"; + +/// The name of the header used by the client to indicate to the server that it supports receiving status messages. +pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = + "x-zed-client-supports-status-messages"; + +/// The name of the header used by the server to indicate to the client that it supports sending status messages. +pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str = + "x-zed-server-supports-status-messages"; + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum UsageLimit { + Limited(i32), + Unlimited, +} + +impl FromStr for UsageLimit { + type Err = anyhow::Error; + + fn from_str(value: &str) -> Result { + match value { + "unlimited" => Ok(Self::Unlimited), + limit => limit + .parse::() + .map(Self::Limited) + .context("failed to parse limit"), + } + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Plan { + #[default] + #[serde(alias = "Free")] + ZedFree, + #[serde(alias = "ZedPro")] + ZedPro, + #[serde(alias = "ZedProTrial")] + ZedProTrial, +} + +impl Plan { + pub fn as_str(&self) -> &'static str { + match self { + Plan::ZedFree => "zed_free", + Plan::ZedPro => "zed_pro", + Plan::ZedProTrial => "zed_pro_trial", + } + } + + pub fn model_requests_limit(&self) -> UsageLimit { + match self { + Plan::ZedPro => UsageLimit::Limited(500), + Plan::ZedProTrial => UsageLimit::Limited(150), + Plan::ZedFree => UsageLimit::Limited(50), + } + } + + pub fn edit_predictions_limit(&self) -> UsageLimit { + match self { + Plan::ZedPro => UsageLimit::Unlimited, + Plan::ZedProTrial => UsageLimit::Unlimited, + Plan::ZedFree => UsageLimit::Limited(2_000), + } + } +} + +impl FromStr for Plan { + type Err = anyhow::Error; + + fn from_str(value: &str) -> Result { + match value { + "zed_free" => Ok(Plan::ZedFree), + "zed_pro" => Ok(Plan::ZedPro), + "zed_pro_trial" => Ok(Plan::ZedProTrial), + plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")), + } + } +} + +#[derive( + Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum LanguageModelProvider { + Anthropic, + OpenAi, + Google, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsBody { + #[serde(skip_serializing_if = "Option::is_none", default)] + pub outline: Option, + pub input_events: String, + pub input_excerpt: String, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub speculated_output: Option, + /// Whether the user provided consent for sampling this interaction. + #[serde(default, alias = "data_collection_permission")] + pub can_collect_data: bool, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub diagnostic_groups: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsResponse { + pub request_id: Uuid, + pub output_excerpt: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AcceptEditPredictionBody { + pub request_id: Uuid, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionMode { + Normal, + Max, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionIntent { + UserPrompt, + ToolResults, + ThreadSummarization, + ThreadContextSummarization, + CreateFile, + EditFile, + InlineAssist, + TerminalInlineAssist, + GenerateGitCommitMessage, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CompletionBody { + #[serde(skip_serializing_if = "Option::is_none", default)] + pub thread_id: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub prompt_id: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub intent: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub mode: Option, + pub provider: LanguageModelProvider, + pub model: String, + pub provider_request: serde_json::Value, +} + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionRequestStatus { + Queued { + position: usize, + }, + Started, + Failed { + code: String, + message: String, + request_id: Uuid, + /// Retry duration in seconds. + retry_after: Option, + }, + UsageUpdated { + amount: usize, + limit: UsageLimit, + }, + ToolUseLimitReached, +} + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionEvent { + Status(CompletionRequestStatus), + Event(T), +} + +impl CompletionEvent { + pub fn into_status(self) -> Option { + match self { + Self::Status(status) => Some(status), + Self::Event(_) => None, + } + } + + pub fn into_event(self) -> Option { + match self { + Self::Event(event) => Some(event), + Self::Status(_) => None, + } + } +} + +#[derive(Serialize, Deserialize)] +pub struct WebSearchBody { + pub query: String, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct WebSearchResponse { + pub results: Vec, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct WebSearchResult { + pub title: String, + pub url: String, + pub text: String, +} + +#[derive(Serialize, Deserialize)] +pub struct CountTokensBody { + pub provider: LanguageModelProvider, + pub model: String, + pub provider_request: serde_json::Value, +} + +#[derive(Serialize, Deserialize)] +pub struct CountTokensResponse { + pub tokens: usize, +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelId(pub Arc); + +impl std::fmt::Display for LanguageModelId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct LanguageModel { + pub provider: LanguageModelProvider, + pub id: LanguageModelId, + pub display_name: String, + pub max_token_count: usize, + pub max_token_count_in_max_mode: Option, + pub max_output_tokens: usize, + pub supports_tools: bool, + pub supports_images: bool, + pub supports_thinking: bool, + pub supports_max_mode: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ListModelsResponse { + pub models: Vec, + pub default_model: LanguageModelId, + pub default_fast_model: LanguageModelId, + pub recommended_models: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GetSubscriptionResponse { + pub plan: Plan, + pub usage: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CurrentUsage { + pub model_requests: UsageData, + pub edit_predictions: UsageData, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UsageData { + pub used: u32, + pub limit: UsageLimit, +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::*; + + #[test] + fn test_plan_deserialize_snake_case() { + let plan = serde_json::from_value::(json!("zed_free")).unwrap(); + assert_eq!(plan, Plan::ZedFree); + + let plan = serde_json::from_value::(json!("zed_pro")).unwrap(); + assert_eq!(plan, Plan::ZedPro); + + let plan = serde_json::from_value::(json!("zed_pro_trial")).unwrap(); + assert_eq!(plan, Plan::ZedProTrial); + } + + #[test] + fn test_plan_deserialize_aliases() { + let plan = serde_json::from_value::(json!("Free")).unwrap(); + assert_eq!(plan, Plan::ZedFree); + + let plan = serde_json::from_value::(json!("ZedPro")).unwrap(); + assert_eq!(plan, Plan::ZedPro); + + let plan = serde_json::from_value::(json!("ZedProTrial")).unwrap(); + assert_eq!(plan, Plan::ZedProTrial); + } + + #[test] + fn test_usage_limit_from_str() { + let limit = UsageLimit::from_str("unlimited").unwrap(); + assert!(matches!(limit, UsageLimit::Unlimited)); + + let limit = UsageLimit::from_str(&0.to_string()).unwrap(); + assert!(matches!(limit, UsageLimit::Limited(0))); + + let limit = UsageLimit::from_str(&50.to_string()).unwrap(); + assert!(matches!(limit, UsageLimit::Limited(50))); + + for value in ["not_a_number", "50xyz"] { + let limit = UsageLimit::from_str(value); + assert!(limit.is_err()); + } + } +}