diff --git a/Cargo.lock b/Cargo.lock index 9cade7daf8..0e0c09c213 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7015,16 +7015,12 @@ dependencies = [ "anyhow", "base64 0.22.1", "collections", - "deepseek", "futures 0.3.31", "google_ai", "gpui", "http_client", "image", - "lmstudio", "log", - "mistral", - "ollama", "open_ai", "parking_lot", "proto", diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 51f205dced..091ed524c8 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -20,16 +20,12 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true base64.workspace = true collections.workspace = true -deepseek = { workspace = true, features = ["schemars"] } futures.workspace = true google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true http_client.workspace = true image.workspace = true -lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true -mistral = { workspace = true, features = ["schemars"] } -ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } parking_lot.workspace = true proto.workspace = true diff --git a/crates/language_model/src/model/mod.rs b/crates/language_model/src/model/mod.rs index 12aaed3ab2..db4c55daa7 100644 --- a/crates/language_model/src/model/mod.rs +++ b/crates/language_model/src/model/mod.rs @@ -1,7 +1,3 @@ pub mod cloud_model; -pub use anthropic::Model as AnthropicModel; pub use cloud_model::*; -pub use lmstudio::Model as LmStudioModel; -pub use ollama::Model as OllamaModel; -pub use open_ai::Model as OpenAiModel; diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 507e8b4207..5f11ddffd6 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -241,298 +241,6 @@ pub struct LanguageModelRequest { pub temperature: Option, } -impl LanguageModelRequest { - pub fn into_open_ai(self, model: String, max_output_tokens: Option) -> open_ai::Request { - let stream = !model.starts_with("o1-"); - open_ai::Request { - model, - messages: self - .messages - .into_iter() - .map(|msg| match msg.role { - Role::User => open_ai::RequestMessage::User { - content: msg.string_contents(), - }, - Role::Assistant => open_ai::RequestMessage::Assistant { - content: Some(msg.string_contents()), - tool_calls: Vec::new(), - }, - Role::System => open_ai::RequestMessage::System { - content: msg.string_contents(), - }, - }) - .collect(), - stream, - stop: self.stop, - temperature: self.temperature.unwrap_or(1.0), - max_tokens: max_output_tokens, - tools: Vec::new(), - tool_choice: None, - } - } - - pub fn into_mistral(self, model: String, max_output_tokens: Option) -> mistral::Request { - let len = self.messages.len(); - let merged_messages = - self.messages - .into_iter() - .fold(Vec::with_capacity(len), |mut acc, msg| { - let role = msg.role; - let content = msg.string_contents(); - - acc.push(match role { - Role::User => mistral::RequestMessage::User { content }, - Role::Assistant => mistral::RequestMessage::Assistant { - content: Some(content), - tool_calls: Vec::new(), - }, - Role::System => mistral::RequestMessage::System { content }, - }); - acc - }); - - mistral::Request { - model, - messages: merged_messages, - stream: true, - max_tokens: max_output_tokens, - temperature: self.temperature, - response_format: None, - tools: self - .tools - .into_iter() - .map(|tool| mistral::ToolDefinition::Function { - function: mistral::FunctionDefinition { - name: tool.name, - description: Some(tool.description), - parameters: Some(tool.input_schema), - }, - }) - .collect(), - } - } - - pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest { - google_ai::GenerateContentRequest { - model, - contents: self - .messages - .into_iter() - .map(|msg| google_ai::Content { - parts: vec![google_ai::Part::TextPart(google_ai::TextPart { - text: msg.string_contents(), - })], - role: match msg.role { - Role::User => google_ai::Role::User, - Role::Assistant => google_ai::Role::Model, - Role::System => google_ai::Role::User, // Google AI doesn't have a system role - }, - }) - .collect(), - generation_config: Some(google_ai::GenerationConfig { - candidate_count: Some(1), - stop_sequences: Some(self.stop), - max_output_tokens: None, - temperature: self.temperature.map(|t| t as f64).or(Some(1.0)), - top_p: None, - top_k: None, - }), - safety_settings: None, - } - } - - pub fn into_anthropic( - self, - model: String, - default_temperature: f32, - max_output_tokens: u32, - ) -> anthropic::Request { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in self.messages { - if message.contents_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - let cache_control = if message.cache { - Some(anthropic::CacheControl { - cache_type: anthropic::CacheControlType::Ephemeral, - }) - } else { - None - }; - let anthropic_message_content: Vec = message - .content - .into_iter() - .filter_map(|content| match content { - MessageContent::Text(text) => { - if !text.is_empty() { - Some(anthropic::RequestContent::Text { - text, - cache_control, - }) - } else { - None - } - } - MessageContent::Image(image) => { - Some(anthropic::RequestContent::Image { - source: anthropic::ImageSource { - source_type: "base64".to_string(), - media_type: "image/png".to_string(), - data: image.source.to_string(), - }, - cache_control, - }) - } - MessageContent::ToolUse(tool_use) => { - Some(anthropic::RequestContent::ToolUse { - id: tool_use.id.to_string(), - name: tool_use.name, - input: tool_use.input, - cache_control, - }) - } - MessageContent::ToolResult(tool_result) => { - Some(anthropic::RequestContent::ToolResult { - tool_use_id: tool_result.tool_use_id, - is_error: tool_result.is_error, - content: tool_result.content, - cache_control, - }) - } - }) - .collect(); - let anthropic_role = match message.role { - Role::User => anthropic::Role::User, - Role::Assistant => anthropic::Role::Assistant, - Role::System => unreachable!("System role should never occur here"), - }; - if let Some(last_message) = new_messages.last_mut() { - if last_message.role == anthropic_role { - last_message.content.extend(anthropic_message_content); - continue; - } - } - new_messages.push(anthropic::Message { - role: anthropic_role, - content: anthropic_message_content, - }); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.string_contents()); - } - } - } - - anthropic::Request { - model, - messages: new_messages, - max_tokens: max_output_tokens, - system: Some(system_message), - tools: self - .tools - .into_iter() - .map(|tool| anthropic::Tool { - name: tool.name, - description: tool.description, - input_schema: tool.input_schema, - }) - .collect(), - tool_choice: None, - metadata: None, - stop_sequences: Vec::new(), - temperature: self.temperature.or(Some(default_temperature)), - top_k: None, - top_p: None, - } - } - - pub fn into_deepseek(self, model: String, max_output_tokens: Option) -> deepseek::Request { - let is_reasoner = model == "deepseek-reasoner"; - - let len = self.messages.len(); - let merged_messages = - self.messages - .into_iter() - .fold(Vec::with_capacity(len), |mut acc, msg| { - let role = msg.role; - let content = msg.string_contents(); - - if is_reasoner { - if let Some(last_msg) = acc.last_mut() { - match (last_msg, role) { - (deepseek::RequestMessage::User { content: last }, Role::User) => { - last.push(' '); - last.push_str(&content); - return acc; - } - - ( - deepseek::RequestMessage::Assistant { - content: last_content, - .. - }, - Role::Assistant, - ) => { - *last_content = last_content - .take() - .map(|c| { - let mut s = - String::with_capacity(c.len() + content.len() + 1); - s.push_str(&c); - s.push(' '); - s.push_str(&content); - s - }) - .or(Some(content)); - - return acc; - } - _ => {} - } - } - } - - acc.push(match role { - Role::User => deepseek::RequestMessage::User { content }, - Role::Assistant => deepseek::RequestMessage::Assistant { - content: Some(content), - tool_calls: Vec::new(), - }, - Role::System => deepseek::RequestMessage::System { content }, - }); - acc - }); - - deepseek::Request { - model, - messages: merged_messages, - stream: true, - max_tokens: max_output_tokens, - temperature: if is_reasoner { None } else { self.temperature }, - response_format: None, - tools: self - .tools - .into_iter() - .map(|tool| deepseek::ToolDefinition::Function { - function: deepseek::FunctionDefinition { - name: tool.name, - description: Some(tool.description), - parameters: Some(tool.input_schema), - }, - }) - .collect(), - } - } -} - #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct LanguageModelResponseMessage { pub role: Option, diff --git a/crates/language_model/src/role.rs b/crates/language_model/src/role.rs index fa56a2a88b..953dfa6fdf 100644 --- a/crates/language_model/src/role.rs +++ b/crates/language_model/src/role.rs @@ -45,43 +45,3 @@ impl Display for Role { } } } - -impl From for ollama::Role { - fn from(val: Role) -> Self { - match val { - Role::User => ollama::Role::User, - Role::Assistant => ollama::Role::Assistant, - Role::System => ollama::Role::System, - } - } -} - -impl From for open_ai::Role { - fn from(val: Role) -> Self { - match val { - Role::User => open_ai::Role::User, - Role::Assistant => open_ai::Role::Assistant, - Role::System => open_ai::Role::System, - } - } -} - -impl From for deepseek::Role { - fn from(val: Role) -> Self { - match val { - Role::User => deepseek::Role::User, - Role::Assistant => deepseek::Role::Assistant, - Role::System => deepseek::Role::System, - } - } -} - -impl From for lmstudio::Role { - fn from(val: Role) -> Self { - match val { - Role::User => lmstudio::Role::User, - Role::Assistant => lmstudio::Role::Assistant, - Role::System => lmstudio::Role::System, - } - } -} diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 9908929457..c3ec14f46e 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -13,7 +13,7 @@ use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, + LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role, }; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; @@ -396,7 +396,8 @@ impl LanguageModel for AnthropicModel { request: LanguageModelRequest, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { - let request = request.into_anthropic( + let request = into_anthropic( + request, self.model.id().into(), self.model.default_temperature(), self.model.max_output_tokens(), @@ -427,7 +428,8 @@ impl LanguageModel for AnthropicModel { input_schema: serde_json::Value, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { - let mut request = request.into_anthropic( + let mut request = into_anthropic( + request, self.model.tool_model_id().into(), self.model.default_temperature(), self.model.max_output_tokens(), @@ -456,6 +458,117 @@ impl LanguageModel for AnthropicModel { } } +pub fn into_anthropic( + request: LanguageModelRequest, + model: String, + default_temperature: f32, + max_output_tokens: u32, +) -> anthropic::Request { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages { + if message.contents_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + let cache_control = if message.cache { + Some(anthropic::CacheControl { + cache_type: anthropic::CacheControlType::Ephemeral, + }) + } else { + None + }; + let anthropic_message_content: Vec = message + .content + .into_iter() + .filter_map(|content| match content { + MessageContent::Text(text) => { + if !text.is_empty() { + Some(anthropic::RequestContent::Text { + text, + cache_control, + }) + } else { + None + } + } + MessageContent::Image(image) => Some(anthropic::RequestContent::Image { + source: anthropic::ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + cache_control, + }), + MessageContent::ToolUse(tool_use) => { + Some(anthropic::RequestContent::ToolUse { + id: tool_use.id.to_string(), + name: tool_use.name, + input: tool_use.input, + cache_control, + }) + } + MessageContent::ToolResult(tool_result) => { + Some(anthropic::RequestContent::ToolResult { + tool_use_id: tool_result.tool_use_id, + is_error: tool_result.is_error, + content: tool_result.content, + cache_control, + }) + } + }) + .collect(); + let anthropic_role = match message.role { + Role::User => anthropic::Role::User, + Role::Assistant => anthropic::Role::Assistant, + Role::System => unreachable!("System role should never occur here"), + }; + if let Some(last_message) = new_messages.last_mut() { + if last_message.role == anthropic_role { + last_message.content.extend(anthropic_message_content); + continue; + } + } + new_messages.push(anthropic::Message { + role: anthropic_role, + content: anthropic_message_content, + }); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.string_contents()); + } + } + } + + anthropic::Request { + model, + messages: new_messages, + max_tokens: max_output_tokens, + system: Some(system_message), + tools: request + .tools + .into_iter() + .map(|tool| anthropic::Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + }) + .collect(), + tool_choice: None, + metadata: None, + stop_sequences: Vec::new(), + temperature: request.temperature.or(Some(default_temperature)), + top_k: None, + top_p: None, + } +} + pub fn map_to_language_model_completion_events( events: Pin>>>, ) -> impl Stream> { diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 236b78527b..cbdf1785e0 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,4 +1,3 @@ -use super::open_ai::count_open_ai_tokens; use anthropic::AnthropicError; use anyhow::{anyhow, Result}; use client::{ @@ -43,11 +42,13 @@ use strum::IntoEnumIterator; use thiserror::Error; use ui::{prelude::*, TintColor}; -use crate::provider::anthropic::map_to_language_model_completion_events; +use crate::provider::anthropic::{ + count_anthropic_tokens, into_anthropic, map_to_language_model_completion_events, +}; +use crate::provider::google::into_google; +use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai}; use crate::AllLanguageModelSettings; -use super::anthropic::count_anthropic_tokens; - pub const PROVIDER_NAME: &str = "Zed"; const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> = @@ -612,7 +613,7 @@ impl LanguageModel for CloudLanguageModel { CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx), CloudModel::Google(model) => { let client = self.client.clone(); - let request = request.into_google(model.id().into()); + let request = into_google(request, model.id().into()); let request = google_ai::CountTokensRequest { contents: request.contents, }; @@ -638,7 +639,8 @@ impl LanguageModel for CloudLanguageModel { ) -> BoxFuture<'static, Result>>> { match &self.model { CloudModel::Anthropic(model) => { - let request = request.into_anthropic( + let request = into_anthropic( + request, model.id().into(), model.default_temperature(), model.max_output_tokens(), @@ -666,7 +668,7 @@ impl LanguageModel for CloudLanguageModel { } CloudModel::OpenAi(model) => { let client = self.client.clone(); - let request = request.into_open_ai(model.id().into(), model.max_output_tokens()); + let request = into_open_ai(request, model.id().into(), model.max_output_tokens()); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream(async move { let response = Self::perform_llm_completion( @@ -693,7 +695,7 @@ impl LanguageModel for CloudLanguageModel { } CloudModel::Google(model) => { let client = self.client.clone(); - let request = request.into_google(model.id().into()); + let request = into_google(request, model.id().into()); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream(async move { let response = Self::perform_llm_completion( @@ -736,7 +738,8 @@ impl LanguageModel for CloudLanguageModel { match &self.model { CloudModel::Anthropic(model) => { - let mut request = request.into_anthropic( + let mut request = into_anthropic( + request, model.tool_model_id().into(), model.default_temperature(), model.max_output_tokens(), @@ -776,7 +779,7 @@ impl LanguageModel for CloudLanguageModel { } CloudModel::OpenAi(model) => { let mut request = - request.into_open_ai(model.id().into(), model.max_output_tokens()); + into_open_ai(request, model.id().into(), model.max_output_tokens()); request.tool_choice = Some(open_ai::ToolChoice::Other( open_ai::ToolDefinition::Function { function: open_ai::FunctionDefinition { diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 830e94ecb5..84d34307cb 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -322,7 +322,11 @@ impl LanguageModel for DeepSeekLanguageModel { request: LanguageModelRequest, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { - let request = request.into_deepseek(self.model.id().to_string(), self.max_output_tokens()); + let request = into_deepseek( + request, + self.model.id().to_string(), + self.max_output_tokens(), + ); let stream = self.stream_completion(request, cx); async move { @@ -357,8 +361,11 @@ impl LanguageModel for DeepSeekLanguageModel { schema: serde_json::Value, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { - let mut deepseek_request = - request.into_deepseek(self.model.id().to_string(), self.max_output_tokens()); + let mut deepseek_request = into_deepseek( + request, + self.model.id().to_string(), + self.max_output_tokens(), + ); deepseek_request.tools = vec![deepseek::ToolDefinition::Function { function: deepseek::FunctionDefinition { @@ -402,6 +409,93 @@ impl LanguageModel for DeepSeekLanguageModel { } } +pub fn into_deepseek( + request: LanguageModelRequest, + model: String, + max_output_tokens: Option, +) -> deepseek::Request { + let is_reasoner = model == "deepseek-reasoner"; + + let len = request.messages.len(); + let merged_messages = + request + .messages + .into_iter() + .fold(Vec::with_capacity(len), |mut acc, msg| { + let role = msg.role; + let content = msg.string_contents(); + + if is_reasoner { + if let Some(last_msg) = acc.last_mut() { + match (last_msg, role) { + (deepseek::RequestMessage::User { content: last }, Role::User) => { + last.push(' '); + last.push_str(&content); + return acc; + } + + ( + deepseek::RequestMessage::Assistant { + content: last_content, + .. + }, + Role::Assistant, + ) => { + *last_content = last_content + .take() + .map(|c| { + let mut s = + String::with_capacity(c.len() + content.len() + 1); + s.push_str(&c); + s.push(' '); + s.push_str(&content); + s + }) + .or(Some(content)); + + return acc; + } + _ => {} + } + } + } + + acc.push(match role { + Role::User => deepseek::RequestMessage::User { content }, + Role::Assistant => deepseek::RequestMessage::Assistant { + content: Some(content), + tool_calls: Vec::new(), + }, + Role::System => deepseek::RequestMessage::System { content }, + }); + acc + }); + + deepseek::Request { + model, + messages: merged_messages, + stream: true, + max_tokens: max_output_tokens, + temperature: if is_reasoner { + None + } else { + request.temperature + }, + response_format: None, + tools: request + .tools + .into_iter() + .map(|tool| deepseek::ToolDefinition::Function { + function: deepseek::FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), + } +} + struct ConfigurationView { api_key_editor: Entity, state: Entity, diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 0bf5001f79..934a06af55 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -272,7 +272,7 @@ impl LanguageModel for GoogleLanguageModel { request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - let request = request.into_google(self.model.id().to_string()); + let request = into_google(request, self.model.id().to_string()); let http_client = self.http_client.clone(); let api_key = self.state.read(cx).api_key.clone(); @@ -303,7 +303,7 @@ impl LanguageModel for GoogleLanguageModel { 'static, Result>>, > { - let request = request.into_google(self.model.id().to_string()); + let request = into_google(request, self.model.id().to_string()); let http_client = self.http_client.clone(); let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { @@ -341,6 +341,38 @@ impl LanguageModel for GoogleLanguageModel { } } +pub fn into_google( + request: LanguageModelRequest, + model: String, +) -> google_ai::GenerateContentRequest { + google_ai::GenerateContentRequest { + model, + contents: request + .messages + .into_iter() + .map(|msg| google_ai::Content { + parts: vec![google_ai::Part::TextPart(google_ai::TextPart { + text: msg.string_contents(), + })], + role: match msg.role { + Role::User => google_ai::Role::User, + Role::Assistant => google_ai::Role::Model, + Role::System => google_ai::Role::User, // Google AI doesn't have a system role + }, + }) + .collect(), + generation_config: Some(google_ai::GenerationConfig { + candidate_count: Some(1), + stop_sequences: Some(request.stop), + max_output_tokens: None, + temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), + top_p: None, + top_k: None, + }), + safety_settings: None, + } +} + pub fn count_google_tokens( request: LanguageModelRequest, cx: &App, diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 80a5988cff..55a6413ef6 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -334,7 +334,11 @@ impl LanguageModel for MistralLanguageModel { request: LanguageModelRequest, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { - let request = request.into_mistral(self.model.id().to_string(), self.max_output_tokens()); + let request = into_mistral( + request, + self.model.id().to_string(), + self.max_output_tokens(), + ); let stream = self.stream_completion(request, cx); async move { @@ -369,7 +373,7 @@ impl LanguageModel for MistralLanguageModel { schema: serde_json::Value, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { - let mut request = request.into_mistral(self.model.id().into(), self.max_output_tokens()); + let mut request = into_mistral(request, self.model.id().into(), self.max_output_tokens()); request.tools = vec![mistral::ToolDefinition::Function { function: mistral::FunctionDefinition { name: tool_name.clone(), @@ -411,6 +415,52 @@ impl LanguageModel for MistralLanguageModel { } } +pub fn into_mistral( + request: LanguageModelRequest, + model: String, + max_output_tokens: Option, +) -> mistral::Request { + let len = request.messages.len(); + let merged_messages = + request + .messages + .into_iter() + .fold(Vec::with_capacity(len), |mut acc, msg| { + let role = msg.role; + let content = msg.string_contents(); + + acc.push(match role { + Role::User => mistral::RequestMessage::User { content }, + Role::Assistant => mistral::RequestMessage::Assistant { + content: Some(content), + tool_calls: Vec::new(), + }, + Role::System => mistral::RequestMessage::System { content }, + }); + acc + }); + + mistral::Request { + model, + messages: merged_messages, + stream: true, + max_tokens: max_output_tokens, + temperature: request.temperature, + response_format: None, + tools: request + .tools + .into_iter() + .map(|tool| mistral::ToolDefinition::Function { + function: mistral::FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), + } +} + struct ConfigurationView { api_key_editor: Entity, state: gpui::Entity, diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 3e46983ebb..c249af0bb7 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -318,7 +318,7 @@ impl LanguageModel for OpenAiLanguageModel { 'static, Result>>, > { - let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens()); + let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens()); let completions = self.stream_completion(request, cx); async move { Ok(open_ai::extract_text_from_events(completions.await?) @@ -336,7 +336,7 @@ impl LanguageModel for OpenAiLanguageModel { schema: serde_json::Value, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { - let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens()); + let mut request = into_open_ai(request, self.model.id().into(), self.max_output_tokens()); request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function { function: FunctionDefinition { name: tool_name.clone(), @@ -366,6 +366,39 @@ impl LanguageModel for OpenAiLanguageModel { } } +pub fn into_open_ai( + request: LanguageModelRequest, + model: String, + max_output_tokens: Option, +) -> open_ai::Request { + let stream = !model.starts_with("o1-"); + open_ai::Request { + model, + messages: request + .messages + .into_iter() + .map(|msg| match msg.role { + Role::User => open_ai::RequestMessage::User { + content: msg.string_contents(), + }, + Role::Assistant => open_ai::RequestMessage::Assistant { + content: Some(msg.string_contents()), + tool_calls: Vec::new(), + }, + Role::System => open_ai::RequestMessage::System { + content: msg.string_contents(), + }, + }) + .collect(), + stream, + stop: request.stop, + temperature: request.temperature.unwrap_or(1.0), + max_tokens: max_output_tokens, + tools: Vec::new(), + tool_choice: None, + } +} + pub fn count_open_ai_tokens( request: LanguageModelRequest, model: open_ai::Model,