diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 4005f27a0c..4995ddb9df 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -33,13 +33,23 @@ pub struct UserMessage { pub id: Option, pub content: ContentBlock, pub chunks: Vec, - pub checkpoint: Option, + pub checkpoint: Option, +} + +#[derive(Debug)] +pub struct Checkpoint { + git_checkpoint: GitStoreCheckpoint, + pub show: bool, } impl UserMessage { fn to_markdown(&self, cx: &App) -> String { let mut markdown = String::new(); - if let Some(_) = self.checkpoint { + if self + .checkpoint + .as_ref() + .map_or(false, |checkpoint| checkpoint.show) + { writeln!(markdown, "## User (checkpoint)").unwrap(); } else { writeln!(markdown, "## User").unwrap(); @@ -1145,9 +1155,12 @@ impl AcpThread { self.project.read(cx).languages().clone(), cx, ); + let request = acp::PromptRequest { + prompt: message.clone(), + session_id: self.session_id.clone(), + }; let git_store = self.project.read(cx).git_store().clone(); - let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx)); let message_id = if self .connection .session_editor(&self.session_id, cx) @@ -1161,68 +1174,63 @@ impl AcpThread { AgentThreadEntry::UserMessage(UserMessage { id: message_id.clone(), content: block, - chunks: message.clone(), + chunks: message, checkpoint: None, }), cx, ); + + self.run_turn(cx, async move |this, cx| { + let old_checkpoint = git_store + .update(cx, |git, cx| git.checkpoint(cx))? + .await + .context("failed to get old checkpoint") + .log_err(); + this.update(cx, |this, cx| { + if let Some((_ix, message)) = this.last_user_message() { + message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint { + git_checkpoint, + show: false, + }); + } + this.connection.prompt(message_id, request, cx) + })? + .await + }) + } + + pub fn resume(&mut self, cx: &mut Context) -> BoxFuture<'static, Result<()>> { + self.run_turn(cx, async move |this, cx| { + this.update(cx, |this, cx| { + this.connection + .resume(&this.session_id, cx) + .map(|resume| resume.run(cx)) + })? + .context("resuming a session is not supported")? + .await + }) + } + + fn run_turn( + &mut self, + cx: &mut Context, + f: impl 'static + AsyncFnOnce(WeakEntity, &mut AsyncApp) -> Result, + ) -> BoxFuture<'static, Result<()>> { self.clear_completed_plan_entries(cx); - let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel(); let (tx, rx) = oneshot::channel(); let cancel_task = self.cancel(cx); - let request = acp::PromptRequest { - prompt: message, - session_id: self.session_id.clone(), - }; - self.send_task = Some(cx.spawn({ - let message_id = message_id.clone(); - async move |this, cx| { - cancel_task.await; - - old_checkpoint_tx.send(old_checkpoint.await).ok(); - if let Ok(result) = this.update(cx, |this, cx| { - this.connection.prompt(message_id, request, cx) - }) { - tx.send(result.await).log_err(); - } - } + self.send_task = Some(cx.spawn(async move |this, cx| { + cancel_task.await; + tx.send(f(this, cx).await).ok(); })); cx.spawn(async move |this, cx| { - let old_checkpoint = old_checkpoint_rx - .await - .map_err(|_| anyhow!("send canceled")) - .flatten() - .context("failed to get old checkpoint") - .log_err(); - let response = rx.await; - if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) { - let new_checkpoint = git_store - .update(cx, |git, cx| git.checkpoint(cx))? - .await - .context("failed to get new checkpoint") - .log_err(); - if let Some(new_checkpoint) = new_checkpoint { - let equal = git_store - .update(cx, |git, cx| { - git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx) - })? - .await - .unwrap_or(true); - if !equal { - this.update(cx, |this, cx| { - if let Some((ix, message)) = this.user_message_mut(&message_id) { - message.checkpoint = Some(old_checkpoint); - cx.emit(AcpThreadEvent::EntryUpdated(ix)); - } - })?; - } - } - } + this.update(cx, |this, cx| this.update_last_checkpoint(cx))? + .await?; this.update(cx, |this, cx| { match response { @@ -1294,7 +1302,10 @@ impl AcpThread { return Task::ready(Err(anyhow!("message not found"))); }; - let checkpoint = message.checkpoint.clone(); + let checkpoint = message + .checkpoint + .as_ref() + .map(|c| c.git_checkpoint.clone()); let git_store = self.project.read(cx).git_store().clone(); cx.spawn(async move |this, cx| { @@ -1316,6 +1327,59 @@ impl AcpThread { }) } + fn update_last_checkpoint(&mut self, cx: &mut Context) -> Task> { + let git_store = self.project.read(cx).git_store().clone(); + + let old_checkpoint = if let Some((_, message)) = self.last_user_message() { + if let Some(checkpoint) = message.checkpoint.as_ref() { + checkpoint.git_checkpoint.clone() + } else { + return Task::ready(Ok(())); + } + } else { + return Task::ready(Ok(())); + }; + + let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx)); + cx.spawn(async move |this, cx| { + let new_checkpoint = new_checkpoint + .await + .context("failed to get new checkpoint") + .log_err(); + if let Some(new_checkpoint) = new_checkpoint { + let equal = git_store + .update(cx, |git, cx| { + git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx) + })? + .await + .unwrap_or(true); + this.update(cx, |this, cx| { + let (ix, message) = this.last_user_message().context("no user message")?; + let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?; + checkpoint.show = !equal; + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + anyhow::Ok(()) + })??; + } + + Ok(()) + }) + } + + fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> { + self.entries + .iter_mut() + .enumerate() + .rev() + .find_map(|(ix, entry)| { + if let AgentThreadEntry::UserMessage(message) = entry { + Some((ix, message)) + } else { + None + } + }) + } + fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> { self.entries.iter().find_map(|entry| { if let AgentThreadEntry::UserMessage(message) = entry { @@ -1552,6 +1616,7 @@ mod tests { use settings::SettingsStore; use smol::stream::StreamExt as _; use std::{ + any::Any, cell::RefCell, path::Path, rc::Rc, @@ -2284,6 +2349,10 @@ mod tests { _session_id: session_id.clone(), })) } + + fn into_any(self: Rc) -> Rc { + self + } } struct FakeAgentSessionEditor { diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 0f531acbde..b2116020fb 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -4,7 +4,7 @@ use anyhow::Result; use collections::IndexMap; use gpui::{Entity, SharedString, Task}; use project::Project; -use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; +use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use ui::{App, IconName}; use uuid::Uuid; @@ -36,6 +36,14 @@ pub trait AgentConnection { cx: &mut App, ) -> Task>; + fn resume( + &self, + _session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { + None + } + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); fn session_editor( @@ -53,12 +61,24 @@ pub trait AgentConnection { fn model_selector(&self) -> Option> { None } + + fn into_any(self: Rc) -> Rc; +} + +impl dyn AgentConnection { + pub fn downcast(self: Rc) -> Option> { + self.into_any().downcast().ok() + } } pub trait AgentSessionEditor { fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task>; } +pub trait AgentSessionResume { + fn run(&self, cx: &mut App) -> Task>; +} + #[derive(Debug)] pub struct AuthRequired; @@ -299,6 +319,10 @@ mod test_support { ) -> Option> { Some(Rc::new(StubAgentSessionEditor)) } + + fn into_any(self: Rc) -> Rc { + self + } } struct StubAgentSessionEditor; diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 9ac3c2d0e5..358365d11f 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,9 +1,8 @@ -use crate::{AgentResponseEvent, Thread, templates::Templates}; use crate::{ - ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool, - EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, - OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent, - WebSearchTool, + AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, + DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, + MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, + ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, }; use acp_thread::AgentModelSelector; use agent_client_protocol as acp; @@ -11,6 +10,7 @@ use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; use collections::{HashSet, IndexMap}; use fs::Fs; +use futures::channel::mpsc; use futures::{StreamExt, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, @@ -21,6 +21,7 @@ use prompt_store::{ ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext, }; use settings::update_settings_file; +use std::any::Any; use std::cell::RefCell; use std::collections::HashMap; use std::path::Path; @@ -426,9 +427,9 @@ impl NativeAgent { self.models.refresh_list(cx); for session in self.sessions.values_mut() { session.thread.update(cx, |thread, _| { - let model_id = LanguageModels::model_id(&thread.selected_model); + let model_id = LanguageModels::model_id(&thread.model()); if let Some(model) = self.models.model_from_id(&model_id) { - thread.selected_model = model.clone(); + thread.set_model(model.clone()); } }); } @@ -439,6 +440,124 @@ impl NativeAgent { #[derive(Clone)] pub struct NativeAgentConnection(pub Entity); +impl NativeAgentConnection { + pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option> { + self.0 + .read(cx) + .sessions + .get(session_id) + .map(|session| session.thread.clone()) + } + + fn run_turn( + &self, + session_id: acp::SessionId, + cx: &mut App, + f: impl 'static + + FnOnce( + Entity, + &mut App, + ) -> Result>>, + ) -> Task> { + let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { + agent + .sessions + .get_mut(&session_id) + .map(|s| (s.thread.clone(), s.acp_thread.clone())) + }) else { + return Task::ready(Err(anyhow!("Session not found"))); + }; + log::debug!("Found session for: {}", session_id); + + let mut response_stream = match f(thread, cx) { + Ok(stream) => stream, + Err(err) => return Task::ready(Err(err)), + }; + cx.spawn(async move |cx| { + // Handle response stream and forward to session.acp_thread + while let Some(result) = response_stream.next().await { + match result { + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + + match event { + AgentResponseEvent::Text(text) => { + acp_thread.update(cx, |thread, cx| { + thread.push_assistant_content_block( + acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + false, + cx, + ) + })?; + } + AgentResponseEvent::Thinking(text) => { + acp_thread.update(cx, |thread, cx| { + thread.push_assistant_content_block( + acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + true, + cx, + ) + })?; + } + AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { + tool_call, + options, + response, + }) => { + let recv = acp_thread.update(cx, |thread, cx| { + thread.request_tool_call_authorization(tool_call, options, cx) + })?; + cx.background_spawn(async move { + if let Some(option) = recv + .await + .context("authorization sender was dropped") + .log_err() + { + response + .send(option) + .map(|_| anyhow!("authorization receiver was dropped")) + .log_err(); + } + }) + .detach(); + } + AgentResponseEvent::ToolCall(tool_call) => { + acp_thread.update(cx, |thread, cx| { + thread.upsert_tool_call(tool_call, cx) + })?; + } + AgentResponseEvent::ToolCallUpdate(update) => { + acp_thread.update(cx, |thread, cx| { + thread.update_tool_call(update, cx) + })??; + } + AgentResponseEvent::Stop(stop_reason) => { + log::debug!("Assistant message complete: {:?}", stop_reason); + return Ok(acp::PromptResponse { stop_reason }); + } + } + } + Err(e) => { + log::error!("Error in model response stream: {:?}", e); + return Err(e); + } + } + } + + log::info!("Response stream completed"); + anyhow::Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + }) + } +} + impl AgentModelSelector for NativeAgentConnection { fn list_models(&self, cx: &mut App) -> Task> { log::debug!("NativeAgentConnection::list_models called"); @@ -472,7 +591,7 @@ impl AgentModelSelector for NativeAgentConnection { }; thread.update(cx, |thread, _cx| { - thread.selected_model = model.clone(); + thread.set_model(model.clone()); }); update_settings_file::( @@ -502,7 +621,7 @@ impl AgentModelSelector for NativeAgentConnection { else { return Task::ready(Err(anyhow!("Session not found"))); }; - let model = thread.read(cx).selected_model.clone(); + let model = thread.read(cx).model().clone(); let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id()) else { return Task::ready(Err(anyhow!("Provider not found"))); @@ -644,25 +763,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection { ) -> Task> { let id = id.expect("UserMessageId is required"); let session_id = params.session_id.clone(); - let agent = self.0.clone(); log::info!("Received prompt request for session: {}", session_id); log::debug!("Prompt blocks count: {}", params.prompt.len()); - cx.spawn(async move |cx| { - // Get session - let (thread, acp_thread) = agent - .update(cx, |agent, _| { - agent - .sessions - .get_mut(&session_id) - .map(|s| (s.thread.clone(), s.acp_thread.clone())) - })? - .ok_or_else(|| { - log::error!("Session not found: {}", session_id); - anyhow::anyhow!("Session not found") - })?; - log::debug!("Found session for: {}", session_id); - + self.run_turn(session_id, cx, |thread, cx| { let content: Vec = params .prompt .into_iter() @@ -672,99 +776,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection { log::debug!("Message id: {:?}", id); log::debug!("Message content: {:?}", content); - // Get model using the ModelSelector capability (always available for agent2) - // Get the selected model from the thread directly - let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; - - // Send to thread - log::info!("Sending message to thread with model: {:?}", model.name()); - let mut response_stream = - thread.update(cx, |thread, cx| thread.send(id, content, cx))?; - - // Handle response stream and forward to session.acp_thread - while let Some(result) = response_stream.next().await { - match result { - Ok(event) => { - log::trace!("Received completion event: {:?}", event); - - match event { - AgentResponseEvent::Text(text) => { - acp_thread.update(cx, |thread, cx| { - thread.push_assistant_content_block( - acp::ContentBlock::Text(acp::TextContent { - text, - annotations: None, - }), - false, - cx, - ) - })?; - } - AgentResponseEvent::Thinking(text) => { - acp_thread.update(cx, |thread, cx| { - thread.push_assistant_content_block( - acp::ContentBlock::Text(acp::TextContent { - text, - annotations: None, - }), - true, - cx, - ) - })?; - } - AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { - tool_call, - options, - response, - }) => { - let recv = acp_thread.update(cx, |thread, cx| { - thread.request_tool_call_authorization(tool_call, options, cx) - })?; - cx.background_spawn(async move { - if let Some(option) = recv - .await - .context("authorization sender was dropped") - .log_err() - { - response - .send(option) - .map(|_| anyhow!("authorization receiver was dropped")) - .log_err(); - } - }) - .detach(); - } - AgentResponseEvent::ToolCall(tool_call) => { - acp_thread.update(cx, |thread, cx| { - thread.upsert_tool_call(tool_call, cx) - })?; - } - AgentResponseEvent::ToolCallUpdate(update) => { - acp_thread.update(cx, |thread, cx| { - thread.update_tool_call(update, cx) - })??; - } - AgentResponseEvent::Stop(stop_reason) => { - log::debug!("Assistant message complete: {:?}", stop_reason); - return Ok(acp::PromptResponse { stop_reason }); - } - } - } - Err(e) => { - log::error!("Error in model response stream: {:?}", e); - // TODO: Consider sending an error message to the UI - break; - } - } - } - - log::info!("Response stream completed"); - anyhow::Ok(acp::PromptResponse { - stop_reason: acp::StopReason::EndTurn, - }) + Ok(thread.update(cx, |thread, cx| { + log::info!( + "Sending message to thread with model: {:?}", + thread.model().name() + ); + thread.send(id, content, cx) + })) }) } + fn resume( + &self, + session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { + Some(Rc::new(NativeAgentSessionResume { + connection: self.clone(), + session_id: session_id.clone(), + }) as _) + } + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { log::info!("Cancelling on session: {}", session_id); self.0.update(cx, |agent, cx| { @@ -786,6 +818,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _) }) } + + fn into_any(self: Rc) -> Rc { + self + } } struct NativeAgentSessionEditor(Entity); @@ -796,6 +832,20 @@ impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { } } +struct NativeAgentSessionResume { + connection: NativeAgentConnection, + session_id: acp::SessionId, +} + +impl acp_thread::AgentSessionResume for NativeAgentSessionResume { + fn run(&self, cx: &mut App) -> Task> { + self.connection + .run_turn(self.session_id.clone(), cx, |thread, cx| { + thread.update(cx, |thread, cx| thread.resume(cx)) + }) + } +} + #[cfg(test)] mod tests { use super::*; @@ -957,7 +1007,7 @@ mod tests { agent.read_with(cx, |agent, _| { let session = agent.sessions.get(&session_id).unwrap(); session.thread.read_with(cx, |thread, _| { - assert_eq!(thread.selected_model.id().0, "fake"); + assert_eq!(thread.model().id().0, "fake"); }); }); diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 1df664c029..cf90c8f650 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -12,9 +12,9 @@ use gpui::{ }; use indoc::indoc; use language_model::{ - LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason, - fake_provider::FakeLanguageModel, + LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, + LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent, + Role, StopReason, fake_provider::FakeLanguageModel, }; use project::Project; use prompt_store::ProjectContext; @@ -394,8 +394,194 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) { assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed)); } +#[gpui::test] +async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events = thread.update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send(UserMessageId::new(), ["abc"], cx) + }); + cx.run_until_parked(); + let tool_use = LanguageModelToolUse { + id: "tool_id_1".into(), + name: EchoTool.name().into(), + raw_input: "{}".into(), + input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(), + is_input_complete: true, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); + fake_model.end_last_completion_stream(); + + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + let tool_result = LanguageModelToolResult { + tool_use_id: "tool_id_1".into(), + tool_name: EchoTool.name().into(), + is_error: false, + content: "def".into(), + output: Some("def".into()), + }; + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["abc".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use.clone())], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result.clone())], + cache: false + }, + ] + ); + + // Simulate reaching tool use limit. + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate( + cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached, + )); + fake_model.end_last_completion_stream(); + let last_event = events.collect::>().await.pop().unwrap(); + assert!( + last_event + .unwrap_err() + .is::() + ); + + let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap(); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["abc".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Continue where you left off".into()], + cache: false + } + ] + ); + + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into())); + fake_model.end_last_completion_stream(); + events.collect::>().await; + thread.read_with(cx, |thread, _cx| { + assert_eq!( + thread.last_message().unwrap().to_markdown(), + indoc! {" + ## Assistant + + Done + "} + ) + }); + + // Ensure we error if calling resume when tool use limit was *not* reached. + let error = thread + .update(cx, |thread, cx| thread.resume(cx)) + .unwrap_err(); + assert_eq!( + error.to_string(), + "can only resume after tool use limit is reached" + ) +} + +#[gpui::test] +async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events = thread.update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send(UserMessageId::new(), ["abc"], cx) + }); + cx.run_until_parked(); + + let tool_use = LanguageModelToolUse { + id: "tool_id_1".into(), + name: EchoTool.name().into(), + raw_input: "{}".into(), + input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(), + is_input_complete: true, + }; + let tool_result = LanguageModelToolResult { + tool_use_id: "tool_id_1".into(), + tool_name: EchoTool.name().into(), + is_error: false, + content: "def".into(), + output: Some("def".into()), + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate( + cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached, + )); + fake_model.end_last_completion_stream(); + let last_event = events.collect::>().await.pop().unwrap(); + assert!( + last_event + .unwrap_err() + .is::() + ); + + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), vec!["ghi"], cx) + }); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["abc".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["ghi".into()], + cache: false + } + ] + ); +} + async fn expect_tool_call( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> acp::ToolCall { let event = events .next() @@ -411,7 +597,7 @@ async fn expect_tool_call( } async fn expect_tool_call_update_fields( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> acp::ToolCallUpdate { let event = events .next() @@ -429,7 +615,7 @@ async fn expect_tool_call_update_fields( } async fn next_tool_call_authorization( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> ToolCallAuthorization { loop { let event = events @@ -1007,9 +1193,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { } /// Filters out the stop events for asserting against in tests -fn stop_events( - result_events: Vec>, -) -> Vec { +fn stop_events(result_events: Vec>) -> Vec { result_events .into_iter() .filter_map(|event| match event.unwrap() { diff --git a/crates/agent2/src/tests/test_tools.rs b/crates/agent2/src/tests/test_tools.rs index 7c7b81f52f..cbff44cedf 100644 --- a/crates/agent2/src/tests/test_tools.rs +++ b/crates/agent2/src/tests/test_tools.rs @@ -7,7 +7,7 @@ use std::future; #[derive(JsonSchema, Serialize, Deserialize)] pub struct EchoToolInput { /// The text to echo. - text: String, + pub text: String, } pub struct EchoTool; diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 260aaaf550..231ee92dda 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -2,10 +2,10 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; use agent_client_protocol as acp; -use agent_settings::{AgentProfileId, AgentSettings}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; -use cloud_llm_client::{CompletionIntent, CompletionMode}; +use cloud_llm_client::{CompletionIntent, CompletionRequestStatus}; use collections::IndexMap; use fs::Fs; use futures::{ @@ -14,10 +14,10 @@ use futures::{ }; use gpui::{App, Context, Entity, SharedString, Task}; use language_model::{ - LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, - LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, - LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, + LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, + LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, + LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, }; use project::Project; use prompt_store::ProjectContext; @@ -33,6 +33,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock}; pub enum Message { User(UserMessage), Agent(AgentMessage), + Resume, } impl Message { @@ -47,6 +48,7 @@ impl Message { match self { Message::User(message) => message.to_markdown(), Message::Agent(message) => message.to_markdown(), + Message::Resume => "[resumed after tool use limit was reached]".into(), } } } @@ -320,7 +322,11 @@ impl AgentMessage { } pub fn to_request(&self) -> Vec { - let mut content = Vec::with_capacity(self.content.len()); + let mut assistant_message = LanguageModelRequestMessage { + role: Role::Assistant, + content: Vec::with_capacity(self.content.len()), + cache: false, + }; for chunk in &self.content { let chunk = match chunk { AgentMessageContent::Text(text) => { @@ -342,29 +348,30 @@ impl AgentMessage { language_model::MessageContent::Image(value.clone()) } }; - content.push(chunk); + assistant_message.content.push(chunk); } - let mut messages = vec![LanguageModelRequestMessage { - role: Role::Assistant, - content, + let mut user_message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::new(), cache: false, - }]; + }; - if !self.tool_results.is_empty() { - let mut tool_results = Vec::with_capacity(self.tool_results.len()); - for tool_result in self.tool_results.values() { - tool_results.push(language_model::MessageContent::ToolResult( + for tool_result in self.tool_results.values() { + user_message + .content + .push(language_model::MessageContent::ToolResult( tool_result.clone(), )); - } - messages.push(LanguageModelRequestMessage { - role: Role::User, - content: tool_results, - cache: false, - }); } + let mut messages = Vec::new(); + if !assistant_message.content.is_empty() { + messages.push(assistant_message); + } + if !user_message.content.is_empty() { + messages.push(user_message); + } messages } } @@ -413,11 +420,12 @@ pub struct Thread { running_turn: Option>, pending_message: Option, tools: BTreeMap>, + tool_use_limit_reached: bool, context_server_registry: Entity, profile_id: AgentProfileId, project_context: Rc>, templates: Arc, - pub selected_model: Arc, + model: Arc, project: Entity, action_log: Entity, } @@ -429,7 +437,7 @@ impl Thread { context_server_registry: Entity, action_log: Entity, templates: Arc, - default_model: Arc, + model: Arc, cx: &mut Context, ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); @@ -439,11 +447,12 @@ impl Thread { running_turn: None, pending_message: None, tools: BTreeMap::default(), + tool_use_limit_reached: false, context_server_registry, profile_id, project_context, templates, - selected_model: default_model, + model, project, action_log, } @@ -457,7 +466,19 @@ impl Thread { &self.action_log } - pub fn set_mode(&mut self, mode: CompletionMode) { + pub fn model(&self) -> &Arc { + &self.model + } + + pub fn set_model(&mut self, model: Arc) { + self.model = model; + } + + pub fn completion_mode(&self) -> CompletionMode { + self.completion_mode + } + + pub fn set_completion_mode(&mut self, mode: CompletionMode) { self.completion_mode = mode; } @@ -499,36 +520,59 @@ impl Thread { Ok(()) } + pub fn resume( + &mut self, + cx: &mut Context, + ) -> Result>> { + anyhow::ensure!( + self.tool_use_limit_reached, + "can only resume after tool use limit is reached" + ); + + self.messages.push(Message::Resume); + cx.notify(); + + log::info!("Total messages in thread: {}", self.messages.len()); + Ok(self.run_turn(cx)) + } + /// Sending a message results in the model streaming a response, which could include tool calls. /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. pub fn send( &mut self, - message_id: UserMessageId, + id: UserMessageId, content: impl IntoIterator, cx: &mut Context, - ) -> mpsc::UnboundedReceiver> + ) -> mpsc::UnboundedReceiver> where T: Into, { - let model = self.selected_model.clone(); + log::info!("Thread::send called with model: {:?}", self.model.name()); + let content = content.into_iter().map(Into::into).collect::>(); - log::info!("Thread::send called with model: {:?}", model.name()); log::debug!("Thread::send content: {:?}", content); + self.messages + .push(Message::User(UserMessage { id, content })); cx.notify(); - let (events_tx, events_rx) = - mpsc::unbounded::>(); - let event_stream = AgentResponseEventStream(events_tx); - self.messages.push(Message::User(UserMessage { - id: message_id.clone(), - content, - })); log::info!("Total messages in thread: {}", self.messages.len()); + self.run_turn(cx) + } + + fn run_turn( + &mut self, + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { + let model = self.model.clone(); + let (events_tx, events_rx) = mpsc::unbounded::>(); + let event_stream = AgentResponseEventStream(events_tx); + let message_ix = self.messages.len().saturating_sub(1); + self.tool_use_limit_reached = false; self.running_turn = Some(cx.spawn(async move |this, cx| { log::info!("Starting agent turn execution"); - let turn_result = async { + let turn_result: Result<()> = async { let mut completion_intent = CompletionIntent::UserPrompt; loop { log::debug!( @@ -543,13 +587,22 @@ impl Thread { let mut events = model.stream_completion(request, cx).await?; log::debug!("Stream completion started successfully"); + let mut tool_use_limit_reached = false; let mut tool_uses = FuturesUnordered::new(); while let Some(event) = events.next().await { match event? { + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::ToolUseLimitReached, + ) => { + tool_use_limit_reached = true; + } LanguageModelCompletionEvent::Stop(reason) => { event_stream.send_stop(reason); if reason == StopReason::Refusal { - this.update(cx, |this, _cx| this.truncate(message_id))??; + this.update(cx, |this, _cx| { + this.flush_pending_message(); + this.messages.truncate(message_ix); + })?; return Ok(()); } } @@ -567,12 +620,7 @@ impl Thread { } } - if tool_uses.is_empty() { - log::info!("No tool uses found, completing turn"); - return Ok(()); - } - log::info!("Found {} tool uses to execute", tool_uses.len()); - + let used_tools = tool_uses.is_empty(); while let Some(tool_result) = tool_uses.next().await { log::info!("Tool finished {:?}", tool_result); @@ -596,8 +644,17 @@ impl Thread { .ok(); } - this.update(cx, |this, _| this.flush_pending_message())?; - completion_intent = CompletionIntent::ToolResults; + if tool_use_limit_reached { + log::info!("Tool use limit reached, completing turn"); + this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?; + return Err(language_model::ToolUseLimitReachedError.into()); + } else if used_tools { + log::info!("No tool uses found, completing turn"); + return Ok(()); + } else { + this.update(cx, |this, _| this.flush_pending_message())?; + completion_intent = CompletionIntent::ToolResults; + } } } .await; @@ -678,10 +735,10 @@ impl Thread { fn handle_text_event( &mut self, new_text: String, - events_stream: &AgentResponseEventStream, + event_stream: &AgentResponseEventStream, cx: &mut Context, ) { - events_stream.send_text(&new_text); + event_stream.send_text(&new_text); let last_message = self.pending_message(); if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() { @@ -798,8 +855,9 @@ impl Thread { status: Some(acp::ToolCallStatus::InProgress), ..Default::default() }); - let supports_images = self.selected_model.supports_images(); + let supports_images = self.model.supports_images(); let tool_result = tool.run(tool_use.input, tool_event_stream, cx); + log::info!("Running tool {}", tool_use.name); Some(cx.foreground_executor().spawn(async move { let tool_result = tool_result.await.and_then(|output| { if let LanguageModelToolResultContent::Image(_) = &output.llm_output { @@ -902,7 +960,7 @@ impl Thread { name: tool_name, description: tool.description().to_string(), input_schema: tool - .input_schema(self.selected_model.tool_input_format()) + .input_schema(self.model.tool_input_format()) .log_err()?, }) }) @@ -917,7 +975,7 @@ impl Thread { thread_id: None, prompt_id: None, intent: Some(completion_intent), - mode: Some(self.completion_mode), + mode: Some(self.completion_mode.into()), messages, tools, tool_choice: None, @@ -935,7 +993,7 @@ impl Thread { .profiles .get(&self.profile_id) .context("profile not found")?; - let provider_id = self.selected_model.provider_id(); + let provider_id = self.model.provider_id(); Ok(self .tools @@ -971,6 +1029,11 @@ impl Thread { match message { Message::User(message) => messages.push(message.to_request()), Message::Agent(message) => messages.extend(message.to_request()), + Message::Resume => messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec!["Continue where you left off".into()], + cache: false, + }), } } @@ -1123,9 +1186,7 @@ where } #[derive(Clone)] -struct AgentResponseEventStream( - mpsc::UnboundedSender>, -); +struct AgentResponseEventStream(mpsc::UnboundedSender>); impl AgentResponseEventStream { fn send_text(&self, text: &str) { @@ -1212,8 +1273,8 @@ impl AgentResponseEventStream { } } - fn send_error(&self, error: LanguageModelCompletionError) { - self.0.unbounded_send(Err(error)).ok(); + fn send_error(&self, error: impl Into) { + self.0.unbounded_send(Err(error.into())).ok(); } } @@ -1229,8 +1290,7 @@ pub struct ToolCallEventStream { impl ToolCallEventStream { #[cfg(test)] pub fn test() -> (Self, ToolCallEventStreamReceiver) { - let (events_tx, events_rx) = - mpsc::unbounded::>(); + let (events_tx, events_rx) = mpsc::unbounded::>(); let stream = ToolCallEventStream::new( &LanguageModelToolUse { @@ -1351,9 +1411,7 @@ impl ToolCallEventStream { } #[cfg(test)] -pub struct ToolCallEventStreamReceiver( - mpsc::UnboundedReceiver>, -); +pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); #[cfg(test)] impl ToolCallEventStreamReceiver { @@ -1381,7 +1439,7 @@ impl ToolCallEventStreamReceiver { #[cfg(test)] impl std::ops::Deref for ToolCallEventStreamReceiver { - type Target = mpsc::UnboundedReceiver>; + type Target = mpsc::UnboundedReceiver>; fn deref(&self) -> &Self::Target { &self.0 diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index 405afb585f..c77b9f6a69 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -241,7 +241,7 @@ impl AgentTool for EditFileTool { thread.build_completion_request(CompletionIntent::ToolResults, cx) }); let thread = self.thread.read(cx); - let model = thread.selected_model.clone(); + let model = thread.model().clone(); let action_log = thread.action_log().clone(); let authorize = self.authorize(&input, &event_stream, cx); diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index 15f8635cde..e936c87643 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow}; use futures::channel::oneshot; use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use project::Project; -use std::{cell::RefCell, path::Path, rc::Rc}; +use std::{any::Any, cell::RefCell, path::Path, rc::Rc}; use ui::App; use util::ResultExt as _; @@ -507,4 +507,8 @@ impl AgentConnection for AcpConnection { }) .detach_and_log_err(cx) } + + fn into_any(self: Rc) -> Rc { + self + } } diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index d93e3d023e..36511e4644 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -3,9 +3,9 @@ use anyhow::anyhow; use collections::HashMap; use futures::channel::oneshot; use project::Project; -use std::cell::RefCell; use std::path::Path; use std::rc::Rc; +use std::{any::Any, cell::RefCell}; use anyhow::{Context as _, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; @@ -191,6 +191,10 @@ impl AgentConnection for AcpConnection { .spawn(async move { conn.cancel(params).await }) .detach(); } + + fn into_any(self: Rc) -> Rc { + self + } } struct ClientDelegate { diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index dbcda00e48..e1cc709289 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -6,6 +6,7 @@ use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; +use std::any::Any; use std::cell::RefCell; use std::fmt::Display; use std::path::Path; @@ -289,6 +290,10 @@ impl AgentConnection for ClaudeAgentConnection { }) .log_err(); } + + fn into_any(self: Rc) -> Rc { + self + } } #[derive(Clone, Copy)] diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index ee016b7503..87af75f046 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -7,20 +7,21 @@ use action_log::ActionLog; use agent::{TextThreadStore, ThreadStore}; use agent_client_protocol::{self as acp}; use agent_servers::AgentServer; -use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; +use agent_settings::{AgentSettings, CompletionMode, NotifyWhenAgentWaiting}; use anyhow::bail; use audio::{Audio, Sound}; use buffer_diff::BufferDiff; +use client::zed_urls; use collections::{HashMap, HashSet}; use editor::scroll::Autoscroll; use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects}; use file_icons::FileIcons; use gpui::{ - Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, EdgesRefinement, Empty, Entity, - FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay, - SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, - Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, - linear_gradient, list, percentage, point, prelude::*, pulsating_between, + Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement, + Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, + PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, + TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, + linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between, }; use language::Buffer; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; @@ -32,8 +33,8 @@ use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration}; use text::Anchor; use theme::ThemeSettings; use ui::{ - Disclosure, Divider, DividerColor, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, - Tooltip, prelude::*, + Callout, Disclosure, Divider, DividerColor, ElevationIndex, KeyBinding, PopoverMenuHandle, + Scrollbar, ScrollbarState, Tooltip, prelude::*, }; use util::{ResultExt, size::format_file_size, time::duration_alt_display}; use workspace::{CollaboratorId, Workspace}; @@ -44,16 +45,39 @@ use super::entry_view_state::EntryViewState; use crate::acp::AcpModelSelectorPopover; use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; use crate::agent_diff::AgentDiff; -use crate::ui::{AgentNotification, AgentNotificationEvent}; +use crate::ui::{AgentNotification, AgentNotificationEvent, BurnModeTooltip}; use crate::{ - AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll, + AgentDiffPane, AgentPanel, ContinueThread, ContinueWithBurnMode, ExpandMessageEditor, Follow, + KeepAll, OpenAgentDiff, RejectAll, ToggleBurnMode, }; const RESPONSE_PADDING_X: Pixels = px(19.); - pub const MIN_EDITOR_LINES: usize = 4; pub const MAX_EDITOR_LINES: usize = 8; +enum ThreadError { + PaymentRequired, + ModelRequestLimitReached(cloud_llm_client::Plan), + ToolUseLimitReached, + Other(SharedString), +} + +impl ThreadError { + fn from_err(error: anyhow::Error) -> Self { + if error.is::() { + Self::PaymentRequired + } else if error.is::() { + Self::ToolUseLimitReached + } else if let Some(error) = + error.downcast_ref::() + { + Self::ModelRequestLimitReached(error.plan) + } else { + Self::Other(error.to_string().into()) + } + } +} + pub struct AcpThreadView { agent: Rc, workspace: WeakEntity, @@ -66,7 +90,7 @@ pub struct AcpThreadView { model_selector: Option>, notifications: Vec>, notification_subscriptions: HashMap, Vec>, - last_error: Option>, + thread_error: Option, list_state: ListState, scrollbar_state: ScrollbarState, auth_task: Option>, @@ -151,7 +175,7 @@ impl AcpThreadView { entry_view_state: EntryViewState::default(), list_state: list_state.clone(), scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()), - last_error: None, + thread_error: None, auth_task: None, expanded_tool_calls: HashSet::default(), expanded_thinking_blocks: HashSet::default(), @@ -316,7 +340,7 @@ impl AcpThreadView { } pub fn cancel_generation(&mut self, cx: &mut Context) { - self.last_error.take(); + self.thread_error.take(); if let Some(thread) = self.thread() { self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx))); @@ -371,6 +395,25 @@ impl AcpThreadView { } } + fn resume_chat(&mut self, cx: &mut Context) { + self.thread_error.take(); + let Some(thread) = self.thread() else { + return; + }; + + let task = thread.update(cx, |thread, cx| thread.resume(cx)); + cx.spawn(async move |this, cx| { + let result = task.await; + + this.update(cx, |this, cx| { + if let Err(err) = result { + this.handle_thread_error(err, cx); + } + }) + }) + .detach(); + } + fn send(&mut self, window: &mut Window, cx: &mut Context) { let contents = self .message_editor @@ -384,7 +427,7 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - self.last_error.take(); + self.thread_error.take(); self.editing_message.take(); let Some(thread) = self.thread().cloned() else { @@ -409,11 +452,9 @@ impl AcpThreadView { }); cx.spawn(async move |this, cx| { - if let Err(e) = task.await { + if let Err(err) = task.await { this.update(cx, |this, cx| { - this.last_error = - Some(cx.new(|cx| Markdown::new(e.to_string().into(), None, None, cx))); - cx.notify() + this.handle_thread_error(err, cx); }) .ok(); } @@ -476,6 +517,16 @@ impl AcpThreadView { }) } + fn handle_thread_error(&mut self, error: anyhow::Error, cx: &mut Context) { + self.thread_error = Some(ThreadError::from_err(error)); + cx.notify(); + } + + fn clear_thread_error(&mut self, cx: &mut Context) { + self.thread_error = None; + cx.notify(); + } + fn handle_thread_event( &mut self, thread: &Entity, @@ -551,7 +602,7 @@ impl AcpThreadView { return; }; - self.last_error.take(); + self.thread_error.take(); let authenticate = connection.authenticate(method, cx); self.auth_task = Some(cx.spawn_in(window, { let project = self.project.clone(); @@ -561,9 +612,7 @@ impl AcpThreadView { this.update_in(cx, |this, window, cx| { if let Err(err) = result { - this.last_error = Some(cx.new(|cx| { - Markdown::new(format!("Error: {err}").into(), None, None, cx) - })) + this.handle_thread_error(err, cx); } else { this.thread_state = Self::initial_state( agent, @@ -620,9 +669,7 @@ impl AcpThreadView { .py_4() .px_2() .children(message.id.clone().and_then(|message_id| { - message.checkpoint.as_ref()?; - - Some( + message.checkpoint.as_ref()?.show.then(|| { Button::new("restore-checkpoint", "Restore Checkpoint") .icon(IconName::Undo) .icon_size(IconSize::XSmall) @@ -630,8 +677,8 @@ impl AcpThreadView { .label_size(LabelSize::XSmall) .on_click(cx.listener(move |this, _, _window, cx| { this.rewind(&message_id, cx); - })), - ) + })) + }) })) .child( v_flex() @@ -2322,7 +2369,12 @@ impl AcpThreadView { h_flex() .flex_none() .justify_between() - .child(self.render_follow_toggle(cx)) + .child( + h_flex() + .gap_1() + .child(self.render_follow_toggle(cx)) + .children(self.render_burn_mode_toggle(cx)), + ) .child( h_flex() .gap_1() @@ -2333,6 +2385,68 @@ impl AcpThreadView { .into_any() } + fn as_native_connection(&self, cx: &App) -> Option> { + let acp_thread = self.thread()?.read(cx); + acp_thread.connection().clone().downcast() + } + + fn as_native_thread(&self, cx: &App) -> Option> { + let acp_thread = self.thread()?.read(cx); + self.as_native_connection(cx)? + .thread(acp_thread.session_id(), cx) + } + + fn toggle_burn_mode( + &mut self, + _: &ToggleBurnMode, + _window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.as_native_thread(cx) else { + return; + }; + + thread.update(cx, |thread, _cx| { + let current_mode = thread.completion_mode(); + thread.set_completion_mode(match current_mode { + CompletionMode::Burn => CompletionMode::Normal, + CompletionMode::Normal => CompletionMode::Burn, + }); + }); + } + + fn render_burn_mode_toggle(&self, cx: &mut Context) -> Option { + let thread = self.as_native_thread(cx)?.read(cx); + + if !thread.model().supports_burn_mode() { + return None; + } + + let active_completion_mode = thread.completion_mode(); + let burn_mode_enabled = active_completion_mode == CompletionMode::Burn; + let icon = if burn_mode_enabled { + IconName::ZedBurnModeOn + } else { + IconName::ZedBurnMode + }; + + Some( + IconButton::new("burn-mode", icon) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .toggle_state(burn_mode_enabled) + .selected_icon_color(Color::Error) + .on_click(cx.listener(|this, _event, window, cx| { + this.toggle_burn_mode(&ToggleBurnMode, window, cx); + })) + .tooltip(move |_window, cx| { + cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled)) + .into() + }) + .into_any_element(), + ) + } + fn start_editing_message(&mut self, index: usize, window: &mut Window, cx: &mut Context) { let Some(thread) = self.thread() else { return; @@ -3002,6 +3116,187 @@ impl AcpThreadView { } } +impl AcpThreadView { + fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option
{ + let content = match self.thread_error.as_ref()? { + ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx), + ThreadError::PaymentRequired => self.render_payment_required_error(cx), + ThreadError::ModelRequestLimitReached(plan) => { + self.render_model_request_limit_reached_error(*plan, cx) + } + ThreadError::ToolUseLimitReached => { + self.render_tool_use_limit_reached_error(window, cx)? + } + }; + + Some( + div() + .border_t_1() + .border_color(cx.theme().colors().border) + .child(content), + ) + } + + fn render_any_thread_error(&self, error: SharedString, cx: &mut Context<'_, Self>) -> Callout { + let icon = Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error); + + Callout::new() + .icon(icon) + .title("Error") + .description(error.clone()) + .secondary_action(self.create_copy_button(error.to_string())) + .primary_action(self.dismiss_error_button(cx)) + .bg_color(self.error_callout_bg(cx)) + } + + fn render_payment_required_error(&self, cx: &mut Context) -> Callout { + const ERROR_MESSAGE: &str = + "You reached your free usage limit. Upgrade to Zed Pro for more prompts."; + + let icon = Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error); + + Callout::new() + .icon(icon) + .title("Free Usage Exceeded") + .description(ERROR_MESSAGE) + .tertiary_action(self.upgrade_button(cx)) + .secondary_action(self.create_copy_button(ERROR_MESSAGE)) + .primary_action(self.dismiss_error_button(cx)) + .bg_color(self.error_callout_bg(cx)) + } + + fn render_model_request_limit_reached_error( + &self, + plan: cloud_llm_client::Plan, + cx: &mut Context, + ) -> Callout { + let error_message = match plan { + cloud_llm_client::Plan::ZedPro => "Upgrade to usage-based billing for more prompts.", + cloud_llm_client::Plan::ZedProTrial | cloud_llm_client::Plan::ZedFree => { + "Upgrade to Zed Pro for more prompts." + } + }; + + let icon = Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error); + + Callout::new() + .icon(icon) + .title("Model Prompt Limit Reached") + .description(error_message) + .tertiary_action(self.upgrade_button(cx)) + .secondary_action(self.create_copy_button(error_message)) + .primary_action(self.dismiss_error_button(cx)) + .bg_color(self.error_callout_bg(cx)) + } + + fn render_tool_use_limit_reached_error( + &self, + window: &mut Window, + cx: &mut Context, + ) -> Option { + let thread = self.as_native_thread(cx)?; + let supports_burn_mode = thread.read(cx).model().supports_burn_mode(); + + let focus_handle = self.focus_handle(cx); + + let icon = Icon::new(IconName::Info) + .size(IconSize::Small) + .color(Color::Info); + + Some( + Callout::new() + .icon(icon) + .title("Consecutive tool use limit reached.") + .when(supports_burn_mode, |this| { + this.secondary_action( + Button::new("continue-burn-mode", "Continue with Burn Mode") + .style(ButtonStyle::Filled) + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .layer(ElevationIndex::ModalSurface) + .label_size(LabelSize::Small) + .key_binding( + KeyBinding::for_action_in( + &ContinueWithBurnMode, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(10.))), + ) + .tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use.")) + .on_click({ + cx.listener(move |this, _, _window, cx| { + thread.update(cx, |thread, _cx| { + thread.set_completion_mode(CompletionMode::Burn); + }); + this.resume_chat(cx); + }) + }), + ) + }) + .primary_action( + Button::new("continue-conversation", "Continue") + .layer(ElevationIndex::ModalSurface) + .label_size(LabelSize::Small) + .key_binding( + KeyBinding::for_action_in(&ContinueThread, &focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(10.))), + ) + .on_click(cx.listener(|this, _, _window, cx| { + this.resume_chat(cx); + })), + ), + ) + } + + fn create_copy_button(&self, message: impl Into) -> impl IntoElement { + let message = message.into(); + + IconButton::new("copy", IconName::Copy) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .tooltip(Tooltip::text("Copy Error Message")) + .on_click(move |_, _, cx| { + cx.write_to_clipboard(ClipboardItem::new_string(message.clone())) + }) + } + + fn dismiss_error_button(&self, cx: &mut Context) -> impl IntoElement { + IconButton::new("dismiss", IconName::Close) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .tooltip(Tooltip::text("Dismiss Error")) + .on_click(cx.listener({ + move |this, _, _, cx| { + this.clear_thread_error(cx); + cx.notify(); + } + })) + } + + fn upgrade_button(&self, cx: &mut Context) -> impl IntoElement { + Button::new("upgrade", "Upgrade") + .label_size(LabelSize::Small) + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(cx.listener({ + move |this, _, _, cx| { + this.clear_thread_error(cx); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)); + } + })) + } + + fn error_callout_bg(&self, cx: &Context) -> Hsla { + cx.theme().status().error.opacity(0.08) + } +} + impl Focusable for AcpThreadView { fn focus_handle(&self, cx: &App) -> FocusHandle { self.message_editor.focus_handle(cx) @@ -3016,6 +3311,7 @@ impl Render for AcpThreadView { .size_full() .key_context("AcpThread") .on_action(cx.listener(Self::open_agent_diff)) + .on_action(cx.listener(Self::toggle_burn_mode)) .bg(cx.theme().colors().panel_background) .child(match &self.thread_state { ThreadState::Unauthenticated { connection } => v_flex() @@ -3100,19 +3396,7 @@ impl Render for AcpThreadView { } _ => this, }) - .when_some(self.last_error.clone(), |el, error| { - el.child( - div() - .p_2() - .text_xs() - .border_t_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().status().error_background) - .child( - self.render_markdown(error, default_markdown_style(false, window, cx)), - ), - ) - }) + .children(self.render_thread_error(window, cx)) .child(self.render_message_editor(window, cx)) } } @@ -3299,8 +3583,6 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { #[cfg(test)] pub(crate) mod tests { - use std::path::Path; - use acp_thread::StubAgentConnection; use agent::{TextThreadStore, ThreadStore}; use agent_client_protocol::SessionId; @@ -3310,6 +3592,8 @@ pub(crate) mod tests { use project::Project; use serde_json::json; use settings::SettingsStore; + use std::any::Any; + use std::path::Path; use super::*; @@ -3547,6 +3831,10 @@ pub(crate) mod tests { fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { unimplemented!() } + + fn into_any(self: Rc) -> Rc { + self + } } pub(crate) fn init_test(cx: &mut TestAppContext) { diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 231b9cfb38..4f5f022593 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -5,7 +5,6 @@ mod agent_diff; mod agent_model_selector; mod agent_panel; mod buffer_codegen; -mod burn_mode_tooltip; mod context_picker; mod context_server_configuration; mod context_strip; diff --git a/crates/agent_ui/src/burn_mode_tooltip.rs b/crates/agent_ui/src/burn_mode_tooltip.rs deleted file mode 100644 index 6354c07760..0000000000 --- a/crates/agent_ui/src/burn_mode_tooltip.rs +++ /dev/null @@ -1,61 +0,0 @@ -use gpui::{Context, FontWeight, IntoElement, Render, Window}; -use ui::{prelude::*, tooltip_container}; - -pub struct BurnModeTooltip { - selected: bool, -} - -impl BurnModeTooltip { - pub fn new() -> Self { - Self { selected: false } - } - - pub fn selected(mut self, selected: bool) -> Self { - self.selected = selected; - self - } -} - -impl Render for BurnModeTooltip { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let (icon, color) = if self.selected { - (IconName::ZedBurnModeOn, Color::Error) - } else { - (IconName::ZedBurnMode, Color::Default) - }; - - let turned_on = h_flex() - .h_4() - .px_1() - .border_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().text_accent.opacity(0.1)) - .rounded_sm() - .child( - Label::new("ON") - .size(LabelSize::XSmall) - .weight(FontWeight::SEMIBOLD) - .color(Color::Accent), - ); - - let title = h_flex() - .gap_1p5() - .child(Icon::new(icon).size(IconSize::Small).color(color)) - .child(Label::new("Burn Mode")) - .when(self.selected, |title| title.child(turned_on)); - - tooltip_container(window, cx, |this, _, _| { - this - .child(title) - .child( - div() - .max_w_64() - .child( - Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.") - .size(LabelSize::Small) - .color(Color::Muted) - ) - ) - }) - } -} diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 4b6d51c4c1..5d094811f1 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -6,7 +6,7 @@ use crate::agent_diff::AgentDiffThread; use crate::agent_model_selector::AgentModelSelector; use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip}; use crate::ui::{ - MaxModeTooltip, + BurnModeTooltip, preview::{AgentPreview, UsageCallout}, }; use agent::history_store::HistoryStore; @@ -605,7 +605,7 @@ impl MessageEditor { this.toggle_burn_mode(&ToggleBurnMode, window, cx); })) .tooltip(move |_window, cx| { - cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled)) + cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled)) .into() }) .into_any_element(), diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 49a37002f7..2e3b4ed890 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -1,6 +1,6 @@ use crate::{ - burn_mode_tooltip::BurnModeTooltip, language_model_selector::{LanguageModelSelector, language_model_selector}, + ui::BurnModeTooltip, }; use agent_settings::{AgentSettings, CompletionMode}; use anyhow::Result; diff --git a/crates/agent_ui/src/ui/burn_mode_tooltip.rs b/crates/agent_ui/src/ui/burn_mode_tooltip.rs index 97f7853a61..72faaa614d 100644 --- a/crates/agent_ui/src/ui/burn_mode_tooltip.rs +++ b/crates/agent_ui/src/ui/burn_mode_tooltip.rs @@ -2,11 +2,11 @@ use crate::ToggleBurnMode; use gpui::{Context, FontWeight, IntoElement, Render, Window}; use ui::{KeyBinding, prelude::*, tooltip_container}; -pub struct MaxModeTooltip { +pub struct BurnModeTooltip { selected: bool, } -impl MaxModeTooltip { +impl BurnModeTooltip { pub fn new() -> Self { Self { selected: false } } @@ -17,7 +17,7 @@ impl MaxModeTooltip { } } -impl Render for MaxModeTooltip { +impl Render for BurnModeTooltip { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let (icon, color) = if self.selected { (IconName::ZedBurnModeOn, Color::Error) diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 3b4c1fa269..0e10050dae 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -42,6 +42,18 @@ impl fmt::Display for ModelRequestLimitReachedError { } } +#[derive(Error, Debug)] +pub struct ToolUseLimitReachedError; + +impl fmt::Display for ToolUseLimitReachedError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Consecutive tool use limit reached. Enable Burn Mode for unlimited tool use." + ) + } +} + #[derive(Clone, Default)] pub struct LlmApiToken(Arc>>);