rpc: Remove llm module in favor of zed_llm_client (#28900)

This PR removes the `llm` module of the `rpc` crate in favor of using
the types from the `zed_llm_client`.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-16 16:22:44 -04:00 committed by GitHub
parent 54b46fdfaa
commit fcb1efdf21
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 45 additions and 85 deletions

43
Cargo.lock generated
View file

@ -324,7 +324,7 @@ dependencies = [
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
"strum 0.26.3", "strum 0.27.1",
"thiserror 2.0.12", "thiserror 2.0.12",
"workspace-hack", "workspace-hack",
] ]
@ -567,7 +567,7 @@ dependencies = [
"settings", "settings",
"smallvec", "smallvec",
"smol", "smol",
"strum 0.26.3", "strum 0.27.1",
"telemetry_events", "telemetry_events",
"text", "text",
"theme", "theme",
@ -1884,7 +1884,7 @@ dependencies = [
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
"strum 0.26.3", "strum 0.27.1",
"thiserror 2.0.12", "thiserror 2.0.12",
"tokio", "tokio",
"workspace-hack", "workspace-hack",
@ -3031,7 +3031,7 @@ dependencies = [
"settings", "settings",
"sha2", "sha2",
"sqlx", "sqlx",
"strum 0.26.3", "strum 0.27.1",
"subtle", "subtle",
"supermaven_api", "supermaven_api",
"telemetry_events", "telemetry_events",
@ -3051,6 +3051,7 @@ dependencies = [
"workspace", "workspace",
"workspace-hack", "workspace-hack",
"worktree", "worktree",
"zed_llm_client",
] ]
[[package]] [[package]]
@ -3363,7 +3364,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"settings", "settings",
"strum 0.26.3", "strum 0.27.1",
"task", "task",
"theme", "theme",
"ui", "ui",
@ -5125,7 +5126,7 @@ dependencies = [
"serde", "serde",
"settings", "settings",
"smallvec", "smallvec",
"strum 0.26.3", "strum 0.27.1",
"telemetry", "telemetry",
"theme", "theme",
"ui", "ui",
@ -5976,7 +5977,7 @@ dependencies = [
"serde_derive", "serde_derive",
"serde_json", "serde_json",
"settings", "settings",
"strum 0.26.3", "strum 0.27.1",
"telemetry", "telemetry",
"theme", "theme",
"time", "time",
@ -6069,7 +6070,7 @@ dependencies = [
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
"strum 0.26.3", "strum 0.27.1",
"workspace-hack", "workspace-hack",
] ]
@ -6175,7 +6176,7 @@ dependencies = [
"slotmap", "slotmap",
"smallvec", "smallvec",
"smol", "smol",
"strum 0.26.3", "strum 0.27.1",
"sum_tree", "sum_tree",
"taffy", "taffy",
"thiserror 2.0.12", "thiserror 2.0.12",
@ -6823,7 +6824,7 @@ name = "icons"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"serde", "serde",
"strum 0.26.3", "strum 0.27.1",
"workspace-hack", "workspace-hack",
] ]
@ -7091,7 +7092,7 @@ dependencies = [
"paths", "paths",
"pretty_assertions", "pretty_assertions",
"serde", "serde",
"strum 0.26.3", "strum 0.27.1",
"util", "util",
"workspace-hack", "workspace-hack",
] ]
@ -7677,7 +7678,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"smol", "smol",
"strum 0.26.3", "strum 0.27.1",
"telemetry_events", "telemetry_events",
"thiserror 2.0.12", "thiserror 2.0.12",
"util", "util",
@ -7737,7 +7738,7 @@ dependencies = [
"serde_json", "serde_json",
"settings", "settings",
"smol", "smol",
"strum 0.26.3", "strum 0.27.1",
"theme", "theme",
"thiserror 2.0.12", "thiserror 2.0.12",
"tiktoken-rs", "tiktoken-rs",
@ -8710,7 +8711,7 @@ dependencies = [
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
"strum 0.26.3", "strum 0.27.1",
"workspace-hack", "workspace-hack",
] ]
@ -9557,7 +9558,7 @@ dependencies = [
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
"strum 0.26.3", "strum 0.27.1",
"workspace-hack", "workspace-hack",
] ]
@ -12136,7 +12137,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"sha2", "sha2",
"strum 0.26.3", "strum 0.27.1",
"tracing", "tracing",
"util", "util",
"workspace-hack", "workspace-hack",
@ -13709,7 +13710,7 @@ dependencies = [
"settings", "settings",
"simplelog", "simplelog",
"story", "story",
"strum 0.26.3", "strum 0.27.1",
"theme", "theme",
"title_bar", "title_bar",
"ui", "ui",
@ -14444,7 +14445,7 @@ dependencies = [
"serde_json_lenient", "serde_json_lenient",
"serde_repr", "serde_repr",
"settings", "settings",
"strum 0.26.3", "strum 0.27.1",
"thiserror 2.0.12", "thiserror 2.0.12",
"util", "util",
"uuid", "uuid",
@ -14478,7 +14479,7 @@ dependencies = [
"serde_json", "serde_json",
"serde_json_lenient", "serde_json_lenient",
"simplelog", "simplelog",
"strum 0.26.3", "strum 0.27.1",
"theme", "theme",
"vscode_theme", "vscode_theme",
"workspace-hack", "workspace-hack",
@ -15479,7 +15480,7 @@ dependencies = [
"settings", "settings",
"smallvec", "smallvec",
"story", "story",
"strum 0.26.3", "strum 0.27.1",
"theme", "theme",
"ui_macros", "ui_macros",
"util", "util",
@ -17680,7 +17681,7 @@ dependencies = [
"settings", "settings",
"smallvec", "smallvec",
"sqlez", "sqlez",
"strum 0.26.3", "strum 0.27.1",
"task", "task",
"telemetry", "telemetry",
"tempfile", "tempfile",

View file

@ -540,7 +540,7 @@ smol = "2.0"
sqlformat = "0.2" sqlformat = "0.2"
streaming-iterator = "0.1" streaming-iterator = "0.1"
strsim = "0.11" strsim = "0.11"
strum = { version = "0.26.0", features = ["derive"] } strum = { version = "0.27.0", features = ["derive"] }
subtle = "2.5.0" subtle = "2.5.0"
syn = { version = "1.0.72", features = ["full", "extra-traits"] } syn = { version = "1.0.72", features = ["full", "extra-traits"] }
sys-locale = "0.3.1" sys-locale = "0.3.1"

View file

@ -75,6 +75,7 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "re
util.workspace = true util.workspace = true
uuid.workspace = true uuid.workspace = true
workspace-hack.workspace = true workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies] [dev-dependencies]
assistant = { workspace = true, features = ["test-support"] } assistant = { workspace = true, features = ["test-support"] }

View file

@ -330,8 +330,10 @@ async fn create_billing_subscription(
.await? .await?
} }
None => { None => {
let default_model = let default_model = llm_db.model(
llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?; zed_llm_client::LanguageModelProvider::Anthropic,
"claude-3-7-sonnet",
)?;
let stripe_model = stripe_billing.register_model(default_model).await?; let stripe_model = stripe_billing.register_model(default_model).await?;
stripe_billing stripe_billing
.checkout(customer_id, &user.github_login, &stripe_model, &success_url) .checkout(customer_id, &user.github_login, &stripe_model, &success_url)

View file

@ -8,9 +8,9 @@ mod tests;
use collections::HashMap; use collections::HashMap;
pub use ids::*; pub use ids::*;
use rpc::LanguageModelProvider;
pub use seed::*; pub use seed::*;
pub use tables::*; pub use tables::*;
use zed_llm_client::LanguageModelProvider;
#[cfg(test)] #[cfg(test)]
pub use tests::TestLlmDb; pub use tests::TestLlmDb;

View file

@ -1,5 +1,5 @@
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use rpc::LanguageModelProvider; use zed_llm_client::LanguageModelProvider;
use crate::llm::db::LlmDatabase; use crate::llm::db::LlmDatabase;
use crate::test_llm_db; use crate::test_llm_db;

View file

@ -1,9 +1,6 @@
use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long}; use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long};
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use client::{ use client::{Client, UserStore, zed_urls};
Client, EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
PerformCompletionParams, UserStore, zed_urls,
};
use collections::BTreeMap; use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro}; use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
use futures::{ use futures::{
@ -26,7 +23,6 @@ use language_model::{
use proto::Plan; use proto::Plan;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_json::value::RawValue;
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use smol::Timer; use smol::Timer;
use smol::io::{AsyncReadExt, BufReader}; use smol::io::{AsyncReadExt, BufReader};
@ -38,7 +34,10 @@ use std::{
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use thiserror::Error; use thiserror::Error;
use ui::{TintColor, prelude::*}; use ui::{TintColor, prelude::*};
use zed_llm_client::{CURRENT_PLAN_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME}; use zed_llm_client::{
CURRENT_PLAN_HEADER_NAME, CompletionBody, EXPIRED_LLM_TOKEN_HEADER_NAME,
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
};
use crate::AllLanguageModelSettings; use crate::AllLanguageModelSettings;
use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic}; use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic};
@ -517,7 +516,7 @@ impl CloudLanguageModel {
async fn perform_llm_completion( async fn perform_llm_completion(
client: Arc<Client>, client: Arc<Client>,
llm_api_token: LlmApiToken, llm_api_token: LlmApiToken,
body: PerformCompletionParams, body: CompletionBody,
) -> Result<Response<AsyncBody>> { ) -> Result<Response<AsyncBody>> {
let http_client = &client.http_client(); let http_client = &client.http_client();
@ -724,12 +723,10 @@ impl LanguageModel for CloudLanguageModel {
let response = Self::perform_llm_completion( let response = Self::perform_llm_completion(
client.clone(), client.clone(),
llm_api_token, llm_api_token,
PerformCompletionParams { CompletionBody {
provider: client::LanguageModelProvider::Anthropic, provider: zed_llm_client::LanguageModelProvider::Anthropic,
model: request.model.clone(), model: request.model.clone(),
provider_request: RawValue::from_string(serde_json::to_string( provider_request: serde_json::to_value(&request)?,
&request,
)?)?,
}, },
) )
.await .await
@ -765,12 +762,10 @@ impl LanguageModel for CloudLanguageModel {
let response = Self::perform_llm_completion( let response = Self::perform_llm_completion(
client.clone(), client.clone(),
llm_api_token, llm_api_token,
PerformCompletionParams { CompletionBody {
provider: client::LanguageModelProvider::OpenAi, provider: zed_llm_client::LanguageModelProvider::OpenAi,
model: request.model.clone(), model: request.model.clone(),
provider_request: RawValue::from_string(serde_json::to_string( provider_request: serde_json::to_value(&request)?,
&request,
)?)?,
}, },
) )
.await?; .await?;
@ -790,12 +785,10 @@ impl LanguageModel for CloudLanguageModel {
let response = Self::perform_llm_completion( let response = Self::perform_llm_completion(
client.clone(), client.clone(),
llm_api_token, llm_api_token,
PerformCompletionParams { CompletionBody {
provider: client::LanguageModelProvider::Google, provider: zed_llm_client::LanguageModelProvider::Google,
model: request.model.clone(), model: request.model.clone(),
provider_request: RawValue::from_string(serde_json::to_string( provider_request: serde_json::to_value(&request)?,
&request,
)?)?,
}, },
) )
.await?; .await?;

View file

@ -1,35 +0,0 @@
use serde::{Deserialize, Serialize};
use strum::{Display, EnumIter, EnumString};
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
pub const MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME: &str = "x-zed-llm-max-monthly-spend-reached";
#[derive(
Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
pub enum LanguageModelProvider {
Anthropic,
OpenAi,
Google,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LanguageModel {
pub provider: LanguageModelProvider,
pub name: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ListModelsResponse {
pub models: Vec<LanguageModel>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PerformCompletionParams {
pub provider: LanguageModelProvider,
pub model: String,
pub provider_request: Box<serde_json::value::RawValue>,
}

View file

@ -1,14 +1,12 @@
pub mod auth; pub mod auth;
mod conn; mod conn;
mod extension; mod extension;
mod llm;
mod message_stream; mod message_stream;
mod notification; mod notification;
mod peer; mod peer;
pub use conn::Connection; pub use conn::Connection;
pub use extension::*; pub use extension::*;
pub use llm::*;
pub use notification::*; pub use notification::*;
pub use peer::*; pub use peer::*;
pub use proto; pub use proto;