diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 2a17f23fd0..5b27f0a048 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -5,7 +5,7 @@ use crate::{ OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, }; -use crate::{DbThread, ThreadsDatabase}; +use crate::{DbThread, ThreadId, ThreadsDatabase}; use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; use agent_client_protocol as acp; use agent_settings::AgentSettings; @@ -473,10 +473,18 @@ impl NativeAgentConnection { }; log::debug!("Found session for: {}", session_id); - let mut response_stream = match f(thread, cx) { + let response_stream = match f(thread, cx) { Ok(stream) => stream, Err(err) => return Task::ready(Err(err)), }; + Self::handle_thread_events(response_stream, acp_thread, cx) + } + + fn handle_thread_events( + mut response_stream: mpsc::UnboundedReceiver>, + acp_thread: WeakEntity, + cx: &mut App, + ) -> Task> { cx.spawn(async move |cx| { // Handle response stream and forward to session.acp_thread while let Some(result) = response_stream.next().await { @@ -486,7 +494,15 @@ impl NativeAgentConnection { match event { ThreadEvent::UserMessage(message) => { - todo!() + acp_thread.update(cx, |thread, cx| { + for content in message.content { + thread.push_user_content_block( + Some(message.id.clone()), + content.into(), + cx, + ); + } + })?; } ThreadEvent::AgentText(text) => { acp_thread.update(cx, |thread, cx| { @@ -806,19 +822,19 @@ impl acp_thread::AgentConnection for NativeAgentConnection { session_id: acp::SessionId, cx: &mut App, ) -> Task>> { - let thread_id = session_id.clone().into(); + let thread_id = ThreadId::from(session_id.clone()); let database = self.0.update(cx, |this, _| this.thread_database.clone()); cx.spawn(async move |cx| { let database = database.await.map_err(|e| anyhow!(e))?; let db_thread = database - .load_thread(thread_id) + .load_thread(thread_id.clone()) .await? .context("no such thread found")?; let acp_thread = cx.update(|cx| { cx.new(|cx| { acp_thread::AcpThread::new( - db_thread.title, + db_thread.title.clone(), self.clone(), project.clone(), session_id.clone(), @@ -835,6 +851,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .update(cx, |registry, cx| { db_thread .model + .as_ref() .and_then(|model| { let model = SelectedModel { provider: model.provider.clone().into(), @@ -852,7 +869,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .context("no model by id")?; let thread = cx.new(|cx| { - let mut thread = Thread::new( + let mut thread = Thread::from_db( + thread_id, + db_thread, project.clone(), agent.project_context.clone(), agent.context_server_registry.clone(), @@ -873,7 +892,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { agent.sessions.insert( session_id, Session { - thread, + thread: thread.clone(), acp_thread: acp_thread.downgrade(), _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { this.sessions.remove(acp_thread.session_id()); @@ -882,8 +901,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection { ); })?; - // we need to actually deserialize the DbThread. - // todo!() + let events = thread.update(cx, |thread, cx| thread.replay(cx))?; + cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))? + .await?; Ok(acp_thread) }) diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 87f4803daf..6d688675ba 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -12,7 +12,7 @@ use futures::{ channel::{mpsc, oneshot}, stream::FuturesUnordered, }; -use gpui::{App, Context, Entity, SharedString, Task}; +use gpui::{App, AppContext, Context, Entity, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, @@ -545,7 +545,10 @@ impl Thread { } } - pub fn replay(&self, cx: &mut Context) -> mpsc::UnboundedReceiver> { + pub fn replay( + &mut self, + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { let (tx, rx) = mpsc::unbounded(); let stream = ThreadEventStream(tx); for message in &self.messages { @@ -615,16 +618,15 @@ impl Thread { ); tool.replay(tool_use.input.clone(), output, tool_event_stream, cx) .log_err(); - } else { - stream.update_tool_call_fields( - &tool_use.id, - acp::ToolCallUpdateFields { - content: Some(vec![TOOL_CANCELED_MESSAGE.into()]), - status: Some(acp::ToolCallStatus::Failed), - ..Default::default() - }, - ); } + + stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + ..Default::default() + }, + ); } pub fn project(&self) -> &Entity { @@ -1744,6 +1746,26 @@ impl From for UserMessageContent { } } +impl From for acp::ContentBlock { + fn from(content: UserMessageContent) -> Self { + match content { + UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent { + data: image.source.to_string(), + mime_type: "image/png".to_string(), + annotations: None, + uri: None, + }), + UserMessageContent::Mention { uri, content } => { + todo!() + } + } + } +} + fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { LanguageModelImage { source: image_content.data.into(),