diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 938e9a5b48..d52a233f78 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -1,7 +1,8 @@ use anyhow::{Context as _, Result, anyhow}; -use collections::BTreeMap; +use collections::{BTreeMap, HashMap}; use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; +use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{ AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle, @@ -12,11 +13,14 @@ use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, RateLimiter, Role, + LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, + RateLimiter, Role, StopReason, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; +use std::pin::Pin; +use std::str::FromStr; use std::sync::Arc; use theme::ThemeSettings; use ui::{Icon, IconName, List, prelude::*}; @@ -28,6 +32,13 @@ const PROVIDER_ID: &str = "deepseek"; const PROVIDER_NAME: &str = "DeepSeek"; const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY"; +#[derive(Default)] +struct RawToolCall { + id: String, + name: String, + arguments: String, +} + #[derive(Default, Clone, Debug, PartialEq)] pub struct DeepSeekSettings { pub api_url: String, @@ -280,11 +291,11 @@ impl LanguageModel for DeepSeekLanguageModel { } fn supports_tools(&self) -> bool { - false + true } fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { - false + true } fn supports_images(&self) -> bool { @@ -339,35 +350,12 @@ impl LanguageModel for DeepSeekLanguageModel { BoxStream<'static, Result>, >, > { - let request = into_deepseek( - request, - self.model.id().to_string(), - self.max_output_tokens(), - ); + let request = into_deepseek(request, &self.model, self.max_output_tokens()); let stream = self.stream_completion(request, cx); async move { - let stream = stream.await?; - Ok(stream - .map(|result| { - result - .and_then(|response| { - response - .choices - .first() - .context("Empty response") - .map(|choice| { - choice - .delta - .content - .clone() - .unwrap_or_default() - .map(LanguageModelCompletionEvent::Text) - }) - }) - .map_err(LanguageModelCompletionError::Other) - }) - .boxed()) + let mapper = DeepSeekEventMapper::new(); + Ok(mapper.map_stream(stream.await?).boxed()) } .boxed() } @@ -375,69 +363,67 @@ impl LanguageModel for DeepSeekLanguageModel { pub fn into_deepseek( request: LanguageModelRequest, - model: String, + model: &deepseek::Model, max_output_tokens: Option, ) -> deepseek::Request { - let is_reasoner = model == "deepseek-reasoner"; + let is_reasoner = *model == deepseek::Model::Reasoner; - let len = request.messages.len(); - let merged_messages = - request - .messages - .into_iter() - .fold(Vec::with_capacity(len), |mut acc, msg| { - let role = msg.role; - let content = msg.string_contents(); + let mut messages = Vec::new(); + for message in request.messages { + for content in message.content { + match content { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages + .push(match message.role { + Role::User => deepseek::RequestMessage::User { content: text }, + Role::Assistant => deepseek::RequestMessage::Assistant { + content: Some(text), + tool_calls: Vec::new(), + }, + Role::System => deepseek::RequestMessage::System { content: text }, + }), + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(_) => {} + MessageContent::ToolUse(tool_use) => { + let tool_call = deepseek::ToolCall { + id: tool_use.id.to_string(), + content: deepseek::ToolCallContent::Function { + function: deepseek::FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + }, + }, + }; - if is_reasoner { - if let Some(last_msg) = acc.last_mut() { - match (last_msg, role) { - (deepseek::RequestMessage::User { content: last }, Role::User) => { - last.push(' '); - last.push_str(&content); - return acc; - } - - ( - deepseek::RequestMessage::Assistant { - content: last_content, - .. - }, - Role::Assistant, - ) => { - *last_content = last_content - .take() - .map(|c| { - let mut s = - String::with_capacity(c.len() + content.len() + 1); - s.push_str(&c); - s.push(' '); - s.push_str(&content); - s - }) - .or(Some(content)); - - return acc; - } - _ => {} - } + if let Some(deepseek::RequestMessage::Assistant { tool_calls, .. }) = + messages.last_mut() + { + tool_calls.push(tool_call); + } else { + messages.push(deepseek::RequestMessage::Assistant { + content: None, + tool_calls: vec![tool_call], + }); } } - - acc.push(match role { - Role::User => deepseek::RequestMessage::User { content }, - Role::Assistant => deepseek::RequestMessage::Assistant { - content: Some(content), - tool_calls: Vec::new(), - }, - Role::System => deepseek::RequestMessage::System { content }, - }); - acc - }); + MessageContent::ToolResult(tool_result) => { + match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + messages.push(deepseek::RequestMessage::Tool { + content: text.to_string(), + tool_call_id: tool_result.tool_use_id.to_string(), + }); + } + LanguageModelToolResultContent::Image(_) => {} + }; + } + } + } + } deepseek::Request { - model, - messages: merged_messages, + model: model.id().to_string(), + messages, stream: true, max_tokens: max_output_tokens, temperature: if is_reasoner { @@ -460,6 +446,103 @@ pub fn into_deepseek( } } +pub struct DeepSeekEventMapper { + tool_calls_by_index: HashMap, +} + +impl DeepSeekEventMapper { + pub fn new() -> Self { + Self { + tool_calls_by_index: HashMap::default(), + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + }) + }) + } + + pub fn map_event( + &mut self, + event: deepseek::StreamResponse, + ) -> Vec> { + let Some(choice) = event.choices.first() else { + return vec![Err(LanguageModelCompletionError::Other(anyhow!( + "Response contained no choices" + )))]; + }; + + let mut events = Vec::new(); + if let Some(content) = choice.delta.content.clone() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + + if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { + for tool_call in tool_calls { + let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); + + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; + } + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); + } + } + } + } + + match choice.finish_reason.as_deref() { + Some("stop") => { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + Some("tool_calls") => { + events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { + match serde_json::Value::from_str(&tool_call.arguments) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.clone().into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments.clone(), + }, + )), + Err(error) => Err(LanguageModelCompletionError::BadInputJson { + id: tool_call.id.into(), + tool_name: tool_call.name.as_str().into(), + raw_input: tool_call.arguments.into(), + json_parse_error: error.to_string(), + }), + } + })); + + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + Some(stop_reason) => { + log::error!("Unexpected DeepSeek stop_reason: {stop_reason:?}",); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + None => {} + } + + events + } +} + struct ConfigurationView { api_key_editor: Entity, state: Entity,