language_model: Remove dependencies on individual model provider crates (#25503)
This PR removes the dependencies on the individual model provider crates from the `language_model` crate. The various conversion methods for converting a `LanguageModelRequest` into its provider-specific request type have been inlined into the various provider modules in the `language_models` crate. The model providers we provide via Zed's cloud offering get to stay, for now. Release Notes: - N/A
This commit is contained in:
parent
2f7a62780a
commit
0acd556106
11 changed files with 347 additions and 366 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -7015,16 +7015,12 @@ dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"collections",
|
"collections",
|
||||||
"deepseek",
|
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
"google_ai",
|
"google_ai",
|
||||||
"gpui",
|
"gpui",
|
||||||
"http_client",
|
"http_client",
|
||||||
"image",
|
"image",
|
||||||
"lmstudio",
|
|
||||||
"log",
|
"log",
|
||||||
"mistral",
|
|
||||||
"ollama",
|
|
||||||
"open_ai",
|
"open_ai",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"proto",
|
"proto",
|
||||||
|
|
|
@ -20,16 +20,12 @@ anthropic = { workspace = true, features = ["schemars"] }
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
base64.workspace = true
|
base64.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
deepseek = { workspace = true, features = ["schemars"] }
|
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
google_ai = { workspace = true, features = ["schemars"] }
|
google_ai = { workspace = true, features = ["schemars"] }
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
http_client.workspace = true
|
http_client.workspace = true
|
||||||
image.workspace = true
|
image.workspace = true
|
||||||
lmstudio = { workspace = true, features = ["schemars"] }
|
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
mistral = { workspace = true, features = ["schemars"] }
|
|
||||||
ollama = { workspace = true, features = ["schemars"] }
|
|
||||||
open_ai = { workspace = true, features = ["schemars"] }
|
open_ai = { workspace = true, features = ["schemars"] }
|
||||||
parking_lot.workspace = true
|
parking_lot.workspace = true
|
||||||
proto.workspace = true
|
proto.workspace = true
|
||||||
|
|
|
@ -1,7 +1,3 @@
|
||||||
pub mod cloud_model;
|
pub mod cloud_model;
|
||||||
|
|
||||||
pub use anthropic::Model as AnthropicModel;
|
|
||||||
pub use cloud_model::*;
|
pub use cloud_model::*;
|
||||||
pub use lmstudio::Model as LmStudioModel;
|
|
||||||
pub use ollama::Model as OllamaModel;
|
|
||||||
pub use open_ai::Model as OpenAiModel;
|
|
||||||
|
|
|
@ -241,298 +241,6 @@ pub struct LanguageModelRequest {
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelRequest {
|
|
||||||
pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
|
|
||||||
let stream = !model.starts_with("o1-");
|
|
||||||
open_ai::Request {
|
|
||||||
model,
|
|
||||||
messages: self
|
|
||||||
.messages
|
|
||||||
.into_iter()
|
|
||||||
.map(|msg| match msg.role {
|
|
||||||
Role::User => open_ai::RequestMessage::User {
|
|
||||||
content: msg.string_contents(),
|
|
||||||
},
|
|
||||||
Role::Assistant => open_ai::RequestMessage::Assistant {
|
|
||||||
content: Some(msg.string_contents()),
|
|
||||||
tool_calls: Vec::new(),
|
|
||||||
},
|
|
||||||
Role::System => open_ai::RequestMessage::System {
|
|
||||||
content: msg.string_contents(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
stream,
|
|
||||||
stop: self.stop,
|
|
||||||
temperature: self.temperature.unwrap_or(1.0),
|
|
||||||
max_tokens: max_output_tokens,
|
|
||||||
tools: Vec::new(),
|
|
||||||
tool_choice: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_mistral(self, model: String, max_output_tokens: Option<u32>) -> mistral::Request {
|
|
||||||
let len = self.messages.len();
|
|
||||||
let merged_messages =
|
|
||||||
self.messages
|
|
||||||
.into_iter()
|
|
||||||
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
|
||||||
let role = msg.role;
|
|
||||||
let content = msg.string_contents();
|
|
||||||
|
|
||||||
acc.push(match role {
|
|
||||||
Role::User => mistral::RequestMessage::User { content },
|
|
||||||
Role::Assistant => mistral::RequestMessage::Assistant {
|
|
||||||
content: Some(content),
|
|
||||||
tool_calls: Vec::new(),
|
|
||||||
},
|
|
||||||
Role::System => mistral::RequestMessage::System { content },
|
|
||||||
});
|
|
||||||
acc
|
|
||||||
});
|
|
||||||
|
|
||||||
mistral::Request {
|
|
||||||
model,
|
|
||||||
messages: merged_messages,
|
|
||||||
stream: true,
|
|
||||||
max_tokens: max_output_tokens,
|
|
||||||
temperature: self.temperature,
|
|
||||||
response_format: None,
|
|
||||||
tools: self
|
|
||||||
.tools
|
|
||||||
.into_iter()
|
|
||||||
.map(|tool| mistral::ToolDefinition::Function {
|
|
||||||
function: mistral::FunctionDefinition {
|
|
||||||
name: tool.name,
|
|
||||||
description: Some(tool.description),
|
|
||||||
parameters: Some(tool.input_schema),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
|
|
||||||
google_ai::GenerateContentRequest {
|
|
||||||
model,
|
|
||||||
contents: self
|
|
||||||
.messages
|
|
||||||
.into_iter()
|
|
||||||
.map(|msg| google_ai::Content {
|
|
||||||
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
|
|
||||||
text: msg.string_contents(),
|
|
||||||
})],
|
|
||||||
role: match msg.role {
|
|
||||||
Role::User => google_ai::Role::User,
|
|
||||||
Role::Assistant => google_ai::Role::Model,
|
|
||||||
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
generation_config: Some(google_ai::GenerationConfig {
|
|
||||||
candidate_count: Some(1),
|
|
||||||
stop_sequences: Some(self.stop),
|
|
||||||
max_output_tokens: None,
|
|
||||||
temperature: self.temperature.map(|t| t as f64).or(Some(1.0)),
|
|
||||||
top_p: None,
|
|
||||||
top_k: None,
|
|
||||||
}),
|
|
||||||
safety_settings: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_anthropic(
|
|
||||||
self,
|
|
||||||
model: String,
|
|
||||||
default_temperature: f32,
|
|
||||||
max_output_tokens: u32,
|
|
||||||
) -> anthropic::Request {
|
|
||||||
let mut new_messages: Vec<anthropic::Message> = Vec::new();
|
|
||||||
let mut system_message = String::new();
|
|
||||||
|
|
||||||
for message in self.messages {
|
|
||||||
if message.contents_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
match message.role {
|
|
||||||
Role::User | Role::Assistant => {
|
|
||||||
let cache_control = if message.cache {
|
|
||||||
Some(anthropic::CacheControl {
|
|
||||||
cache_type: anthropic::CacheControlType::Ephemeral,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let anthropic_message_content: Vec<anthropic::RequestContent> = message
|
|
||||||
.content
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|content| match content {
|
|
||||||
MessageContent::Text(text) => {
|
|
||||||
if !text.is_empty() {
|
|
||||||
Some(anthropic::RequestContent::Text {
|
|
||||||
text,
|
|
||||||
cache_control,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MessageContent::Image(image) => {
|
|
||||||
Some(anthropic::RequestContent::Image {
|
|
||||||
source: anthropic::ImageSource {
|
|
||||||
source_type: "base64".to_string(),
|
|
||||||
media_type: "image/png".to_string(),
|
|
||||||
data: image.source.to_string(),
|
|
||||||
},
|
|
||||||
cache_control,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
MessageContent::ToolUse(tool_use) => {
|
|
||||||
Some(anthropic::RequestContent::ToolUse {
|
|
||||||
id: tool_use.id.to_string(),
|
|
||||||
name: tool_use.name,
|
|
||||||
input: tool_use.input,
|
|
||||||
cache_control,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
MessageContent::ToolResult(tool_result) => {
|
|
||||||
Some(anthropic::RequestContent::ToolResult {
|
|
||||||
tool_use_id: tool_result.tool_use_id,
|
|
||||||
is_error: tool_result.is_error,
|
|
||||||
content: tool_result.content,
|
|
||||||
cache_control,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let anthropic_role = match message.role {
|
|
||||||
Role::User => anthropic::Role::User,
|
|
||||||
Role::Assistant => anthropic::Role::Assistant,
|
|
||||||
Role::System => unreachable!("System role should never occur here"),
|
|
||||||
};
|
|
||||||
if let Some(last_message) = new_messages.last_mut() {
|
|
||||||
if last_message.role == anthropic_role {
|
|
||||||
last_message.content.extend(anthropic_message_content);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
new_messages.push(anthropic::Message {
|
|
||||||
role: anthropic_role,
|
|
||||||
content: anthropic_message_content,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
Role::System => {
|
|
||||||
if !system_message.is_empty() {
|
|
||||||
system_message.push_str("\n\n");
|
|
||||||
}
|
|
||||||
system_message.push_str(&message.string_contents());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
anthropic::Request {
|
|
||||||
model,
|
|
||||||
messages: new_messages,
|
|
||||||
max_tokens: max_output_tokens,
|
|
||||||
system: Some(system_message),
|
|
||||||
tools: self
|
|
||||||
.tools
|
|
||||||
.into_iter()
|
|
||||||
.map(|tool| anthropic::Tool {
|
|
||||||
name: tool.name,
|
|
||||||
description: tool.description,
|
|
||||||
input_schema: tool.input_schema,
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
tool_choice: None,
|
|
||||||
metadata: None,
|
|
||||||
stop_sequences: Vec::new(),
|
|
||||||
temperature: self.temperature.or(Some(default_temperature)),
|
|
||||||
top_k: None,
|
|
||||||
top_p: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_deepseek(self, model: String, max_output_tokens: Option<u32>) -> deepseek::Request {
|
|
||||||
let is_reasoner = model == "deepseek-reasoner";
|
|
||||||
|
|
||||||
let len = self.messages.len();
|
|
||||||
let merged_messages =
|
|
||||||
self.messages
|
|
||||||
.into_iter()
|
|
||||||
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
|
||||||
let role = msg.role;
|
|
||||||
let content = msg.string_contents();
|
|
||||||
|
|
||||||
if is_reasoner {
|
|
||||||
if let Some(last_msg) = acc.last_mut() {
|
|
||||||
match (last_msg, role) {
|
|
||||||
(deepseek::RequestMessage::User { content: last }, Role::User) => {
|
|
||||||
last.push(' ');
|
|
||||||
last.push_str(&content);
|
|
||||||
return acc;
|
|
||||||
}
|
|
||||||
|
|
||||||
(
|
|
||||||
deepseek::RequestMessage::Assistant {
|
|
||||||
content: last_content,
|
|
||||||
..
|
|
||||||
},
|
|
||||||
Role::Assistant,
|
|
||||||
) => {
|
|
||||||
*last_content = last_content
|
|
||||||
.take()
|
|
||||||
.map(|c| {
|
|
||||||
let mut s =
|
|
||||||
String::with_capacity(c.len() + content.len() + 1);
|
|
||||||
s.push_str(&c);
|
|
||||||
s.push(' ');
|
|
||||||
s.push_str(&content);
|
|
||||||
s
|
|
||||||
})
|
|
||||||
.or(Some(content));
|
|
||||||
|
|
||||||
return acc;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
acc.push(match role {
|
|
||||||
Role::User => deepseek::RequestMessage::User { content },
|
|
||||||
Role::Assistant => deepseek::RequestMessage::Assistant {
|
|
||||||
content: Some(content),
|
|
||||||
tool_calls: Vec::new(),
|
|
||||||
},
|
|
||||||
Role::System => deepseek::RequestMessage::System { content },
|
|
||||||
});
|
|
||||||
acc
|
|
||||||
});
|
|
||||||
|
|
||||||
deepseek::Request {
|
|
||||||
model,
|
|
||||||
messages: merged_messages,
|
|
||||||
stream: true,
|
|
||||||
max_tokens: max_output_tokens,
|
|
||||||
temperature: if is_reasoner { None } else { self.temperature },
|
|
||||||
response_format: None,
|
|
||||||
tools: self
|
|
||||||
.tools
|
|
||||||
.into_iter()
|
|
||||||
.map(|tool| deepseek::ToolDefinition::Function {
|
|
||||||
function: deepseek::FunctionDefinition {
|
|
||||||
name: tool.name,
|
|
||||||
description: Some(tool.description),
|
|
||||||
parameters: Some(tool.input_schema),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
pub struct LanguageModelResponseMessage {
|
pub struct LanguageModelResponseMessage {
|
||||||
pub role: Option<Role>,
|
pub role: Option<Role>,
|
||||||
|
|
|
@ -45,43 +45,3 @@ impl Display for Role {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Role> for ollama::Role {
|
|
||||||
fn from(val: Role) -> Self {
|
|
||||||
match val {
|
|
||||||
Role::User => ollama::Role::User,
|
|
||||||
Role::Assistant => ollama::Role::Assistant,
|
|
||||||
Role::System => ollama::Role::System,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<Role> for open_ai::Role {
|
|
||||||
fn from(val: Role) -> Self {
|
|
||||||
match val {
|
|
||||||
Role::User => open_ai::Role::User,
|
|
||||||
Role::Assistant => open_ai::Role::Assistant,
|
|
||||||
Role::System => open_ai::Role::System,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<Role> for deepseek::Role {
|
|
||||||
fn from(val: Role) -> Self {
|
|
||||||
match val {
|
|
||||||
Role::User => deepseek::Role::User,
|
|
||||||
Role::Assistant => deepseek::Role::Assistant,
|
|
||||||
Role::System => deepseek::Role::System,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<Role> for lmstudio::Role {
|
|
||||||
fn from(val: Role) -> Self {
|
|
||||||
match val {
|
|
||||||
Role::User => lmstudio::Role::User,
|
|
||||||
Role::Assistant => lmstudio::Role::Assistant,
|
|
||||||
Role::System => lmstudio::Role::System,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ use http_client::HttpClient;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
||||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
|
||||||
};
|
};
|
||||||
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
|
@ -396,7 +396,8 @@ impl LanguageModel for AnthropicModel {
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||||
let request = request.into_anthropic(
|
let request = into_anthropic(
|
||||||
|
request,
|
||||||
self.model.id().into(),
|
self.model.id().into(),
|
||||||
self.model.default_temperature(),
|
self.model.default_temperature(),
|
||||||
self.model.max_output_tokens(),
|
self.model.max_output_tokens(),
|
||||||
|
@ -427,7 +428,8 @@ impl LanguageModel for AnthropicModel {
|
||||||
input_schema: serde_json::Value,
|
input_schema: serde_json::Value,
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
let mut request = request.into_anthropic(
|
let mut request = into_anthropic(
|
||||||
|
request,
|
||||||
self.model.tool_model_id().into(),
|
self.model.tool_model_id().into(),
|
||||||
self.model.default_temperature(),
|
self.model.default_temperature(),
|
||||||
self.model.max_output_tokens(),
|
self.model.max_output_tokens(),
|
||||||
|
@ -456,6 +458,117 @@ impl LanguageModel for AnthropicModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn into_anthropic(
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
model: String,
|
||||||
|
default_temperature: f32,
|
||||||
|
max_output_tokens: u32,
|
||||||
|
) -> anthropic::Request {
|
||||||
|
let mut new_messages: Vec<anthropic::Message> = Vec::new();
|
||||||
|
let mut system_message = String::new();
|
||||||
|
|
||||||
|
for message in request.messages {
|
||||||
|
if message.contents_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match message.role {
|
||||||
|
Role::User | Role::Assistant => {
|
||||||
|
let cache_control = if message.cache {
|
||||||
|
Some(anthropic::CacheControl {
|
||||||
|
cache_type: anthropic::CacheControlType::Ephemeral,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let anthropic_message_content: Vec<anthropic::RequestContent> = message
|
||||||
|
.content
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|content| match content {
|
||||||
|
MessageContent::Text(text) => {
|
||||||
|
if !text.is_empty() {
|
||||||
|
Some(anthropic::RequestContent::Text {
|
||||||
|
text,
|
||||||
|
cache_control,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
|
||||||
|
source: anthropic::ImageSource {
|
||||||
|
source_type: "base64".to_string(),
|
||||||
|
media_type: "image/png".to_string(),
|
||||||
|
data: image.source.to_string(),
|
||||||
|
},
|
||||||
|
cache_control,
|
||||||
|
}),
|
||||||
|
MessageContent::ToolUse(tool_use) => {
|
||||||
|
Some(anthropic::RequestContent::ToolUse {
|
||||||
|
id: tool_use.id.to_string(),
|
||||||
|
name: tool_use.name,
|
||||||
|
input: tool_use.input,
|
||||||
|
cache_control,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
MessageContent::ToolResult(tool_result) => {
|
||||||
|
Some(anthropic::RequestContent::ToolResult {
|
||||||
|
tool_use_id: tool_result.tool_use_id,
|
||||||
|
is_error: tool_result.is_error,
|
||||||
|
content: tool_result.content,
|
||||||
|
cache_control,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let anthropic_role = match message.role {
|
||||||
|
Role::User => anthropic::Role::User,
|
||||||
|
Role::Assistant => anthropic::Role::Assistant,
|
||||||
|
Role::System => unreachable!("System role should never occur here"),
|
||||||
|
};
|
||||||
|
if let Some(last_message) = new_messages.last_mut() {
|
||||||
|
if last_message.role == anthropic_role {
|
||||||
|
last_message.content.extend(anthropic_message_content);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
new_messages.push(anthropic::Message {
|
||||||
|
role: anthropic_role,
|
||||||
|
content: anthropic_message_content,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Role::System => {
|
||||||
|
if !system_message.is_empty() {
|
||||||
|
system_message.push_str("\n\n");
|
||||||
|
}
|
||||||
|
system_message.push_str(&message.string_contents());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anthropic::Request {
|
||||||
|
model,
|
||||||
|
messages: new_messages,
|
||||||
|
max_tokens: max_output_tokens,
|
||||||
|
system: Some(system_message),
|
||||||
|
tools: request
|
||||||
|
.tools
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool| anthropic::Tool {
|
||||||
|
name: tool.name,
|
||||||
|
description: tool.description,
|
||||||
|
input_schema: tool.input_schema,
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
tool_choice: None,
|
||||||
|
metadata: None,
|
||||||
|
stop_sequences: Vec::new(),
|
||||||
|
temperature: request.temperature.or(Some(default_temperature)),
|
||||||
|
top_k: None,
|
||||||
|
top_p: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn map_to_language_model_completion_events(
|
pub fn map_to_language_model_completion_events(
|
||||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
use super::open_ai::count_open_ai_tokens;
|
|
||||||
use anthropic::AnthropicError;
|
use anthropic::AnthropicError;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use client::{
|
use client::{
|
||||||
|
@ -43,11 +42,13 @@ use strum::IntoEnumIterator;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use ui::{prelude::*, TintColor};
|
use ui::{prelude::*, TintColor};
|
||||||
|
|
||||||
use crate::provider::anthropic::map_to_language_model_completion_events;
|
use crate::provider::anthropic::{
|
||||||
|
count_anthropic_tokens, into_anthropic, map_to_language_model_completion_events,
|
||||||
|
};
|
||||||
|
use crate::provider::google::into_google;
|
||||||
|
use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai};
|
||||||
use crate::AllLanguageModelSettings;
|
use crate::AllLanguageModelSettings;
|
||||||
|
|
||||||
use super::anthropic::count_anthropic_tokens;
|
|
||||||
|
|
||||||
pub const PROVIDER_NAME: &str = "Zed";
|
pub const PROVIDER_NAME: &str = "Zed";
|
||||||
|
|
||||||
const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
|
const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
|
||||||
|
@ -612,7 +613,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
|
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
|
||||||
CloudModel::Google(model) => {
|
CloudModel::Google(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_google(model.id().into());
|
let request = into_google(request, model.id().into());
|
||||||
let request = google_ai::CountTokensRequest {
|
let request = google_ai::CountTokensRequest {
|
||||||
contents: request.contents,
|
contents: request.contents,
|
||||||
};
|
};
|
||||||
|
@ -638,7 +639,8 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||||
match &self.model {
|
match &self.model {
|
||||||
CloudModel::Anthropic(model) => {
|
CloudModel::Anthropic(model) => {
|
||||||
let request = request.into_anthropic(
|
let request = into_anthropic(
|
||||||
|
request,
|
||||||
model.id().into(),
|
model.id().into(),
|
||||||
model.default_temperature(),
|
model.default_temperature(),
|
||||||
model.max_output_tokens(),
|
model.max_output_tokens(),
|
||||||
|
@ -666,7 +668,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
}
|
}
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
|
let request = into_open_ai(request, model.id().into(), model.max_output_tokens());
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let response = Self::perform_llm_completion(
|
let response = Self::perform_llm_completion(
|
||||||
|
@ -693,7 +695,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
}
|
}
|
||||||
CloudModel::Google(model) => {
|
CloudModel::Google(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_google(model.id().into());
|
let request = into_google(request, model.id().into());
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let response = Self::perform_llm_completion(
|
let response = Self::perform_llm_completion(
|
||||||
|
@ -736,7 +738,8 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
|
|
||||||
match &self.model {
|
match &self.model {
|
||||||
CloudModel::Anthropic(model) => {
|
CloudModel::Anthropic(model) => {
|
||||||
let mut request = request.into_anthropic(
|
let mut request = into_anthropic(
|
||||||
|
request,
|
||||||
model.tool_model_id().into(),
|
model.tool_model_id().into(),
|
||||||
model.default_temperature(),
|
model.default_temperature(),
|
||||||
model.max_output_tokens(),
|
model.max_output_tokens(),
|
||||||
|
@ -776,7 +779,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
}
|
}
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let mut request =
|
let mut request =
|
||||||
request.into_open_ai(model.id().into(), model.max_output_tokens());
|
into_open_ai(request, model.id().into(), model.max_output_tokens());
|
||||||
request.tool_choice = Some(open_ai::ToolChoice::Other(
|
request.tool_choice = Some(open_ai::ToolChoice::Other(
|
||||||
open_ai::ToolDefinition::Function {
|
open_ai::ToolDefinition::Function {
|
||||||
function: open_ai::FunctionDefinition {
|
function: open_ai::FunctionDefinition {
|
||||||
|
|
|
@ -322,7 +322,11 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||||
let request = request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
|
let request = into_deepseek(
|
||||||
|
request,
|
||||||
|
self.model.id().to_string(),
|
||||||
|
self.max_output_tokens(),
|
||||||
|
);
|
||||||
let stream = self.stream_completion(request, cx);
|
let stream = self.stream_completion(request, cx);
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
|
@ -357,8 +361,11 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||||
schema: serde_json::Value,
|
schema: serde_json::Value,
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||||
let mut deepseek_request =
|
let mut deepseek_request = into_deepseek(
|
||||||
request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
|
request,
|
||||||
|
self.model.id().to_string(),
|
||||||
|
self.max_output_tokens(),
|
||||||
|
);
|
||||||
|
|
||||||
deepseek_request.tools = vec![deepseek::ToolDefinition::Function {
|
deepseek_request.tools = vec![deepseek::ToolDefinition::Function {
|
||||||
function: deepseek::FunctionDefinition {
|
function: deepseek::FunctionDefinition {
|
||||||
|
@ -402,6 +409,93 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn into_deepseek(
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
model: String,
|
||||||
|
max_output_tokens: Option<u32>,
|
||||||
|
) -> deepseek::Request {
|
||||||
|
let is_reasoner = model == "deepseek-reasoner";
|
||||||
|
|
||||||
|
let len = request.messages.len();
|
||||||
|
let merged_messages =
|
||||||
|
request
|
||||||
|
.messages
|
||||||
|
.into_iter()
|
||||||
|
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
||||||
|
let role = msg.role;
|
||||||
|
let content = msg.string_contents();
|
||||||
|
|
||||||
|
if is_reasoner {
|
||||||
|
if let Some(last_msg) = acc.last_mut() {
|
||||||
|
match (last_msg, role) {
|
||||||
|
(deepseek::RequestMessage::User { content: last }, Role::User) => {
|
||||||
|
last.push(' ');
|
||||||
|
last.push_str(&content);
|
||||||
|
return acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
(
|
||||||
|
deepseek::RequestMessage::Assistant {
|
||||||
|
content: last_content,
|
||||||
|
..
|
||||||
|
},
|
||||||
|
Role::Assistant,
|
||||||
|
) => {
|
||||||
|
*last_content = last_content
|
||||||
|
.take()
|
||||||
|
.map(|c| {
|
||||||
|
let mut s =
|
||||||
|
String::with_capacity(c.len() + content.len() + 1);
|
||||||
|
s.push_str(&c);
|
||||||
|
s.push(' ');
|
||||||
|
s.push_str(&content);
|
||||||
|
s
|
||||||
|
})
|
||||||
|
.or(Some(content));
|
||||||
|
|
||||||
|
return acc;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
acc.push(match role {
|
||||||
|
Role::User => deepseek::RequestMessage::User { content },
|
||||||
|
Role::Assistant => deepseek::RequestMessage::Assistant {
|
||||||
|
content: Some(content),
|
||||||
|
tool_calls: Vec::new(),
|
||||||
|
},
|
||||||
|
Role::System => deepseek::RequestMessage::System { content },
|
||||||
|
});
|
||||||
|
acc
|
||||||
|
});
|
||||||
|
|
||||||
|
deepseek::Request {
|
||||||
|
model,
|
||||||
|
messages: merged_messages,
|
||||||
|
stream: true,
|
||||||
|
max_tokens: max_output_tokens,
|
||||||
|
temperature: if is_reasoner {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
request.temperature
|
||||||
|
},
|
||||||
|
response_format: None,
|
||||||
|
tools: request
|
||||||
|
.tools
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool| deepseek::ToolDefinition::Function {
|
||||||
|
function: deepseek::FunctionDefinition {
|
||||||
|
name: tool.name,
|
||||||
|
description: Some(tool.description),
|
||||||
|
parameters: Some(tool.input_schema),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ConfigurationView {
|
struct ConfigurationView {
|
||||||
api_key_editor: Entity<Editor>,
|
api_key_editor: Entity<Editor>,
|
||||||
state: Entity<State>,
|
state: Entity<State>,
|
||||||
|
|
|
@ -272,7 +272,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &App,
|
cx: &App,
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
let request = request.into_google(self.model.id().to_string());
|
let request = into_google(request, self.model.id().to_string());
|
||||||
let http_client = self.http_client.clone();
|
let http_client = self.http_client.clone();
|
||||||
let api_key = self.state.read(cx).api_key.clone();
|
let api_key = self.state.read(cx).api_key.clone();
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
'static,
|
'static,
|
||||||
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
||||||
> {
|
> {
|
||||||
let request = request.into_google(self.model.id().to_string());
|
let request = into_google(request, self.model.id().to_string());
|
||||||
|
|
||||||
let http_client = self.http_client.clone();
|
let http_client = self.http_client.clone();
|
||||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||||
|
@ -341,6 +341,38 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn into_google(
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
model: String,
|
||||||
|
) -> google_ai::GenerateContentRequest {
|
||||||
|
google_ai::GenerateContentRequest {
|
||||||
|
model,
|
||||||
|
contents: request
|
||||||
|
.messages
|
||||||
|
.into_iter()
|
||||||
|
.map(|msg| google_ai::Content {
|
||||||
|
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
|
||||||
|
text: msg.string_contents(),
|
||||||
|
})],
|
||||||
|
role: match msg.role {
|
||||||
|
Role::User => google_ai::Role::User,
|
||||||
|
Role::Assistant => google_ai::Role::Model,
|
||||||
|
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
generation_config: Some(google_ai::GenerationConfig {
|
||||||
|
candidate_count: Some(1),
|
||||||
|
stop_sequences: Some(request.stop),
|
||||||
|
max_output_tokens: None,
|
||||||
|
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
|
||||||
|
top_p: None,
|
||||||
|
top_k: None,
|
||||||
|
}),
|
||||||
|
safety_settings: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn count_google_tokens(
|
pub fn count_google_tokens(
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &App,
|
cx: &App,
|
||||||
|
|
|
@ -334,7 +334,11 @@ impl LanguageModel for MistralLanguageModel {
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||||
let request = request.into_mistral(self.model.id().to_string(), self.max_output_tokens());
|
let request = into_mistral(
|
||||||
|
request,
|
||||||
|
self.model.id().to_string(),
|
||||||
|
self.max_output_tokens(),
|
||||||
|
);
|
||||||
let stream = self.stream_completion(request, cx);
|
let stream = self.stream_completion(request, cx);
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
|
@ -369,7 +373,7 @@ impl LanguageModel for MistralLanguageModel {
|
||||||
schema: serde_json::Value,
|
schema: serde_json::Value,
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||||
let mut request = request.into_mistral(self.model.id().into(), self.max_output_tokens());
|
let mut request = into_mistral(request, self.model.id().into(), self.max_output_tokens());
|
||||||
request.tools = vec![mistral::ToolDefinition::Function {
|
request.tools = vec![mistral::ToolDefinition::Function {
|
||||||
function: mistral::FunctionDefinition {
|
function: mistral::FunctionDefinition {
|
||||||
name: tool_name.clone(),
|
name: tool_name.clone(),
|
||||||
|
@ -411,6 +415,52 @@ impl LanguageModel for MistralLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn into_mistral(
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
model: String,
|
||||||
|
max_output_tokens: Option<u32>,
|
||||||
|
) -> mistral::Request {
|
||||||
|
let len = request.messages.len();
|
||||||
|
let merged_messages =
|
||||||
|
request
|
||||||
|
.messages
|
||||||
|
.into_iter()
|
||||||
|
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
||||||
|
let role = msg.role;
|
||||||
|
let content = msg.string_contents();
|
||||||
|
|
||||||
|
acc.push(match role {
|
||||||
|
Role::User => mistral::RequestMessage::User { content },
|
||||||
|
Role::Assistant => mistral::RequestMessage::Assistant {
|
||||||
|
content: Some(content),
|
||||||
|
tool_calls: Vec::new(),
|
||||||
|
},
|
||||||
|
Role::System => mistral::RequestMessage::System { content },
|
||||||
|
});
|
||||||
|
acc
|
||||||
|
});
|
||||||
|
|
||||||
|
mistral::Request {
|
||||||
|
model,
|
||||||
|
messages: merged_messages,
|
||||||
|
stream: true,
|
||||||
|
max_tokens: max_output_tokens,
|
||||||
|
temperature: request.temperature,
|
||||||
|
response_format: None,
|
||||||
|
tools: request
|
||||||
|
.tools
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool| mistral::ToolDefinition::Function {
|
||||||
|
function: mistral::FunctionDefinition {
|
||||||
|
name: tool.name,
|
||||||
|
description: Some(tool.description),
|
||||||
|
parameters: Some(tool.input_schema),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ConfigurationView {
|
struct ConfigurationView {
|
||||||
api_key_editor: Entity<Editor>,
|
api_key_editor: Entity<Editor>,
|
||||||
state: gpui::Entity<State>,
|
state: gpui::Entity<State>,
|
||||||
|
|
|
@ -318,7 +318,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||||
'static,
|
'static,
|
||||||
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
||||||
> {
|
> {
|
||||||
let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
|
let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
|
||||||
let completions = self.stream_completion(request, cx);
|
let completions = self.stream_completion(request, cx);
|
||||||
async move {
|
async move {
|
||||||
Ok(open_ai::extract_text_from_events(completions.await?)
|
Ok(open_ai::extract_text_from_events(completions.await?)
|
||||||
|
@ -336,7 +336,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||||
schema: serde_json::Value,
|
schema: serde_json::Value,
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||||
let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
|
let mut request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
|
||||||
request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
|
request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
|
||||||
function: FunctionDefinition {
|
function: FunctionDefinition {
|
||||||
name: tool_name.clone(),
|
name: tool_name.clone(),
|
||||||
|
@ -366,6 +366,39 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn into_open_ai(
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
model: String,
|
||||||
|
max_output_tokens: Option<u32>,
|
||||||
|
) -> open_ai::Request {
|
||||||
|
let stream = !model.starts_with("o1-");
|
||||||
|
open_ai::Request {
|
||||||
|
model,
|
||||||
|
messages: request
|
||||||
|
.messages
|
||||||
|
.into_iter()
|
||||||
|
.map(|msg| match msg.role {
|
||||||
|
Role::User => open_ai::RequestMessage::User {
|
||||||
|
content: msg.string_contents(),
|
||||||
|
},
|
||||||
|
Role::Assistant => open_ai::RequestMessage::Assistant {
|
||||||
|
content: Some(msg.string_contents()),
|
||||||
|
tool_calls: Vec::new(),
|
||||||
|
},
|
||||||
|
Role::System => open_ai::RequestMessage::System {
|
||||||
|
content: msg.string_contents(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
stream,
|
||||||
|
stop: request.stop,
|
||||||
|
temperature: request.temperature.unwrap_or(1.0),
|
||||||
|
max_tokens: max_output_tokens,
|
||||||
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn count_open_ai_tokens(
|
pub fn count_open_ai_tokens(
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
model: open_ai::Model,
|
model: open_ai::Model,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue