diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 4c3fbe878e..2a4a00cf23 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1,8 +1,8 @@ use crate::context::{AssistantContext, ContextId, format_context_as_string}; use crate::context_picker::MentionLink; use crate::thread::{ - LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError, - ThreadEvent, ThreadFeedback, + LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent, + ThreadFeedback, }; use crate::thread_store::{RulesLoadingError, ThreadStore}; use crate::tool_use::{PendingToolUseStatus, ToolUse}; @@ -1285,7 +1285,7 @@ impl ActiveThread { self.thread.update(cx, |thread, cx| { thread.advance_prompt_id(); - thread.send_to_model(model.model, RequestKind::Chat, cx) + thread.send_to_model(model.model, cx) }); cx.notify(); } diff --git a/crates/agent/src/assistant.rs b/crates/agent/src/assistant.rs index 1f067af734..03e13d6f68 100644 --- a/crates/agent/src/assistant.rs +++ b/crates/agent/src/assistant.rs @@ -40,7 +40,7 @@ pub use crate::active_thread::ActiveThread; use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal}; pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate}; pub use crate::inline_assistant::InlineAssistant; -pub use crate::thread::{Message, RequestKind, Thread, ThreadEvent}; +pub use crate::thread::{Message, Thread, ThreadEvent}; pub use crate::thread_store::ThreadStore; pub use agent_diff::{AgentDiff, AgentDiffToolbar}; diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index d64c95b9b9..ac16df4c97 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -34,7 +34,7 @@ use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider}; use crate::context_store::{ContextStore, refresh_context_store_text}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::profile_selector::ProfileSelector; -use crate::thread::{RequestKind, Thread, TokenUsageRatio}; +use crate::thread::{Thread, TokenUsageRatio}; use crate::thread_store::ThreadStore; use crate::{ AgentDiff, Chat, ChatMode, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext, @@ -234,7 +234,7 @@ impl MessageEditor { } self.set_editor_is_expanded(false, cx); - self.send_to_model(RequestKind::Chat, window, cx); + self.send_to_model(window, cx); cx.notify(); } @@ -249,12 +249,7 @@ impl MessageEditor { .is_some() } - fn send_to_model( - &mut self, - request_kind: RequestKind, - window: &mut Window, - cx: &mut Context, - ) { + fn send_to_model(&mut self, window: &mut Window, cx: &mut Context) { let model_registry = LanguageModelRegistry::read_global(cx); let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else { return; @@ -331,7 +326,7 @@ impl MessageEditor { thread .update(cx, |thread, cx| { thread.advance_prompt_id(); - thread.send_to_model(model, request_kind, cx); + thread.send_to_model(model, cx); }) .log_err(); }) @@ -345,7 +340,7 @@ impl MessageEditor { if cancelled { self.set_editor_is_expanded(false, cx); - self.send_to_model(RequestKind::Chat, window, cx); + self.send_to_model(window, cx); } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index e8ca584fa0..17f7e20387 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -40,13 +40,6 @@ use crate::thread_store::{ }; use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER}; -#[derive(Debug, Clone, Copy)] -pub enum RequestKind { - Chat, - /// Used when summarizing a thread. - Summarize, -} - #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, )] @@ -949,13 +942,8 @@ impl Thread { }) } - pub fn send_to_model( - &mut self, - model: Arc, - request_kind: RequestKind, - cx: &mut Context, - ) { - let mut request = self.to_completion_request(request_kind, cx); + pub fn send_to_model(&mut self, model: Arc, cx: &mut Context) { + let mut request = self.to_completion_request(cx); if model.supports_tools() { request.tools = { let mut tools = Vec::new(); @@ -994,11 +982,7 @@ impl Thread { false } - pub fn to_completion_request( - &self, - request_kind: RequestKind, - cx: &mut Context, - ) -> LanguageModelRequest { + pub fn to_completion_request(&self, cx: &mut Context) -> LanguageModelRequest { let mut request = LanguageModelRequest { thread_id: Some(self.id.to_string()), prompt_id: Some(self.last_prompt_id.to_string()), @@ -1045,18 +1029,8 @@ impl Thread { cache: false, }; - match request_kind { - RequestKind::Chat => { - self.tool_use - .attach_tool_results(message.id, &mut request_message); - } - RequestKind::Summarize => { - // We don't care about tool use during summarization. - if self.tool_use.message_has_tool_results(message.id) { - continue; - } - } - } + self.tool_use + .attach_tool_results(message.id, &mut request_message); if !message.context.is_empty() { request_message @@ -1089,15 +1063,8 @@ impl Thread { }; } - match request_kind { - RequestKind::Chat => { - self.tool_use - .attach_tool_uses(message.id, &mut request_message); - } - RequestKind::Summarize => { - // We don't care about tool use during summarization. - } - }; + self.tool_use + .attach_tool_uses(message.id, &mut request_message); request.messages.push(request_message); } @@ -1112,6 +1079,54 @@ impl Thread { request } + fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest { + let mut request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + messages: vec![], + tools: Vec::new(), + stop: Vec::new(), + temperature: None, + }; + + for message in &self.messages { + let mut request_message = LanguageModelRequestMessage { + role: message.role, + content: Vec::new(), + cache: false, + }; + + // Skip tool results during summarization. + if self.tool_use.message_has_tool_results(message.id) { + continue; + } + + for segment in &message.segments { + match segment { + MessageSegment::Text(text) => request_message + .content + .push(MessageContent::Text(text.clone())), + MessageSegment::Thinking { .. } => {} + MessageSegment::RedactedThinking(_) => {} + } + } + + if request_message.content.is_empty() { + continue; + } + + request.messages.push(request_message); + } + + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text(added_user_message)], + cache: false, + }); + + request + } + fn attached_tracked_files_state( &self, messages: &mut Vec, @@ -1293,7 +1308,12 @@ impl Thread { .pending_completions .retain(|completion| completion.id != pending_completion_id); - if thread.summary.is_none() && thread.messages.len() >= 2 { + // If there is a response without tool use, summarize the message. Otherwise, + // allow two tool uses before summarizing. + if thread.summary.is_none() + && thread.messages.len() >= 2 + && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6) + { thread.summarize(cx); } })?; @@ -1403,18 +1423,12 @@ impl Thread { return; } - let mut request = self.to_completion_request(RequestKind::Summarize, cx); - request.messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec![ - "Generate a concise 3-7 word title for this conversation, omitting punctuation. \ - Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \ - If the conversation is about a specific subject, include it in the title. \ - Be descriptive. DO NOT speak in the first person." - .into(), - ], - cache: false, - }); + let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \ + Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \ + If the conversation is about a specific subject, include it in the title. \ + Be descriptive. DO NOT speak in the first person."; + + let request = self.to_summarize_request(added_user_message.into()); self.pending_summary = cx.spawn(async move |this, cx| { async move { @@ -1476,21 +1490,14 @@ impl Thread { return None; } - let mut request = self.to_completion_request(RequestKind::Summarize, cx); + let added_user_message = "Generate a detailed summary of this conversation. Include:\n\ + 1. A brief overview of what was discussed\n\ + 2. Key facts or information discovered\n\ + 3. Outcomes or conclusions reached\n\ + 4. Any action items or next steps if any\n\ + Format it in Markdown with headings and bullet points."; - request.messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec![ - "Generate a detailed summary of this conversation. Include:\n\ - 1. A brief overview of what was discussed\n\ - 2. Key facts or information discovered\n\ - 3. Outcomes or conclusions reached\n\ - 4. Any action items or next steps if any\n\ - Format it in Markdown with headings and bullet points." - .into(), - ], - cache: false, - }); + let request = self.to_summarize_request(added_user_message.into()); let task = cx.spawn(async move |thread, cx| { let stream = model.stream_completion_text(request, &cx); @@ -1538,7 +1545,7 @@ impl Thread { pub fn use_pending_tools(&mut self, cx: &mut Context) -> Vec { self.auto_capture_telemetry(cx); - let request = self.to_completion_request(RequestKind::Chat, cx); + let request = self.to_completion_request(cx); let messages = Arc::new(request.messages); let pending_tool_uses = self .tool_use @@ -1650,7 +1657,7 @@ impl Thread { if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() { self.attach_tool_results(cx); if !canceled { - self.send_to_model(model, RequestKind::Chat, cx); + self.send_to_model(model, cx); } } } @@ -2275,9 +2282,7 @@ fn main() {{ assert_eq!(message.context, expected_context); // Check message in request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); assert_eq!(request.messages.len(), 2); let expected_full_message = format!("{}Please explain this code", expected_context); @@ -2367,9 +2372,7 @@ fn main() {{ assert!(message3.context.contains("file3.rs")); // Check entire request to make sure all contexts are properly included - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); // The request should contain all 3 messages assert_eq!(request.messages.len(), 4); @@ -2419,9 +2422,7 @@ fn main() {{ assert_eq!(message.context, ""); // Check message in request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); assert_eq!(request.messages.len(), 2); assert_eq!( @@ -2439,9 +2440,7 @@ fn main() {{ assert_eq!(message2.context, ""); // Check that both messages appear in the request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); assert_eq!(request.messages.len(), 3); assert_eq!( @@ -2481,9 +2480,7 @@ fn main() {{ }); // Create a request and check that it doesn't have a stale buffer warning yet - let initial_request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); // Make sure we don't have a stale file warning yet let has_stale_warning = initial_request.messages.iter().any(|msg| { @@ -2511,9 +2508,7 @@ fn main() {{ }); // Create a new request and check for the stale buffer warning - let new_request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); // We should have a stale file warning as the last message let last_message = new_request diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 684feaca3b..e7edb0e086 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -74,6 +74,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Self::Claude3_5Haiku + } + pub fn from_id(id: &str) -> Result { if id.starts_with("claude-3-5-sonnet") { Ok(Self::Claude3_5Sonnet) diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index 052e5c2ca1..8ead77f9c4 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -84,6 +84,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Self::Claude3_5Haiku + } + pub fn from_id(id: &str) -> anyhow::Result { if id.starts_with("claude-3-5-sonnet-v2") { Ok(Self::Claude3_5SonnetV2) diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index 255c39cb84..2bcb82c1ee 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -61,6 +61,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Self::Claude3_7Sonnet + } + pub fn uses_streaming(&self) -> bool { match self { Self::Gpt4o diff --git a/crates/deepseek/src/deepseek.rs b/crates/deepseek/src/deepseek.rs index 07f6a959e1..9c19f1ae2f 100644 --- a/crates/deepseek/src/deepseek.rs +++ b/crates/deepseek/src/deepseek.rs @@ -64,6 +64,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Model::Chat + } + pub fn from_id(id: &str) -> Result { match id { "deepseek-chat" => Ok(Self::Chat), diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 982daeaed7..78b7eb7af9 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -1,4 +1,4 @@ -use agent::{RequestKind, ThreadEvent, ThreadStore}; +use agent::{ThreadEvent, ThreadStore}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::ToolWorkingSet; use client::proto::LspWorkProgress; @@ -472,7 +472,7 @@ impl Example { thread.update(cx, |thread, cx| { let context = vec![]; thread.insert_user_message(this.prompt.clone(), context, None, cx); - thread.send_to_model(model, RequestKind::Chat, cx); + thread.send_to_model(model, cx); })?; event_handler_task.await?; diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 09278d6ed2..e26750936d 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -412,6 +412,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Model { + Model::Gemini15Flash + } + pub fn id(&self) -> &str { match self { Model::Gemini15Pro => "gemini-1.5-pro", diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index 56df184d36..25f2a496e7 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -49,6 +49,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider { Some(Arc::new(FakeLanguageModel::default())) } + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(Arc::new(FakeLanguageModel::default())) + } + fn provided_models(&self, _: &App) -> Vec> { vec![Arc::new(FakeLanguageModel::default())] } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 206958e82f..71d8551bd5 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -370,6 +370,7 @@ pub trait LanguageModelProvider: 'static { IconName::ZedAssistant } fn default_model(&self, cx: &App) -> Option>; + fn default_fast_model(&self, cx: &App) -> Option>; fn provided_models(&self, cx: &App) -> Vec>; fn recommended_models(&self, _cx: &App) -> Vec> { Vec::new() diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 45be22457f..62f216094b 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -5,6 +5,7 @@ use crate::{ use collections::BTreeMap; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; use std::sync::Arc; +use util::maybe; pub fn init(cx: &mut App) { let registry = cx.new(|_cx| LanguageModelRegistry::default()); @@ -18,6 +19,7 @@ impl Global for GlobalLanguageModelRegistry {} #[derive(Default)] pub struct LanguageModelRegistry { default_model: Option, + default_fast_model: Option, inline_assistant_model: Option, commit_message_model: Option, thread_summary_model: Option, @@ -202,6 +204,14 @@ impl LanguageModelRegistry { (None, None) => {} _ => cx.emit(Event::DefaultModelChanged), } + self.default_fast_model = maybe!({ + let provider = &model.as_ref()?.provider; + let fast_model = provider.default_fast_model(cx)?; + Some(ConfiguredModel { + provider: provider.clone(), + model: fast_model, + }) + }); self.default_model = model; } @@ -254,21 +264,37 @@ impl LanguageModelRegistry { } pub fn inline_assistant_model(&self) -> Option { + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { + return None; + } + self.inline_assistant_model .clone() - .or_else(|| self.default_model()) + .or_else(|| self.default_model.clone()) } pub fn commit_message_model(&self) -> Option { + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { + return None; + } + self.commit_message_model .clone() - .or_else(|| self.default_model()) + .or_else(|| self.default_model.clone()) } pub fn thread_summary_model(&self) -> Option { + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { + return None; + } + self.thread_summary_model .clone() - .or_else(|| self.default_model()) + .or_else(|| self.default_fast_model.clone()) + .or_else(|| self.default_model.clone()) } /// The models to use for inline assists. Returns the union of the active diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 6a29976504..f998969bfe 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -201,7 +201,7 @@ impl AnthropicLanguageModelProvider { state: self.state.clone(), http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), - }) as Arc + }) } } @@ -227,14 +227,11 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = anthropic::Model::default(); - Some(Arc::new(AnthropicModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(anthropic::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(anthropic::Model::default_fast())) } fn recommended_models(&self, _cx: &App) -> Vec> { diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index c4ef48404f..a2748b45be 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -286,6 +286,18 @@ impl BedrockLanguageModelProvider { state, } } + + fn create_language_model(&self, model: bedrock::Model) -> Arc { + Arc::new(BedrockModel { + id: LanguageModelId::from(model.id().to_string()), + model, + http_client: self.http_client.clone(), + handler: self.handler.clone(), + state: self.state.clone(), + client: OnceCell::new(), + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProvider for BedrockLanguageModelProvider { @@ -302,16 +314,11 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = bedrock::Model::default(); - Some(Arc::new(BedrockModel { - id: LanguageModelId::from(model.id().to_string()), - model, - http_client: self.http_client.clone(), - handler: self.handler.clone(), - state: self.state.clone(), - client: OnceCell::new(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(bedrock::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(bedrock::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { @@ -343,17 +350,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { models .into_values() - .map(|model| { - Arc::new(BedrockModel { - id: LanguageModelId::from(model.id().to_string()), - model, - http_client: self.http_client.clone(), - handler: self.handler.clone(), - state: self.state.clone(), - client: OnceCell::new(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 80c8d0dcc3..b9911d5d46 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -242,7 +242,7 @@ impl CloudLanguageModelProvider { llm_api_token: llm_api_token.clone(), client: self.client.clone(), request_limiter: RateLimiter::new(4), - }) as Arc + }) } } @@ -270,13 +270,13 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn default_model(&self, cx: &App) -> Option> { let llm_api_token = self.state.read(cx).llm_api_token.clone(); let model = CloudModel::Anthropic(anthropic::Model::default()); - Some(Arc::new(CloudLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - llm_api_token: llm_api_token.clone(), - client: self.client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(model, llm_api_token)) + } + + fn default_fast_model(&self, cx: &App) -> Option> { + let llm_api_token = self.state.read(cx).llm_api_token.clone(); + let model = CloudModel::Anthropic(anthropic::Model::default_fast()); + Some(self.create_language_model(model, llm_api_token)) } fn recommended_models(&self, cx: &App) -> Vec> { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 3d4924b890..255de2d536 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -70,6 +70,13 @@ impl CopilotChatLanguageModelProvider { Self { state } } + + fn create_language_model(&self, model: CopilotChatModel) -> Arc { + Arc::new(CopilotChatLanguageModel { + model, + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProviderState for CopilotChatLanguageModelProvider { @@ -94,21 +101,16 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = CopilotChatModel::default(); - Some(Arc::new(CopilotChatLanguageModel { - model, - request_limiter: RateLimiter::new(4), - }) as Arc) + Some(self.create_language_model(CopilotChatModel::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(CopilotChatModel::default_fast())) } fn provided_models(&self, _cx: &App) -> Vec> { CopilotChatModel::iter() - .map(|model| { - Arc::new(CopilotChatLanguageModel { - model, - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index e4f1cd830a..9989e4c6b1 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -140,6 +140,16 @@ impl DeepSeekLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: deepseek::Model) -> Arc { + Arc::new(DeepSeekLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) as Arc + } } impl LanguageModelProviderState for DeepSeekLanguageModelProvider { @@ -164,14 +174,11 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = deepseek::Model::Chat; - Some(Arc::new(DeepSeekLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(deepseek::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(deepseek::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { @@ -198,15 +205,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { models .into_values() - .map(|model| { - Arc::new(DeepSeekLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index c754e63bbd..bbe6c58353 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -150,6 +150,16 @@ impl GoogleLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: google_ai::Model) -> Arc { + Arc::new(GoogleLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProviderState for GoogleLanguageModelProvider { @@ -174,14 +184,11 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = google_ai::Model::default(); - Some(Arc::new(GoogleLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(google_ai::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(google_ai::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index bd8b6303f8..2f5ae9ebb6 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -157,6 +157,10 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { self.provided_models(cx).into_iter().next() } + fn default_fast_model(&self, cx: &App) -> Option> { + self.default_model(cx) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index b9017398e6..a5009c76a6 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -144,6 +144,16 @@ impl MistralLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: mistral::Model) -> Arc { + Arc::new(MistralLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProviderState for MistralLanguageModelProvider { @@ -168,14 +178,11 @@ impl LanguageModelProvider for MistralLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = mistral::Model::default(); - Some(Arc::new(MistralLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(mistral::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(mistral::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 465dc0d659..17c50c8eaf 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -162,6 +162,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { self.provided_models(cx).into_iter().next() } + fn default_fast_model(&self, cx: &App) -> Option> { + self.default_model(cx) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 020c642520..188a219e2d 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -148,6 +148,16 @@ impl OpenAiLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: open_ai::Model) -> Arc { + Arc::new(OpenAiLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProviderState for OpenAiLanguageModelProvider { @@ -172,14 +182,11 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = open_ai::Model::default(); - Some(Arc::new(OpenAiLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(open_ai::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(open_ai::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { @@ -211,15 +218,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { models .into_values() - .map(|model| { - Arc::new(OpenAiLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs index de2457e0bf..27de5fccc2 100644 --- a/crates/mistral/src/mistral.rs +++ b/crates/mistral/src/mistral.rs @@ -69,6 +69,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Model::MistralSmallLatest + } + pub fn from_id(id: &str) -> Result { match id { "codestral-latest" => Ok(Self::CodestralLatest), diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 9284b4a9b2..f9e0b7d4e3 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -102,6 +102,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Self::FourPointOneMini + } + pub fn from_id(id: &str) -> Result { match id { "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),