From 29bfb56739bfb546c85dfbd40ea4a80cfb46a546 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=BB=E4=BA=8C=E6=B0=AE=E6=9D=82=E8=8F=B2?= <40173605+Cupnfish@users.noreply.github.com> Date: Tue, 28 Jan 2025 02:40:59 +0800 Subject: [PATCH] Add DeepSeek support (#23551) - Added support for DeepSeek as a new language model provider in Zed Assistant - Implemented streaming API support for real-time responses from DeepSeek models. - Added a configuration UI for DeepSeek API key management and settings. - Updated documentation with detailed setup instructions for DeepSeek integration. - Added DeepSeek-specific icons and model definitions for seamless integration into the Zed UI. - Integrated DeepSeek into the language model registry, making it available alongside other providers like OpenAI and Anthropic. Release Notes: - Added support for DeepSeek to the Assistant. --------- Co-authored-by: Marshall Bowers --- Cargo.lock | 15 + Cargo.toml | 2 + assets/icons/ai_deep_seek.svg | 1 + assets/settings/default.json | 3 + crates/assistant_settings/Cargo.toml | 1 + .../src/assistant_settings.rs | 27 +- crates/deepseek/Cargo.toml | 24 + crates/deepseek/LICENSE-GPL | 1 + crates/deepseek/src/deepseek.rs | 301 ++++++++++ crates/language_model/Cargo.toml | 1 + crates/language_model/src/request.rs | 78 +++ crates/language_model/src/role.rs | 10 + crates/language_models/Cargo.toml | 1 + crates/language_models/src/language_models.rs | 5 + crates/language_models/src/provider.rs | 1 + .../language_models/src/provider/deepseek.rs | 559 ++++++++++++++++++ crates/language_models/src/settings.rs | 21 + crates/ui/src/components/icon.rs | 1 + docs/src/assistant/configuration.md | 39 ++ 19 files changed, 1090 insertions(+), 1 deletion(-) create mode 100644 assets/icons/ai_deep_seek.svg create mode 100644 crates/deepseek/Cargo.toml create mode 120000 crates/deepseek/LICENSE-GPL create mode 100644 crates/deepseek/src/deepseek.rs create mode 100644 crates/language_models/src/provider/deepseek.rs diff --git a/Cargo.lock b/Cargo.lock index 349325e068..ec6d1dd38c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -560,6 +560,7 @@ version = "0.1.0" dependencies = [ "anthropic", "anyhow", + "deepseek", "feature_flags", "fs", "gpui", @@ -3685,6 +3686,18 @@ dependencies = [ "winapi", ] +[[package]] +name = "deepseek" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.31", + "http_client", + "schemars", + "serde", + "serde_json", +] + [[package]] name = "deflate64" version = "0.1.9" @@ -6808,6 +6821,7 @@ dependencies = [ "anyhow", "base64 0.22.1", "collections", + "deepseek", "futures 0.3.31", "google_ai", "gpui", @@ -6851,6 +6865,7 @@ dependencies = [ "client", "collections", "copilot", + "deepseek", "editor", "feature_flags", "fs", diff --git a/Cargo.toml b/Cargo.toml index 6c949d927d..0926bacf9a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ members = [ "crates/copilot", "crates/db", "crates/diagnostics", + "crates/deepseek", "crates/docs_preprocessor", "crates/editor", "crates/evals", @@ -229,6 +230,7 @@ context_server = { path = "crates/context_server" } context_server_settings = { path = "crates/context_server_settings" } copilot = { path = "crates/copilot" } db = { path = "crates/db" } +deepseek = { path = "crates/deepseek" } diagnostics = { path = "crates/diagnostics" } editor = { path = "crates/editor" } extension = { path = "crates/extension" } diff --git a/assets/icons/ai_deep_seek.svg b/assets/icons/ai_deep_seek.svg new file mode 100644 index 0000000000..cf480c834c --- /dev/null +++ b/assets/icons/ai_deep_seek.svg @@ -0,0 +1 @@ +DeepSeek diff --git a/assets/settings/default.json b/assets/settings/default.json index 04b9bdc29e..ad982a7179 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1166,6 +1166,9 @@ }, "lmstudio": { "api_url": "http://localhost:1234/api/v0" + }, + "deepseek": { + "api_url": "https://api.deepseek.com" } }, // Zed's Prettier integration settings. diff --git a/crates/assistant_settings/Cargo.toml b/crates/assistant_settings/Cargo.toml index 32ebb6a959..4398f75ef9 100644 --- a/crates/assistant_settings/Cargo.toml +++ b/crates/assistant_settings/Cargo.toml @@ -21,6 +21,7 @@ lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } +deepseek = { workspace = true, features = ["schemars"] } schemars.workspace = true serde.workspace = true settings.workspace = true diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index 12be65b90f..3193e09ae5 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use ::open_ai::Model as OpenAiModel; use anthropic::Model as AnthropicModel; +use deepseek::Model as DeepseekModel; use feature_flags::FeatureFlagAppExt; use gpui::{App, Pixels}; use language_model::{CloudModel, LanguageModel}; @@ -46,6 +47,11 @@ pub enum AssistantProviderContentV1 { default_model: Option, api_url: Option, }, + #[serde(rename = "deepseek")] + DeepSeek { + default_model: Option, + api_url: Option, + }, } #[derive(Debug, Default)] @@ -149,6 +155,12 @@ impl AssistantSettingsContent { model: model.id().to_string(), }) } + AssistantProviderContentV1::DeepSeek { default_model, .. } => { + default_model.map(|model| LanguageModelSelection { + provider: "deepseek".to_string(), + model: model.id().to_string(), + }) + } }), inline_alternatives: None, enable_experimental_live_diffs: None, @@ -253,6 +265,18 @@ impl AssistantSettingsContent { available_models, }); } + "deepseek" => { + let api_url = match &settings.provider { + Some(AssistantProviderContentV1::DeepSeek { api_url, .. }) => { + api_url.clone() + } + _ => None, + }; + settings.provider = Some(AssistantProviderContentV1::DeepSeek { + default_model: DeepseekModel::from_id(&model).ok(), + api_url, + }); + } _ => {} }, VersionedAssistantSettingsContent::V2(settings) => { @@ -341,6 +365,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema: "openai".into(), "zed.dev".into(), "copilot_chat".into(), + "deepseek".into(), ]), ..Default::default() } @@ -380,7 +405,7 @@ pub struct AssistantSettingsContentV1 { default_height: Option, /// The provider of the assistant service. /// - /// This can be "openai", "anthropic", "ollama", "lmstudio", "zed.dev" + /// This can be "openai", "anthropic", "ollama", "lmstudio", "deepseek", "zed.dev" /// each with their respective default models and configurations. provider: Option, } diff --git a/crates/deepseek/Cargo.toml b/crates/deepseek/Cargo.toml new file mode 100644 index 0000000000..25e8f2f25c --- /dev/null +++ b/crates/deepseek/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "deepseek" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/deepseek.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +futures.workspace = true +http_client.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true diff --git a/crates/deepseek/LICENSE-GPL b/crates/deepseek/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/deepseek/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/deepseek/src/deepseek.rs b/crates/deepseek/src/deepseek.rs new file mode 100644 index 0000000000..777cf696d8 --- /dev/null +++ b/crates/deepseek/src/deepseek.rs @@ -0,0 +1,301 @@ +use anyhow::{anyhow, Result}; +use futures::{ + io::BufReader, + stream::{BoxStream, StreamExt}, + AsyncBufReadExt, AsyncReadExt, +}; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::convert::TryFrom; + +pub const DEEPSEEK_API_URL: &str = "https://api.deepseek.com"; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, + Tool, +} + +impl TryFrom for Role { + type Error = anyhow::Error; + + fn try_from(value: String) -> Result { + match value.as_str() { + "user" => Ok(Self::User), + "assistant" => Ok(Self::Assistant), + "system" => Ok(Self::System), + "tool" => Ok(Self::Tool), + _ => Err(anyhow!("invalid role '{value}'")), + } + } +} + +impl From for String { + fn from(val: Role) -> Self { + match val { + Role::User => "user".to_owned(), + Role::Assistant => "assistant".to_owned(), + Role::System => "system".to_owned(), + Role::Tool => "tool".to_owned(), + } + } +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub enum Model { + #[serde(rename = "deepseek-chat")] + #[default] + Chat, + #[serde(rename = "deepseek-reasoner")] + Reasoner, + #[serde(rename = "custom")] + Custom { + name: String, + /// The name displayed in the UI, such as in the assistant panel model dropdown menu. + display_name: Option, + max_tokens: usize, + max_output_tokens: Option, + }, +} + +impl Model { + pub fn from_id(id: &str) -> Result { + match id { + "deepseek-chat" => Ok(Self::Chat), + "deepseek-reasoner" => Ok(Self::Reasoner), + _ => Err(anyhow!("invalid model id")), + } + } + + pub fn id(&self) -> &str { + match self { + Self::Chat => "deepseek-chat", + Self::Reasoner => "deepseek-reasoner", + Self::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::Chat => "DeepSeek Chat", + Self::Reasoner => "DeepSeek Reasoner", + Self::Custom { + name, display_name, .. + } => display_name.as_ref().unwrap_or(name).as_str(), + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Self::Chat | Self::Reasoner => 64_000, + Self::Custom { max_tokens, .. } => *max_tokens, + } + } + + pub fn max_output_tokens(&self) -> Option { + match self { + Self::Chat => Some(8_192), + Self::Reasoner => Some(8_192), + Self::Custom { + max_output_tokens, .. + } => *max_output_tokens, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Request { + pub model: String, + pub messages: Vec, + pub stream: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ResponseFormat { + Text, + #[serde(rename = "json_object")] + JsonObject, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolDefinition { + Function { function: FunctionDefinition }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: Option, + pub parameters: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum RequestMessage { + Assistant { + content: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + tool_calls: Vec, + }, + User { + content: String, + }, + System { + content: String, + }, + Tool { + content: String, + tool_call_id: String, + }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ToolCall { + pub id: String, + #[serde(flatten)] + pub content: ToolCallContent, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ToolCallContent { + Function { function: FunctionContent }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionContent { + pub name: String, + pub arguments: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Response { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, + #[serde(default)] + pub prompt_cache_hit_tokens: u32, + #[serde(default)] + pub prompt_cache_miss_tokens: u32, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Choice { + pub index: u32, + pub message: RequestMessage, + pub finish_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamChoice { + pub index: u32, + pub delta: StreamDelta, + pub finish_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct StreamDelta { + pub role: Option, + pub content: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ToolCallChunk { + pub index: usize, + pub id: Option, + pub function: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct FunctionChunk { + pub name: Option, + pub arguments: Option, +} + +pub async fn stream_completion( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, +) -> Result>> { + let uri = format!("{api_url}/v1/chat/completions"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)); + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + let line = line.strip_prefix("data: ")?; + if line == "[DONE]" { + None + } else { + match serde_json::from_str(line) { + Ok(response) => Some(Ok(response)), + Err(error) => Some(Err(anyhow!(error))), + } + } + } + Err(error) => Some(Err(anyhow!(error))), + } + }) + .boxed()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + Err(anyhow!( + "Failed to connect to DeepSeek API: {} {}", + response.status(), + body, + )) + } +} diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 0842e18752..74505b1780 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -29,6 +29,7 @@ log.workspace = true ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } lmstudio = { workspace = true, features = ["schemars"] } +deepseek = { workspace = true, features = ["schemars"] } parking_lot.workspace = true proto.workspace = true schemars.workspace = true diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index ea26b53021..19ceea7a53 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -410,6 +410,84 @@ impl LanguageModelRequest { top_p: None, } } + + pub fn into_deepseek(self, model: String, max_output_tokens: Option) -> deepseek::Request { + let is_reasoner = model == "deepseek-reasoner"; + + let len = self.messages.len(); + let merged_messages = + self.messages + .into_iter() + .fold(Vec::with_capacity(len), |mut acc, msg| { + let role = msg.role; + let content = msg.string_contents(); + + if is_reasoner { + if let Some(last_msg) = acc.last_mut() { + match (last_msg, role) { + (deepseek::RequestMessage::User { content: last }, Role::User) => { + last.push(' '); + last.push_str(&content); + return acc; + } + + ( + deepseek::RequestMessage::Assistant { + content: last_content, + .. + }, + Role::Assistant, + ) => { + *last_content = last_content + .take() + .map(|c| { + let mut s = + String::with_capacity(c.len() + content.len() + 1); + s.push_str(&c); + s.push(' '); + s.push_str(&content); + s + }) + .or(Some(content)); + + return acc; + } + _ => {} + } + } + } + + acc.push(match role { + Role::User => deepseek::RequestMessage::User { content }, + Role::Assistant => deepseek::RequestMessage::Assistant { + content: Some(content), + tool_calls: Vec::new(), + }, + Role::System => deepseek::RequestMessage::System { content }, + }); + acc + }); + + deepseek::Request { + model, + messages: merged_messages, + stream: true, + max_tokens: max_output_tokens, + temperature: if is_reasoner { None } else { self.temperature }, + response_format: None, + tools: self + .tools + .into_iter() + .map(|tool| deepseek::ToolDefinition::Function { + function: deepseek::FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), + } + } } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] diff --git a/crates/language_model/src/role.rs b/crates/language_model/src/role.rs index 17366fa3ec..fa56a2a88b 100644 --- a/crates/language_model/src/role.rs +++ b/crates/language_model/src/role.rs @@ -66,6 +66,16 @@ impl From for open_ai::Role { } } +impl From for deepseek::Role { + fn from(val: Role) -> Self { + match val { + Role::User => deepseek::Role::User, + Role::Assistant => deepseek::Role::Assistant, + Role::System => deepseek::Role::System, + } + } +} + impl From for lmstudio::Role { fn from(val: Role) -> Self { match val { diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index b66447124b..4d7590e40e 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -29,6 +29,7 @@ menu.workspace = true ollama = { workspace = true, features = ["schemars"] } lmstudio = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } +deepseek = { workspace = true, features = ["schemars"] } project.workspace = true proto.workspace = true schemars.workspace = true diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 645d02b978..99e5c36d61 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -4,6 +4,7 @@ use client::{Client, UserStore}; use fs::Fs; use gpui::{App, Context, Entity}; use language_model::{LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; +use provider::deepseek::DeepSeekLanguageModelProvider; mod logging; pub mod provider; @@ -55,6 +56,10 @@ fn register_language_model_providers( LmStudioLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider( + DeepSeekLanguageModelProvider::new(client.http_client(), cx), + cx, + ); registry.register_provider( GoogleLanguageModelProvider::new(client.http_client(), cx), cx, diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index 09fb975fc6..a7738563e7 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -1,6 +1,7 @@ pub mod anthropic; pub mod cloud; pub mod copilot_chat; +pub mod deepseek; pub mod google; pub mod lmstudio; pub mod ollama; diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs new file mode 100644 index 0000000000..1d32706bf2 --- /dev/null +++ b/crates/language_models/src/provider/deepseek.rs @@ -0,0 +1,559 @@ +use anyhow::{anyhow, Result}; +use collections::BTreeMap; +use editor::{Editor, EditorElement, EditorStyle}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::{ + AnyView, AppContext, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, +}; +use http_client::HttpClient; +use language_model::{ + LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsStore}; +use std::sync::Arc; +use theme::ThemeSettings; +use ui::{prelude::*, Icon, IconName}; +use util::ResultExt; + +use crate::AllLanguageModelSettings; + +const PROVIDER_ID: &str = "deepseek"; +const PROVIDER_NAME: &str = "DeepSeek"; +const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct DeepSeekSettings { + pub api_url: String, + pub available_models: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + pub name: String, + pub display_name: Option, + pub max_tokens: usize, + pub max_output_tokens: Option, +} + +pub struct DeepSeekLanguageModelProvider { + http_client: Arc, + state: Entity, +} + +pub struct State { + api_key: Option, + api_key_from_env: bool, + _subscription: Subscription, +} + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut Context) -> Task> { + let settings = &AllLanguageModelSettings::get_global(cx).deepseek; + let delete_credentials = cx.delete_credentials(&settings.api_url); + cx.spawn(|this, mut cx| async move { + delete_credentials.await.log_err(); + this.update(&mut cx, |this, cx| { + this.api_key = None; + this.api_key_from_env = false; + cx.notify(); + }) + }) + } + + fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { + let settings = &AllLanguageModelSettings::get_global(cx).deepseek; + let write_credentials = + cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes()); + + cx.spawn(|this, mut cx| async move { + write_credentials.await?; + this.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + + fn authenticate(&self, cx: &mut Context) -> Task> { + if self.is_authenticated() { + Task::ready(Ok(())) + } else { + let api_url = AllLanguageModelSettings::get_global(cx) + .deepseek + .api_url + .clone(); + + cx.spawn(|this, mut cx| async move { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or_else(|| anyhow!("credentials not found"))?; + (String::from_utf8(api_key)?, false) + }; + + this.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + }) + }) + } + } +} + +impl DeepSeekLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| State { + api_key: None, + api_key_from_env: false, + _subscription: cx.observe_global::(|_this: &mut State, cx| { + cx.notify(); + }), + }); + + Self { http_client, state } + } +} + +impl LanguageModelProviderState for DeepSeekLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for DeepSeekLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn icon(&self) -> IconName { + IconName::AiDeepSeek + } + + fn provided_models(&self, cx: &App) -> Vec> { + let mut models = BTreeMap::default(); + + models.insert("deepseek-chat", deepseek::Model::Chat); + models.insert("deepseek-reasoner", deepseek::Model::Reasoner); + + for available_model in AllLanguageModelSettings::get_global(cx) + .deepseek + .available_models + .iter() + { + models.insert( + &available_model.name, + deepseek::Model::Custom { + name: available_model.name.clone(), + display_name: available_model.display_name.clone(), + max_tokens: available_model.max_tokens, + max_output_tokens: available_model.max_output_tokens, + }, + ); + } + + models + .into_values() + .map(|model| { + Arc::new(DeepSeekLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) as Arc + }) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.reset_api_key(cx)) + } +} + +pub struct DeepSeekLanguageModel { + id: LanguageModelId, + model: deepseek::Model, + state: Entity, + http_client: Arc, + request_limiter: RateLimiter, +} + +impl DeepSeekLanguageModel { + fn stream_completion( + &self, + request: deepseek::Request, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result>>> { + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).deepseek; + (state.api_key.clone(), settings.api_url.clone()) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + let future = self.request_limiter.stream(async move { + let api_key = api_key.ok_or_else(|| anyhow!("Missing DeepSeek API Key"))?; + let request = + deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }); + + async move { Ok(future.await?.boxed()) }.boxed() + } +} + +impl LanguageModel for DeepSeekLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn telemetry_id(&self) -> String { + format!("deepseek/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + self.model.max_output_tokens() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + cx.background_executor() + .spawn(async move { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + tiktoken_rs::num_tokens_from_messages("gpt-4", &messages) + }) + .boxed() + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result>>> { + let request = request.into_deepseek(self.model.id().to_string(), self.max_output_tokens()); + let stream = self.stream_completion(request, cx); + + async move { + let stream = stream.await?; + Ok(stream + .map(|result| { + result.and_then(|response| { + response + .choices + .first() + .ok_or_else(|| anyhow!("Empty response")) + .map(|choice| { + choice + .delta + .content + .clone() + .unwrap_or_default() + .map(LanguageModelCompletionEvent::Text) + }) + }) + }) + .boxed()) + } + .boxed() + } + fn use_any_tool( + &self, + request: LanguageModelRequest, + name: String, + description: String, + schema: serde_json::Value, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result>>> { + let mut deepseek_request = + request.into_deepseek(self.model.id().to_string(), self.max_output_tokens()); + + deepseek_request.tools = vec![deepseek::ToolDefinition::Function { + function: deepseek::FunctionDefinition { + name: name.clone(), + description: Some(description), + parameters: Some(schema), + }, + }]; + + let response_stream = self.stream_completion(deepseek_request, cx); + + self.request_limiter + .run(async move { + let stream = response_stream.await?; + + let tool_args_stream = stream + .filter_map(move |response| async move { + match response { + Ok(response) => { + for choice in response.choices { + if let Some(tool_calls) = choice.delta.tool_calls { + for tool_call in tool_calls { + if let Some(function) = tool_call.function { + if let Some(args) = function.arguments { + return Some(Ok(args)); + } + } + } + } + } + None + } + Err(e) => Some(Err(e)), + } + }) + .boxed(); + + Ok(tool_args_stream) + }) + .boxed() + } +} + +struct ConfigurationView { + api_key_editor: Entity, + state: Entity, + load_credentials_task: Option>, +} + +impl ConfigurationView { + fn new(state: Entity, window: &mut Window, cx: &mut Context) -> Self { + let api_key_editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_placeholder_text("sk-00000000000000000000000000000000", cx); + editor + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn({ + let state = state.clone(); + |this, mut cx| async move { + if let Some(task) = state + .update(&mut cx, |state, cx| state.authenticate(cx)) + .log_err() + { + let _ = task.await; + } + + this.update(&mut cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor, + state, + load_credentials_task, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_editor.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + let state = self.state.clone(); + cx.spawn(|_, mut cx| async move { + state + .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn(|_, mut cx| async move { + state + .update(&mut cx, |state, cx| state.reset_api_key(cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_fallbacks: settings.ui_font.fallbacks.clone(), + font_size: rems(0.875).into(), + font_weight: settings.ui_font.weight, + font_style: FontStyle::Normal, + line_height: relative(1.3), + background_color: None, + underline: None, + strikethrough: None, + white_space: WhiteSpace::Normal, + truncate: None, + }; + EditorElement::new( + &self.api_key_editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } + + fn should_render_editor(&self, cx: &mut Context) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + const DEEPSEEK_CONSOLE_URL: &str = "https://platform.deepseek.com/api_keys"; + const INSTRUCTIONS: [&str; 3] = [ + "To use DeepSeek in Zed, you need an API key:", + "- Get your API key from:", + "- Paste it below and press enter:", + ]; + + let env_var_set = self.state.read(cx).api_key_from_env; + + if self.load_credentials_task.is_some() { + div().child(Label::new("Loading credentials...")).into_any() + } else if self.should_render_editor(cx) { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new(INSTRUCTIONS[0])) + .child( + h_flex().child(Label::new(INSTRUCTIONS[1])).child( + Button::new("deepseek_console", DEEPSEEK_CONSOLE_URL) + .style(ButtonStyle::Subtle) + .icon(IconName::ExternalLink) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, _window, cx| cx.open_url(DEEPSEEK_CONSOLE_URL)), + ), + ) + .child(Label::new(INSTRUCTIONS[2])) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .rounded_md() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new(format!( + "Or set {} environment variable", + DEEPSEEK_API_KEY_VAR + )) + .size(LabelSize::Small), + ) + .into_any() + } else { + h_flex() + .size_full() + .justify_between() + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new(if env_var_set { + format!("API key set in {}", DEEPSEEK_API_KEY_VAR) + } else { + "API key configured".to_string() + })), + ) + .child( + Button::new("reset-key", "Reset") + .icon(IconName::Trash) + .disabled(env_var_set) + .on_click( + cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)), + ), + ) + .into_any() + } + } +} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index d071185618..eb3afb8f5e 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -13,6 +13,7 @@ use crate::provider::{ anthropic::AnthropicSettings, cloud::{self, ZedDotDevSettings}, copilot_chat::CopilotChatSettings, + deepseek::DeepSeekSettings, google::GoogleSettings, lmstudio::LmStudioSettings, ollama::OllamaSettings, @@ -61,6 +62,7 @@ pub struct AllLanguageModelSettings { pub google: GoogleSettings, pub copilot_chat: CopilotChatSettings, pub lmstudio: LmStudioSettings, + pub deepseek: DeepSeekSettings, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -72,6 +74,7 @@ pub struct AllLanguageModelSettingsContent { #[serde(rename = "zed.dev")] pub zed_dot_dev: Option, pub google: Option, + pub deepseek: Option, pub copilot_chat: Option, } @@ -162,6 +165,12 @@ pub struct LmStudioSettingsContent { pub available_models: Option>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct DeepseekSettingsContent { + pub api_url: Option, + pub available_models: Option>, +} + #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[serde(untagged)] pub enum OpenAiSettingsContent { @@ -299,6 +308,18 @@ impl settings::Settings for AllLanguageModelSettings { lmstudio.as_ref().and_then(|s| s.available_models.clone()), ); + // DeepSeek + let deepseek = value.deepseek.clone(); + + merge( + &mut settings.deepseek.api_url, + value.deepseek.as_ref().and_then(|s| s.api_url.clone()), + ); + merge( + &mut settings.deepseek.available_models, + deepseek.as_ref().and_then(|s| s.available_models.clone()), + ); + // OpenAI let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) { Some((content, upgraded)) => (Some(content), upgraded), diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs index 0dbb1dbaa2..c1aea34371 100644 --- a/crates/ui/src/components/icon.rs +++ b/crates/ui/src/components/icon.rs @@ -128,6 +128,7 @@ pub enum IconName { Ai, AiAnthropic, AiAnthropicHosted, + AiDeepSeek, AiGoogle, AiLmStudio, AiOllama, diff --git a/docs/src/assistant/configuration.md b/docs/src/assistant/configuration.md index 35c08b20f5..aa317f8830 100644 --- a/docs/src/assistant/configuration.md +++ b/docs/src/assistant/configuration.md @@ -194,6 +194,45 @@ The Zed Assistant comes pre-configured to use the latest version for common mode You must provide the model's Context Window in the `max_tokens` parameter, this can be found [OpenAI Model Docs](https://platform.openai.com/docs/models). OpenAI `o1` models should set `max_completion_tokens` as well to avoid incurring high reasoning token costs. Custom models will be listed in the model dropdown in the assistant panel. +### DeepSeek {#deepseek} + +1. Visit the DeepSeek platform and [create an API key](https://platform.deepseek.com/api_keys) +2. Open the configuration view (`assistant: show configuration`) and navigate to the DeepSeek section +3. Enter your DeepSeek API key + +The DeepSeek API key will be saved in your keychain. + +Zed will also use the `DEEPSEEK_API_KEY` environment variable if it's defined. + +#### DeepSeek Custom Models {#deepseek-custom-models} + +The Zed Assistant comes pre-configured to use the latest version for common models (DeepSeek Chat, DeepSeek Reasoner). If you wish to use alternate models or customize the API endpoint, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "deepseek": { + "api_url": "https://api.deepseek.com", + "available_models": [ + { + "name": "deepseek-chat", + "display_name": "DeepSeek Chat", + "max_tokens": 64000 + }, + { + "name": "deepseek-reasoner", + "display_name": "DeepSeek Reasoner", + "max_tokens": 64000, + "max_output_tokens": 4096 + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the assistant panel. You can also modify the `api_url` to use a custom endpoint if needed. + ### OpenAI API Compatible Zed supports using OpenAI compatible APIs by specifying a custom `endpoint` and `available_models` for the OpenAI provider.