diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index ced8c5e401..6ebcece2b5 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,8 +1,8 @@ use crate::{AgentResponseEvent, Thread, templates::Templates}; use crate::{ - ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, - FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, - ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent, + ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool, + EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, + OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent, WebSearchTool, }; use acp_thread::AgentModelSelector; @@ -583,22 +583,22 @@ impl acp_thread::AgentConnection for NativeAgentConnection { default_model, cx, ); - thread.add_tool(CreateDirectoryTool::new(project.clone())); thread.add_tool(CopyPathTool::new(project.clone())); + thread.add_tool(CreateDirectoryTool::new(project.clone())); + thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); thread.add_tool(DiagnosticsTool::new(project.clone())); - thread.add_tool(MovePathTool::new(project.clone())); - thread.add_tool(ListDirectoryTool::new(project.clone())); - thread.add_tool(OpenTool::new(project.clone())); - thread.add_tool(ThinkingTool); - thread.add_tool(FindPathTool::new(project.clone())); - thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); - thread.add_tool(GrepTool::new(project.clone())); - thread.add_tool(ReadFileTool::new(project.clone(), action_log)); thread.add_tool(EditFileTool::new(cx.entity())); + thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); + thread.add_tool(FindPathTool::new(project.clone())); + thread.add_tool(GrepTool::new(project.clone())); + thread.add_tool(ListDirectoryTool::new(project.clone())); + thread.add_tool(MovePathTool::new(project.clone())); thread.add_tool(NowTool); + thread.add_tool(OpenTool::new(project.clone())); + thread.add_tool(ReadFileTool::new(project.clone(), action_log)); thread.add_tool(TerminalTool::new(project.clone(), cx)); - // TODO: Needs to be conditional based on zed model or not - thread.add_tool(WebSearchTool); + thread.add_tool(ThinkingTool); + thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model. thread }); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index b48f9001ac..4156ec44d2 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -411,7 +411,7 @@ pub struct Thread { /// Survives across multiple requests as the model performs tool calls and /// we run tools, report their results. running_turn: Option>, - pending_agent_message: Option, + pending_message: Option, tools: BTreeMap>, context_server_registry: Entity, profile_id: AgentProfileId, @@ -437,7 +437,7 @@ impl Thread { messages: Vec::new(), completion_mode: CompletionMode::Normal, running_turn: None, - pending_agent_message: None, + pending_message: None, tools: BTreeMap::default(), context_server_registry, profile_id, @@ -463,7 +463,7 @@ impl Thread { #[cfg(any(test, feature = "test-support"))] pub fn last_message(&self) -> Option { - if let Some(message) = self.pending_agent_message.clone() { + if let Some(message) = self.pending_message.clone() { Some(Message::Agent(message)) } else { self.messages.last().cloned() @@ -485,7 +485,7 @@ impl Thread { pub fn cancel(&mut self) { // TODO: do we need to emit a stop::cancel for ACP? self.running_turn.take(); - self.flush_pending_agent_message(); + self.flush_pending_message(); } pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> { @@ -521,74 +521,58 @@ impl Thread { mpsc::unbounded::>(); let event_stream = AgentResponseEventStream(events_tx); - let user_message_ix = self.messages.len(); self.messages.push(Message::User(UserMessage { - id: message_id, + id: message_id.clone(), content, })); log::info!("Total messages in thread: {}", self.messages.len()); - self.running_turn = Some(cx.spawn(async move |thread, cx| { + self.running_turn = Some(cx.spawn(async move |this, cx| { log::info!("Starting agent turn execution"); let turn_result = async { - // Perform one request, then keep looping if the model makes tool calls. let mut completion_intent = CompletionIntent::UserPrompt; - 'outer: loop { + loop { log::debug!( "Building completion request with intent: {:?}", completion_intent ); - let request = thread.update(cx, |thread, cx| { - thread.build_completion_request(completion_intent, cx) + let request = this.update(cx, |this, cx| { + this.build_completion_request(completion_intent, cx) })?; - // Stream events, appending to messages and collecting up tool uses. log::info!("Calling model.stream_completion"); let mut events = model.stream_completion(request, cx).await?; log::debug!("Stream completion started successfully"); let mut tool_uses = FuturesUnordered::new(); while let Some(event) = events.next().await { - match event { - Ok(LanguageModelCompletionEvent::Stop(reason)) => { + match event? { + LanguageModelCompletionEvent::Stop(reason) => { event_stream.send_stop(reason); if reason == StopReason::Refusal { - thread.update(cx, |thread, _cx| { - thread.pending_agent_message = None; - thread.messages.truncate(user_message_ix); - })?; - break 'outer; + this.update(cx, |this, _cx| this.truncate(message_id))??; + return Ok(()); } } - Ok(event) => { + event => { log::trace!("Received completion event: {:?}", event); - thread - .update(cx, |thread, cx| { - tool_uses.extend(thread.handle_streamed_completion_event( - event, - &event_stream, - cx, - )); - }) - .ok(); - } - Err(error) => { - log::error!("Error in completion stream: {:?}", error); - event_stream.send_error(error); - break; + this.update(cx, |this, cx| { + tool_uses.extend(this.handle_streamed_completion_event( + event, + &event_stream, + cx, + )); + }) + .ok(); } } } - // If there are no tool uses, the turn is done. if tool_uses.is_empty() { log::info!("No tool uses found, completing turn"); - break; + return Ok(()); } log::info!("Found {} tool uses to execute", tool_uses.len()); - // As tool results trickle in, insert them in the last user - // message so that they can be sent on the next tick of the - // agentic loop. while let Some(tool_result) = tool_uses.next().await { log::info!("Tool finished {:?}", tool_result); @@ -604,29 +588,21 @@ impl Thread { ..Default::default() }, ); - thread - .update(cx, |thread, _cx| { - thread - .pending_agent_message() - .tool_results - .insert(tool_result.tool_use_id.clone(), tool_result); - }) - .ok(); + this.update(cx, |this, _cx| { + this.pending_message() + .tool_results + .insert(tool_result.tool_use_id.clone(), tool_result); + }) + .ok(); } - thread.update(cx, |thread, _cx| thread.flush_pending_agent_message())?; - + this.update(cx, |this, _| this.flush_pending_message())?; completion_intent = CompletionIntent::ToolResults; } - - Ok(()) } .await; - thread - .update(cx, |thread, _cx| thread.flush_pending_agent_message()) - .ok(); - + this.update(cx, |this, _| this.flush_pending_message()).ok(); if let Err(error) = turn_result { log::error!("Turn execution failed: {:?}", error); event_stream.send_error(error); @@ -668,7 +644,8 @@ impl Thread { match event { StartMessage { .. } => { - self.messages.push(Message::Agent(AgentMessage::default())); + self.flush_pending_message(); + self.pending_message = Some(AgentMessage::default()); } Text(new_text) => self.handle_text_event(new_text, event_stream, cx), Thinking { text, signature } => { @@ -706,7 +683,7 @@ impl Thread { ) { events_stream.send_text(&new_text); - let last_message = self.pending_agent_message(); + let last_message = self.pending_message(); if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() { text.push_str(&new_text); } else { @@ -727,7 +704,7 @@ impl Thread { ) { event_stream.send_thinking(&new_text); - let last_message = self.pending_agent_message(); + let last_message = self.pending_message(); if let Some(AgentMessageContent::Thinking { text, signature }) = last_message.content.last_mut() { @@ -744,7 +721,7 @@ impl Thread { } fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context) { - let last_message = self.pending_agent_message(); + let last_message = self.pending_message(); last_message .content .push(AgentMessageContent::RedactedThinking(data)); @@ -768,7 +745,7 @@ impl Thread { } // Ensure the last message ends in the current tool use - let last_message = self.pending_agent_message(); + let last_message = self.pending_message(); let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| { if let AgentMessageContent::ToolUse(last_tool_use) = content { if last_tool_use.id == tool_use.id { @@ -871,12 +848,12 @@ impl Thread { } } - fn pending_agent_message(&mut self) -> &mut AgentMessage { - self.pending_agent_message.get_or_insert_default() + fn pending_message(&mut self) -> &mut AgentMessage { + self.pending_message.get_or_insert_default() } - fn flush_pending_agent_message(&mut self) { - let Some(mut message) = self.pending_agent_message.take() else { + fn flush_pending_message(&mut self) { + let Some(mut message) = self.pending_message.take() else { return; }; @@ -997,7 +974,7 @@ impl Thread { } } - if let Some(message) = self.pending_agent_message.as_ref() { + if let Some(message) = self.pending_message.as_ref() { messages.extend(message.to_request()); } @@ -1013,7 +990,7 @@ impl Thread { markdown.push_str(&message.to_markdown()); } - if let Some(message) = self.pending_agent_message.as_ref() { + if let Some(message) = self.pending_message.as_ref() { markdown.push('\n'); markdown.push_str(&message.to_markdown()); }