Add tool support for DeepSeek (#30223)
[deepseek function call api](https://api-docs.deepseek.com/guides/function_calling) has been released and it is same as openai. Release Notes: - Added tool calling support for Deepseek Models --------- Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
parent
55d91bce53
commit
b820aa1fcd
1 changed files with 168 additions and 85 deletions
|
@ -1,7 +1,8 @@
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use collections::BTreeMap;
|
use collections::{BTreeMap, HashMap};
|
||||||
use credentials_provider::CredentialsProvider;
|
use credentials_provider::CredentialsProvider;
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
|
use futures::Stream;
|
||||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle,
|
AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle,
|
||||||
|
@ -12,11 +13,14 @@ use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||||
LanguageModelToolChoice, RateLimiter, Role,
|
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||||
|
RateLimiter, Role, StopReason,
|
||||||
};
|
};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::str::FromStr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::{Icon, IconName, List, prelude::*};
|
use ui::{Icon, IconName, List, prelude::*};
|
||||||
|
@ -28,6 +32,13 @@ const PROVIDER_ID: &str = "deepseek";
|
||||||
const PROVIDER_NAME: &str = "DeepSeek";
|
const PROVIDER_NAME: &str = "DeepSeek";
|
||||||
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
|
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct RawToolCall {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Default, Clone, Debug, PartialEq)]
|
#[derive(Default, Clone, Debug, PartialEq)]
|
||||||
pub struct DeepSeekSettings {
|
pub struct DeepSeekSettings {
|
||||||
pub api_url: String,
|
pub api_url: String,
|
||||||
|
@ -280,11 +291,11 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
false
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||||
false
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_images(&self) -> bool {
|
fn supports_images(&self) -> bool {
|
||||||
|
@ -339,35 +350,12 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let request = into_deepseek(
|
let request = into_deepseek(request, &self.model, self.max_output_tokens());
|
||||||
request,
|
|
||||||
self.model.id().to_string(),
|
|
||||||
self.max_output_tokens(),
|
|
||||||
);
|
|
||||||
let stream = self.stream_completion(request, cx);
|
let stream = self.stream_completion(request, cx);
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
let stream = stream.await?;
|
let mapper = DeepSeekEventMapper::new();
|
||||||
Ok(stream
|
Ok(mapper.map_stream(stream.await?).boxed())
|
||||||
.map(|result| {
|
|
||||||
result
|
|
||||||
.and_then(|response| {
|
|
||||||
response
|
|
||||||
.choices
|
|
||||||
.first()
|
|
||||||
.context("Empty response")
|
|
||||||
.map(|choice| {
|
|
||||||
choice
|
|
||||||
.delta
|
|
||||||
.content
|
|
||||||
.clone()
|
|
||||||
.unwrap_or_default()
|
|
||||||
.map(LanguageModelCompletionEvent::Text)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.map_err(LanguageModelCompletionError::Other)
|
|
||||||
})
|
|
||||||
.boxed())
|
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
@ -375,69 +363,67 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||||
|
|
||||||
pub fn into_deepseek(
|
pub fn into_deepseek(
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
model: String,
|
model: &deepseek::Model,
|
||||||
max_output_tokens: Option<u32>,
|
max_output_tokens: Option<u32>,
|
||||||
) -> deepseek::Request {
|
) -> deepseek::Request {
|
||||||
let is_reasoner = model == "deepseek-reasoner";
|
let is_reasoner = *model == deepseek::Model::Reasoner;
|
||||||
|
|
||||||
let len = request.messages.len();
|
let mut messages = Vec::new();
|
||||||
let merged_messages =
|
for message in request.messages {
|
||||||
request
|
for content in message.content {
|
||||||
.messages
|
match content {
|
||||||
.into_iter()
|
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
|
||||||
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
.push(match message.role {
|
||||||
let role = msg.role;
|
Role::User => deepseek::RequestMessage::User { content: text },
|
||||||
let content = msg.string_contents();
|
Role::Assistant => deepseek::RequestMessage::Assistant {
|
||||||
|
content: Some(text),
|
||||||
|
tool_calls: Vec::new(),
|
||||||
|
},
|
||||||
|
Role::System => deepseek::RequestMessage::System { content: text },
|
||||||
|
}),
|
||||||
|
MessageContent::RedactedThinking(_) => {}
|
||||||
|
MessageContent::Image(_) => {}
|
||||||
|
MessageContent::ToolUse(tool_use) => {
|
||||||
|
let tool_call = deepseek::ToolCall {
|
||||||
|
id: tool_use.id.to_string(),
|
||||||
|
content: deepseek::ToolCallContent::Function {
|
||||||
|
function: deepseek::FunctionContent {
|
||||||
|
name: tool_use.name.to_string(),
|
||||||
|
arguments: serde_json::to_string(&tool_use.input)
|
||||||
|
.unwrap_or_default(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
if is_reasoner {
|
if let Some(deepseek::RequestMessage::Assistant { tool_calls, .. }) =
|
||||||
if let Some(last_msg) = acc.last_mut() {
|
messages.last_mut()
|
||||||
match (last_msg, role) {
|
{
|
||||||
(deepseek::RequestMessage::User { content: last }, Role::User) => {
|
tool_calls.push(tool_call);
|
||||||
last.push(' ');
|
} else {
|
||||||
last.push_str(&content);
|
messages.push(deepseek::RequestMessage::Assistant {
|
||||||
return acc;
|
content: None,
|
||||||
}
|
tool_calls: vec![tool_call],
|
||||||
|
});
|
||||||
(
|
|
||||||
deepseek::RequestMessage::Assistant {
|
|
||||||
content: last_content,
|
|
||||||
..
|
|
||||||
},
|
|
||||||
Role::Assistant,
|
|
||||||
) => {
|
|
||||||
*last_content = last_content
|
|
||||||
.take()
|
|
||||||
.map(|c| {
|
|
||||||
let mut s =
|
|
||||||
String::with_capacity(c.len() + content.len() + 1);
|
|
||||||
s.push_str(&c);
|
|
||||||
s.push(' ');
|
|
||||||
s.push_str(&content);
|
|
||||||
s
|
|
||||||
})
|
|
||||||
.or(Some(content));
|
|
||||||
|
|
||||||
return acc;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
MessageContent::ToolResult(tool_result) => {
|
||||||
acc.push(match role {
|
match &tool_result.content {
|
||||||
Role::User => deepseek::RequestMessage::User { content },
|
LanguageModelToolResultContent::Text(text) => {
|
||||||
Role::Assistant => deepseek::RequestMessage::Assistant {
|
messages.push(deepseek::RequestMessage::Tool {
|
||||||
content: Some(content),
|
content: text.to_string(),
|
||||||
tool_calls: Vec::new(),
|
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||||
},
|
});
|
||||||
Role::System => deepseek::RequestMessage::System { content },
|
}
|
||||||
});
|
LanguageModelToolResultContent::Image(_) => {}
|
||||||
acc
|
};
|
||||||
});
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
deepseek::Request {
|
deepseek::Request {
|
||||||
model,
|
model: model.id().to_string(),
|
||||||
messages: merged_messages,
|
messages,
|
||||||
stream: true,
|
stream: true,
|
||||||
max_tokens: max_output_tokens,
|
max_tokens: max_output_tokens,
|
||||||
temperature: if is_reasoner {
|
temperature: if is_reasoner {
|
||||||
|
@ -460,6 +446,103 @@ pub fn into_deepseek(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct DeepSeekEventMapper {
|
||||||
|
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DeepSeekEventMapper {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
tool_calls_by_index: HashMap::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn map_stream(
|
||||||
|
mut self,
|
||||||
|
events: Pin<Box<dyn Send + Stream<Item = Result<deepseek::StreamResponse>>>>,
|
||||||
|
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||||
|
{
|
||||||
|
events.flat_map(move |event| {
|
||||||
|
futures::stream::iter(match event {
|
||||||
|
Ok(event) => self.map_event(event),
|
||||||
|
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn map_event(
|
||||||
|
&mut self,
|
||||||
|
event: deepseek::StreamResponse,
|
||||||
|
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||||
|
let Some(choice) = event.choices.first() else {
|
||||||
|
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||||
|
"Response contained no choices"
|
||||||
|
)))];
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut events = Vec::new();
|
||||||
|
if let Some(content) = choice.delta.content.clone() {
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
||||||
|
for tool_call in tool_calls {
|
||||||
|
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
|
||||||
|
|
||||||
|
if let Some(tool_id) = tool_call.id.clone() {
|
||||||
|
entry.id = tool_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(function) = tool_call.function.as_ref() {
|
||||||
|
if let Some(name) = function.name.clone() {
|
||||||
|
entry.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(arguments) = function.arguments.clone() {
|
||||||
|
entry.arguments.push_str(&arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match choice.finish_reason.as_deref() {
|
||||||
|
Some("stop") => {
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||||
|
}
|
||||||
|
Some("tool_calls") => {
|
||||||
|
events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
|
||||||
|
match serde_json::Value::from_str(&tool_call.arguments) {
|
||||||
|
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
LanguageModelToolUse {
|
||||||
|
id: tool_call.id.clone().into(),
|
||||||
|
name: tool_call.name.as_str().into(),
|
||||||
|
is_input_complete: true,
|
||||||
|
input,
|
||||||
|
raw_input: tool_call.arguments.clone(),
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
|
||||||
|
id: tool_call.id.into(),
|
||||||
|
tool_name: tool_call.name.as_str().into(),
|
||||||
|
raw_input: tool_call.arguments.into(),
|
||||||
|
json_parse_error: error.to_string(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||||
|
}
|
||||||
|
Some(stop_reason) => {
|
||||||
|
log::error!("Unexpected DeepSeek stop_reason: {stop_reason:?}",);
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
events
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ConfigurationView {
|
struct ConfigurationView {
|
||||||
api_key_editor: Entity<Editor>,
|
api_key_editor: Entity<Editor>,
|
||||||
state: Entity<State>,
|
state: Entity<State>,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue