language_model: Remove dependencies on individual model provider crates (#25503)
This PR removes the dependencies on the individual model provider crates from the `language_model` crate. The various conversion methods for converting a `LanguageModelRequest` into its provider-specific request type have been inlined into the various provider modules in the `language_models` crate. The model providers we provide via Zed's cloud offering get to stay, for now. Release Notes: - N/A
This commit is contained in:
parent
2f7a62780a
commit
0acd556106
11 changed files with 347 additions and 366 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -7015,16 +7015,12 @@ dependencies = [
|
|||
"anyhow",
|
||||
"base64 0.22.1",
|
||||
"collections",
|
||||
"deepseek",
|
||||
"futures 0.3.31",
|
||||
"google_ai",
|
||||
"gpui",
|
||||
"http_client",
|
||||
"image",
|
||||
"lmstudio",
|
||||
"log",
|
||||
"mistral",
|
||||
"ollama",
|
||||
"open_ai",
|
||||
"parking_lot",
|
||||
"proto",
|
||||
|
|
|
@ -20,16 +20,12 @@ anthropic = { workspace = true, features = ["schemars"] }
|
|||
anyhow.workspace = true
|
||||
base64.workspace = true
|
||||
collections.workspace = true
|
||||
deepseek = { workspace = true, features = ["schemars"] }
|
||||
futures.workspace = true
|
||||
google_ai = { workspace = true, features = ["schemars"] }
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
image.workspace = true
|
||||
lmstudio = { workspace = true, features = ["schemars"] }
|
||||
log.workspace = true
|
||||
mistral = { workspace = true, features = ["schemars"] }
|
||||
ollama = { workspace = true, features = ["schemars"] }
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
parking_lot.workspace = true
|
||||
proto.workspace = true
|
||||
|
|
|
@ -1,7 +1,3 @@
|
|||
pub mod cloud_model;
|
||||
|
||||
pub use anthropic::Model as AnthropicModel;
|
||||
pub use cloud_model::*;
|
||||
pub use lmstudio::Model as LmStudioModel;
|
||||
pub use ollama::Model as OllamaModel;
|
||||
pub use open_ai::Model as OpenAiModel;
|
||||
|
|
|
@ -241,298 +241,6 @@ pub struct LanguageModelRequest {
|
|||
pub temperature: Option<f32>,
|
||||
}
|
||||
|
||||
impl LanguageModelRequest {
|
||||
pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
|
||||
let stream = !model.starts_with("o1-");
|
||||
open_ai::Request {
|
||||
model,
|
||||
messages: self
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => open_ai::RequestMessage::User {
|
||||
content: msg.string_contents(),
|
||||
},
|
||||
Role::Assistant => open_ai::RequestMessage::Assistant {
|
||||
content: Some(msg.string_contents()),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => open_ai::RequestMessage::System {
|
||||
content: msg.string_contents(),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
stream,
|
||||
stop: self.stop,
|
||||
temperature: self.temperature.unwrap_or(1.0),
|
||||
max_tokens: max_output_tokens,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_mistral(self, model: String, max_output_tokens: Option<u32>) -> mistral::Request {
|
||||
let len = self.messages.len();
|
||||
let merged_messages =
|
||||
self.messages
|
||||
.into_iter()
|
||||
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
||||
let role = msg.role;
|
||||
let content = msg.string_contents();
|
||||
|
||||
acc.push(match role {
|
||||
Role::User => mistral::RequestMessage::User { content },
|
||||
Role::Assistant => mistral::RequestMessage::Assistant {
|
||||
content: Some(content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => mistral::RequestMessage::System { content },
|
||||
});
|
||||
acc
|
||||
});
|
||||
|
||||
mistral::Request {
|
||||
model,
|
||||
messages: merged_messages,
|
||||
stream: true,
|
||||
max_tokens: max_output_tokens,
|
||||
temperature: self.temperature,
|
||||
response_format: None,
|
||||
tools: self
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| mistral::ToolDefinition::Function {
|
||||
function: mistral::FunctionDefinition {
|
||||
name: tool.name,
|
||||
description: Some(tool.description),
|
||||
parameters: Some(tool.input_schema),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
|
||||
google_ai::GenerateContentRequest {
|
||||
model,
|
||||
contents: self
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| google_ai::Content {
|
||||
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
|
||||
text: msg.string_contents(),
|
||||
})],
|
||||
role: match msg.role {
|
||||
Role::User => google_ai::Role::User,
|
||||
Role::Assistant => google_ai::Role::Model,
|
||||
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
generation_config: Some(google_ai::GenerationConfig {
|
||||
candidate_count: Some(1),
|
||||
stop_sequences: Some(self.stop),
|
||||
max_output_tokens: None,
|
||||
temperature: self.temperature.map(|t| t as f64).or(Some(1.0)),
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
}),
|
||||
safety_settings: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_anthropic(
|
||||
self,
|
||||
model: String,
|
||||
default_temperature: f32,
|
||||
max_output_tokens: u32,
|
||||
) -> anthropic::Request {
|
||||
let mut new_messages: Vec<anthropic::Message> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
||||
for message in self.messages {
|
||||
if message.contents_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
let cache_control = if message.cache {
|
||||
Some(anthropic::CacheControl {
|
||||
cache_type: anthropic::CacheControlType::Ephemeral,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let anthropic_message_content: Vec<anthropic::RequestContent> = message
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
MessageContent::Text(text) => {
|
||||
if !text.is_empty() {
|
||||
Some(anthropic::RequestContent::Text {
|
||||
text,
|
||||
cache_control,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
Some(anthropic::RequestContent::Image {
|
||||
source: anthropic::ImageSource {
|
||||
source_type: "base64".to_string(),
|
||||
media_type: "image/png".to_string(),
|
||||
data: image.source.to_string(),
|
||||
},
|
||||
cache_control,
|
||||
})
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
Some(anthropic::RequestContent::ToolUse {
|
||||
id: tool_use.id.to_string(),
|
||||
name: tool_use.name,
|
||||
input: tool_use.input,
|
||||
cache_control,
|
||||
})
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
Some(anthropic::RequestContent::ToolResult {
|
||||
tool_use_id: tool_result.tool_use_id,
|
||||
is_error: tool_result.is_error,
|
||||
content: tool_result.content,
|
||||
cache_control,
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let anthropic_role = match message.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => unreachable!("System role should never occur here"),
|
||||
};
|
||||
if let Some(last_message) = new_messages.last_mut() {
|
||||
if last_message.role == anthropic_role {
|
||||
last_message.content.extend(anthropic_message_content);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
new_messages.push(anthropic::Message {
|
||||
role: anthropic_role,
|
||||
content: anthropic_message_content,
|
||||
});
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.string_contents());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anthropic::Request {
|
||||
model,
|
||||
messages: new_messages,
|
||||
max_tokens: max_output_tokens,
|
||||
system: Some(system_message),
|
||||
tools: self
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| anthropic::Tool {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: tool.input_schema,
|
||||
})
|
||||
.collect(),
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
stop_sequences: Vec::new(),
|
||||
temperature: self.temperature.or(Some(default_temperature)),
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_deepseek(self, model: String, max_output_tokens: Option<u32>) -> deepseek::Request {
|
||||
let is_reasoner = model == "deepseek-reasoner";
|
||||
|
||||
let len = self.messages.len();
|
||||
let merged_messages =
|
||||
self.messages
|
||||
.into_iter()
|
||||
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
||||
let role = msg.role;
|
||||
let content = msg.string_contents();
|
||||
|
||||
if is_reasoner {
|
||||
if let Some(last_msg) = acc.last_mut() {
|
||||
match (last_msg, role) {
|
||||
(deepseek::RequestMessage::User { content: last }, Role::User) => {
|
||||
last.push(' ');
|
||||
last.push_str(&content);
|
||||
return acc;
|
||||
}
|
||||
|
||||
(
|
||||
deepseek::RequestMessage::Assistant {
|
||||
content: last_content,
|
||||
..
|
||||
},
|
||||
Role::Assistant,
|
||||
) => {
|
||||
*last_content = last_content
|
||||
.take()
|
||||
.map(|c| {
|
||||
let mut s =
|
||||
String::with_capacity(c.len() + content.len() + 1);
|
||||
s.push_str(&c);
|
||||
s.push(' ');
|
||||
s.push_str(&content);
|
||||
s
|
||||
})
|
||||
.or(Some(content));
|
||||
|
||||
return acc;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc.push(match role {
|
||||
Role::User => deepseek::RequestMessage::User { content },
|
||||
Role::Assistant => deepseek::RequestMessage::Assistant {
|
||||
content: Some(content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => deepseek::RequestMessage::System { content },
|
||||
});
|
||||
acc
|
||||
});
|
||||
|
||||
deepseek::Request {
|
||||
model,
|
||||
messages: merged_messages,
|
||||
stream: true,
|
||||
max_tokens: max_output_tokens,
|
||||
temperature: if is_reasoner { None } else { self.temperature },
|
||||
response_format: None,
|
||||
tools: self
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| deepseek::ToolDefinition::Function {
|
||||
function: deepseek::FunctionDefinition {
|
||||
name: tool.name,
|
||||
description: Some(tool.description),
|
||||
parameters: Some(tool.input_schema),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct LanguageModelResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
|
|
|
@ -45,43 +45,3 @@ impl Display for Role {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for ollama::Role {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => ollama::Role::User,
|
||||
Role::Assistant => ollama::Role::Assistant,
|
||||
Role::System => ollama::Role::System,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for open_ai::Role {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => open_ai::Role::User,
|
||||
Role::Assistant => open_ai::Role::Assistant,
|
||||
Role::System => open_ai::Role::System,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for deepseek::Role {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => deepseek::Role::User,
|
||||
Role::Assistant => deepseek::Role::Assistant,
|
||||
Role::System => deepseek::Role::System,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for lmstudio::Role {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => lmstudio::Role::User,
|
||||
Role::Assistant => lmstudio::Role::Assistant,
|
||||
Role::System => lmstudio::Role::System,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ use http_client::HttpClient;
|
|||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
|
||||
};
|
||||
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
||||
use schemars::JsonSchema;
|
||||
|
@ -396,7 +396,8 @@ impl LanguageModel for AnthropicModel {
|
|||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
let request = request.into_anthropic(
|
||||
let request = into_anthropic(
|
||||
request,
|
||||
self.model.id().into(),
|
||||
self.model.default_temperature(),
|
||||
self.model.max_output_tokens(),
|
||||
|
@ -427,7 +428,8 @@ impl LanguageModel for AnthropicModel {
|
|||
input_schema: serde_json::Value,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let mut request = request.into_anthropic(
|
||||
let mut request = into_anthropic(
|
||||
request,
|
||||
self.model.tool_model_id().into(),
|
||||
self.model.default_temperature(),
|
||||
self.model.max_output_tokens(),
|
||||
|
@ -456,6 +458,117 @@ impl LanguageModel for AnthropicModel {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn into_anthropic(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
default_temperature: f32,
|
||||
max_output_tokens: u32,
|
||||
) -> anthropic::Request {
|
||||
let mut new_messages: Vec<anthropic::Message> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
||||
for message in request.messages {
|
||||
if message.contents_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
let cache_control = if message.cache {
|
||||
Some(anthropic::CacheControl {
|
||||
cache_type: anthropic::CacheControlType::Ephemeral,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let anthropic_message_content: Vec<anthropic::RequestContent> = message
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
MessageContent::Text(text) => {
|
||||
if !text.is_empty() {
|
||||
Some(anthropic::RequestContent::Text {
|
||||
text,
|
||||
cache_control,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
|
||||
source: anthropic::ImageSource {
|
||||
source_type: "base64".to_string(),
|
||||
media_type: "image/png".to_string(),
|
||||
data: image.source.to_string(),
|
||||
},
|
||||
cache_control,
|
||||
}),
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
Some(anthropic::RequestContent::ToolUse {
|
||||
id: tool_use.id.to_string(),
|
||||
name: tool_use.name,
|
||||
input: tool_use.input,
|
||||
cache_control,
|
||||
})
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
Some(anthropic::RequestContent::ToolResult {
|
||||
tool_use_id: tool_result.tool_use_id,
|
||||
is_error: tool_result.is_error,
|
||||
content: tool_result.content,
|
||||
cache_control,
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let anthropic_role = match message.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => unreachable!("System role should never occur here"),
|
||||
};
|
||||
if let Some(last_message) = new_messages.last_mut() {
|
||||
if last_message.role == anthropic_role {
|
||||
last_message.content.extend(anthropic_message_content);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
new_messages.push(anthropic::Message {
|
||||
role: anthropic_role,
|
||||
content: anthropic_message_content,
|
||||
});
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.string_contents());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anthropic::Request {
|
||||
model,
|
||||
messages: new_messages,
|
||||
max_tokens: max_output_tokens,
|
||||
system: Some(system_message),
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| anthropic::Tool {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
input_schema: tool.input_schema,
|
||||
})
|
||||
.collect(),
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
stop_sequences: Vec::new(),
|
||||
temperature: request.temperature.or(Some(default_temperature)),
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use super::open_ai::count_open_ai_tokens;
|
||||
use anthropic::AnthropicError;
|
||||
use anyhow::{anyhow, Result};
|
||||
use client::{
|
||||
|
@ -43,11 +42,13 @@ use strum::IntoEnumIterator;
|
|||
use thiserror::Error;
|
||||
use ui::{prelude::*, TintColor};
|
||||
|
||||
use crate::provider::anthropic::map_to_language_model_completion_events;
|
||||
use crate::provider::anthropic::{
|
||||
count_anthropic_tokens, into_anthropic, map_to_language_model_completion_events,
|
||||
};
|
||||
use crate::provider::google::into_google;
|
||||
use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai};
|
||||
use crate::AllLanguageModelSettings;
|
||||
|
||||
use super::anthropic::count_anthropic_tokens;
|
||||
|
||||
pub const PROVIDER_NAME: &str = "Zed";
|
||||
|
||||
const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
|
||||
|
@ -612,7 +613,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_google(model.id().into());
|
||||
let request = into_google(request, model.id().into());
|
||||
let request = google_ai::CountTokensRequest {
|
||||
contents: request.contents,
|
||||
};
|
||||
|
@ -638,7 +639,8 @@ impl LanguageModel for CloudLanguageModel {
|
|||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
let request = request.into_anthropic(
|
||||
let request = into_anthropic(
|
||||
request,
|
||||
model.id().into(),
|
||||
model.default_temperature(),
|
||||
model.max_output_tokens(),
|
||||
|
@ -666,7 +668,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
|
||||
let request = into_open_ai(request, model.id().into(), model.max_output_tokens());
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
|
@ -693,7 +695,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_google(model.id().into());
|
||||
let request = into_google(request, model.id().into());
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
|
@ -736,7 +738,8 @@ impl LanguageModel for CloudLanguageModel {
|
|||
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
let mut request = request.into_anthropic(
|
||||
let mut request = into_anthropic(
|
||||
request,
|
||||
model.tool_model_id().into(),
|
||||
model.default_temperature(),
|
||||
model.max_output_tokens(),
|
||||
|
@ -776,7 +779,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let mut request =
|
||||
request.into_open_ai(model.id().into(), model.max_output_tokens());
|
||||
into_open_ai(request, model.id().into(), model.max_output_tokens());
|
||||
request.tool_choice = Some(open_ai::ToolChoice::Other(
|
||||
open_ai::ToolDefinition::Function {
|
||||
function: open_ai::FunctionDefinition {
|
||||
|
|
|
@ -322,7 +322,11 @@ impl LanguageModel for DeepSeekLanguageModel {
|
|||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
let request = request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
|
||||
let request = into_deepseek(
|
||||
request,
|
||||
self.model.id().to_string(),
|
||||
self.max_output_tokens(),
|
||||
);
|
||||
let stream = self.stream_completion(request, cx);
|
||||
|
||||
async move {
|
||||
|
@ -357,8 +361,11 @@ impl LanguageModel for DeepSeekLanguageModel {
|
|||
schema: serde_json::Value,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||
let mut deepseek_request =
|
||||
request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
|
||||
let mut deepseek_request = into_deepseek(
|
||||
request,
|
||||
self.model.id().to_string(),
|
||||
self.max_output_tokens(),
|
||||
);
|
||||
|
||||
deepseek_request.tools = vec![deepseek::ToolDefinition::Function {
|
||||
function: deepseek::FunctionDefinition {
|
||||
|
@ -402,6 +409,93 @@ impl LanguageModel for DeepSeekLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn into_deepseek(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
max_output_tokens: Option<u32>,
|
||||
) -> deepseek::Request {
|
||||
let is_reasoner = model == "deepseek-reasoner";
|
||||
|
||||
let len = request.messages.len();
|
||||
let merged_messages =
|
||||
request
|
||||
.messages
|
||||
.into_iter()
|
||||
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
||||
let role = msg.role;
|
||||
let content = msg.string_contents();
|
||||
|
||||
if is_reasoner {
|
||||
if let Some(last_msg) = acc.last_mut() {
|
||||
match (last_msg, role) {
|
||||
(deepseek::RequestMessage::User { content: last }, Role::User) => {
|
||||
last.push(' ');
|
||||
last.push_str(&content);
|
||||
return acc;
|
||||
}
|
||||
|
||||
(
|
||||
deepseek::RequestMessage::Assistant {
|
||||
content: last_content,
|
||||
..
|
||||
},
|
||||
Role::Assistant,
|
||||
) => {
|
||||
*last_content = last_content
|
||||
.take()
|
||||
.map(|c| {
|
||||
let mut s =
|
||||
String::with_capacity(c.len() + content.len() + 1);
|
||||
s.push_str(&c);
|
||||
s.push(' ');
|
||||
s.push_str(&content);
|
||||
s
|
||||
})
|
||||
.or(Some(content));
|
||||
|
||||
return acc;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc.push(match role {
|
||||
Role::User => deepseek::RequestMessage::User { content },
|
||||
Role::Assistant => deepseek::RequestMessage::Assistant {
|
||||
content: Some(content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => deepseek::RequestMessage::System { content },
|
||||
});
|
||||
acc
|
||||
});
|
||||
|
||||
deepseek::Request {
|
||||
model,
|
||||
messages: merged_messages,
|
||||
stream: true,
|
||||
max_tokens: max_output_tokens,
|
||||
temperature: if is_reasoner {
|
||||
None
|
||||
} else {
|
||||
request.temperature
|
||||
},
|
||||
response_format: None,
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| deepseek::ToolDefinition::Function {
|
||||
function: deepseek::FunctionDefinition {
|
||||
name: tool.name,
|
||||
description: Some(tool.description),
|
||||
parameters: Some(tool.input_schema),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<Editor>,
|
||||
state: Entity<State>,
|
||||
|
|
|
@ -272,7 +272,7 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
let request = request.into_google(self.model.id().to_string());
|
||||
let request = into_google(request, self.model.id().to_string());
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.state.read(cx).api_key.clone();
|
||||
|
||||
|
@ -303,7 +303,7 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
||||
> {
|
||||
let request = request.into_google(self.model.id().to_string());
|
||||
let request = into_google(request, self.model.id().to_string());
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
|
@ -341,6 +341,38 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn into_google(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
) -> google_ai::GenerateContentRequest {
|
||||
google_ai::GenerateContentRequest {
|
||||
model,
|
||||
contents: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| google_ai::Content {
|
||||
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
|
||||
text: msg.string_contents(),
|
||||
})],
|
||||
role: match msg.role {
|
||||
Role::User => google_ai::Role::User,
|
||||
Role::Assistant => google_ai::Role::Model,
|
||||
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
generation_config: Some(google_ai::GenerationConfig {
|
||||
candidate_count: Some(1),
|
||||
stop_sequences: Some(request.stop),
|
||||
max_output_tokens: None,
|
||||
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
}),
|
||||
safety_settings: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_google_tokens(
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
|
|
|
@ -334,7 +334,11 @@ impl LanguageModel for MistralLanguageModel {
|
|||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
let request = request.into_mistral(self.model.id().to_string(), self.max_output_tokens());
|
||||
let request = into_mistral(
|
||||
request,
|
||||
self.model.id().to_string(),
|
||||
self.max_output_tokens(),
|
||||
);
|
||||
let stream = self.stream_completion(request, cx);
|
||||
|
||||
async move {
|
||||
|
@ -369,7 +373,7 @@ impl LanguageModel for MistralLanguageModel {
|
|||
schema: serde_json::Value,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||
let mut request = request.into_mistral(self.model.id().into(), self.max_output_tokens());
|
||||
let mut request = into_mistral(request, self.model.id().into(), self.max_output_tokens());
|
||||
request.tools = vec![mistral::ToolDefinition::Function {
|
||||
function: mistral::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
|
@ -411,6 +415,52 @@ impl LanguageModel for MistralLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn into_mistral(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
max_output_tokens: Option<u32>,
|
||||
) -> mistral::Request {
|
||||
let len = request.messages.len();
|
||||
let merged_messages =
|
||||
request
|
||||
.messages
|
||||
.into_iter()
|
||||
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
||||
let role = msg.role;
|
||||
let content = msg.string_contents();
|
||||
|
||||
acc.push(match role {
|
||||
Role::User => mistral::RequestMessage::User { content },
|
||||
Role::Assistant => mistral::RequestMessage::Assistant {
|
||||
content: Some(content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => mistral::RequestMessage::System { content },
|
||||
});
|
||||
acc
|
||||
});
|
||||
|
||||
mistral::Request {
|
||||
model,
|
||||
messages: merged_messages,
|
||||
stream: true,
|
||||
max_tokens: max_output_tokens,
|
||||
temperature: request.temperature,
|
||||
response_format: None,
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| mistral::ToolDefinition::Function {
|
||||
function: mistral::FunctionDefinition {
|
||||
name: tool.name,
|
||||
description: Some(tool.description),
|
||||
parameters: Some(tool.input_schema),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<Editor>,
|
||||
state: gpui::Entity<State>,
|
||||
|
|
|
@ -318,7 +318,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
||||
> {
|
||||
let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
|
||||
let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
Ok(open_ai::extract_text_from_events(completions.await?)
|
||||
|
@ -336,7 +336,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
schema: serde_json::Value,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||
let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
|
||||
let mut request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
|
||||
request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
|
||||
function: FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
|
@ -366,6 +366,39 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn into_open_ai(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
max_output_tokens: Option<u32>,
|
||||
) -> open_ai::Request {
|
||||
let stream = !model.starts_with("o1-");
|
||||
open_ai::Request {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => open_ai::RequestMessage::User {
|
||||
content: msg.string_contents(),
|
||||
},
|
||||
Role::Assistant => open_ai::RequestMessage::Assistant {
|
||||
content: Some(msg.string_contents()),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => open_ai::RequestMessage::System {
|
||||
content: msg.string_contents(),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
stream,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature.unwrap_or(1.0),
|
||||
max_tokens: max_output_tokens,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_open_ai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
model: open_ai::Model,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue