Add tool support for DeepSeek (#30223)
[deepseek function call api](https://api-docs.deepseek.com/guides/function_calling) has been released and it is same as openai. Release Notes: - Added tool calling support for Deepseek Models --------- Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
parent
55d91bce53
commit
b820aa1fcd
1 changed files with 168 additions and 85 deletions
|
@ -1,7 +1,8 @@
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::Stream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{
|
||||
AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle,
|
||||
|
@ -12,11 +13,14 @@ use language_model::{
|
|||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, RateLimiter, Role,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, prelude::*};
|
||||
|
@ -28,6 +32,13 @@ const PROVIDER_ID: &str = "deepseek";
|
|||
const PROVIDER_NAME: &str = "DeepSeek";
|
||||
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
|
||||
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct DeepSeekSettings {
|
||||
pub api_url: String,
|
||||
|
@ -280,11 +291,11 @@ impl LanguageModel for DeepSeekLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||
false
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
|
@ -339,35 +350,12 @@ impl LanguageModel for DeepSeekLanguageModel {
|
|||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let request = into_deepseek(
|
||||
request,
|
||||
self.model.id().to_string(),
|
||||
self.max_output_tokens(),
|
||||
);
|
||||
let request = into_deepseek(request, &self.model, self.max_output_tokens());
|
||||
let stream = self.stream_completion(request, cx);
|
||||
|
||||
async move {
|
||||
let stream = stream.await?;
|
||||
Ok(stream
|
||||
.map(|result| {
|
||||
result
|
||||
.and_then(|response| {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.context("Empty response")
|
||||
.map(|choice| {
|
||||
choice
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
})
|
||||
})
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
let mapper = DeepSeekEventMapper::new();
|
||||
Ok(mapper.map_stream(stream.await?).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
@ -375,69 +363,67 @@ impl LanguageModel for DeepSeekLanguageModel {
|
|||
|
||||
pub fn into_deepseek(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
model: &deepseek::Model,
|
||||
max_output_tokens: Option<u32>,
|
||||
) -> deepseek::Request {
|
||||
let is_reasoner = model == "deepseek-reasoner";
|
||||
let is_reasoner = *model == deepseek::Model::Reasoner;
|
||||
|
||||
let len = request.messages.len();
|
||||
let merged_messages =
|
||||
request
|
||||
.messages
|
||||
.into_iter()
|
||||
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
||||
let role = msg.role;
|
||||
let content = msg.string_contents();
|
||||
let mut messages = Vec::new();
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
|
||||
.push(match message.role {
|
||||
Role::User => deepseek::RequestMessage::User { content: text },
|
||||
Role::Assistant => deepseek::RequestMessage::Assistant {
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => deepseek::RequestMessage::System { content: text },
|
||||
}),
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = deepseek::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
content: deepseek::ToolCallContent::Function {
|
||||
function: deepseek::FunctionContent {
|
||||
name: tool_use.name.to_string(),
|
||||
arguments: serde_json::to_string(&tool_use.input)
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
if is_reasoner {
|
||||
if let Some(last_msg) = acc.last_mut() {
|
||||
match (last_msg, role) {
|
||||
(deepseek::RequestMessage::User { content: last }, Role::User) => {
|
||||
last.push(' ');
|
||||
last.push_str(&content);
|
||||
return acc;
|
||||
}
|
||||
|
||||
(
|
||||
deepseek::RequestMessage::Assistant {
|
||||
content: last_content,
|
||||
..
|
||||
},
|
||||
Role::Assistant,
|
||||
) => {
|
||||
*last_content = last_content
|
||||
.take()
|
||||
.map(|c| {
|
||||
let mut s =
|
||||
String::with_capacity(c.len() + content.len() + 1);
|
||||
s.push_str(&c);
|
||||
s.push(' ');
|
||||
s.push_str(&content);
|
||||
s
|
||||
})
|
||||
.or(Some(content));
|
||||
|
||||
return acc;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
if let Some(deepseek::RequestMessage::Assistant { tool_calls, .. }) =
|
||||
messages.last_mut()
|
||||
{
|
||||
tool_calls.push(tool_call);
|
||||
} else {
|
||||
messages.push(deepseek::RequestMessage::Assistant {
|
||||
content: None,
|
||||
tool_calls: vec![tool_call],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
acc.push(match role {
|
||||
Role::User => deepseek::RequestMessage::User { content },
|
||||
Role::Assistant => deepseek::RequestMessage::Assistant {
|
||||
content: Some(content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => deepseek::RequestMessage::System { content },
|
||||
});
|
||||
acc
|
||||
});
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
messages.push(deepseek::RequestMessage::Tool {
|
||||
content: text.to_string(),
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
LanguageModelToolResultContent::Image(_) => {}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deepseek::Request {
|
||||
model,
|
||||
messages: merged_messages,
|
||||
model: model.id().to_string(),
|
||||
messages,
|
||||
stream: true,
|
||||
max_tokens: max_output_tokens,
|
||||
temperature: if is_reasoner {
|
||||
|
@ -460,6 +446,103 @@ pub fn into_deepseek(
|
|||
}
|
||||
}
|
||||
|
||||
pub struct DeepSeekEventMapper {
|
||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||
}
|
||||
|
||||
impl DeepSeekEventMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tool_calls_by_index: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<deepseek::StreamResponse>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: deepseek::StreamResponse,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))];
|
||||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(content) = choice.delta.content.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
||||
for tool_call in tool_calls {
|
||||
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
|
||||
|
||||
if let Some(tool_id) = tool_call.id.clone() {
|
||||
entry.id = tool_id;
|
||||
}
|
||||
|
||||
if let Some(function) = tool_call.function.as_ref() {
|
||||
if let Some(name) = function.name.clone() {
|
||||
entry.name = name;
|
||||
}
|
||||
|
||||
if let Some(arguments) = function.arguments.clone() {
|
||||
entry.arguments.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
Some("tool_calls") => {
|
||||
events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
|
||||
match serde_json::Value::from_str(&tool_call.arguments) {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.clone().into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
},
|
||||
)),
|
||||
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.as_str().into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
}),
|
||||
}
|
||||
}));
|
||||
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
}
|
||||
Some(stop_reason) => {
|
||||
log::error!("Unexpected DeepSeek stop_reason: {stop_reason:?}",);
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<Editor>,
|
||||
state: Entity<State>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue