Stream deserialized thread to AcpThread

This commit is contained in:
Antonio Scandurra 2025-08-18 14:16:19 +02:00
parent 5259c8692d
commit 3a0e55d9b6
2 changed files with 63 additions and 21 deletions

View file

@ -5,7 +5,7 @@ use crate::{
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
UserMessageContent, WebSearchTool, templates::Templates, UserMessageContent, WebSearchTool, templates::Templates,
}; };
use crate::{DbThread, ThreadsDatabase}; use crate::{DbThread, ThreadId, ThreadsDatabase};
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
use agent_client_protocol as acp; use agent_client_protocol as acp;
use agent_settings::AgentSettings; use agent_settings::AgentSettings;
@ -473,10 +473,18 @@ impl NativeAgentConnection {
}; };
log::debug!("Found session for: {}", session_id); log::debug!("Found session for: {}", session_id);
let mut response_stream = match f(thread, cx) { let response_stream = match f(thread, cx) {
Ok(stream) => stream, Ok(stream) => stream,
Err(err) => return Task::ready(Err(err)), Err(err) => return Task::ready(Err(err)),
}; };
Self::handle_thread_events(response_stream, acp_thread, cx)
}
fn handle_thread_events(
mut response_stream: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
acp_thread: WeakEntity<AcpThread>,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
// Handle response stream and forward to session.acp_thread // Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await { while let Some(result) = response_stream.next().await {
@ -486,7 +494,15 @@ impl NativeAgentConnection {
match event { match event {
ThreadEvent::UserMessage(message) => { ThreadEvent::UserMessage(message) => {
todo!() acp_thread.update(cx, |thread, cx| {
for content in message.content {
thread.push_user_content_block(
Some(message.id.clone()),
content.into(),
cx,
);
}
})?;
} }
ThreadEvent::AgentText(text) => { ThreadEvent::AgentText(text) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
@ -806,19 +822,19 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
session_id: acp::SessionId, session_id: acp::SessionId,
cx: &mut App, cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> { ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
let thread_id = session_id.clone().into(); let thread_id = ThreadId::from(session_id.clone());
let database = self.0.update(cx, |this, _| this.thread_database.clone()); let database = self.0.update(cx, |this, _| this.thread_database.clone());
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let database = database.await.map_err(|e| anyhow!(e))?; let database = database.await.map_err(|e| anyhow!(e))?;
let db_thread = database let db_thread = database
.load_thread(thread_id) .load_thread(thread_id.clone())
.await? .await?
.context("no such thread found")?; .context("no such thread found")?;
let acp_thread = cx.update(|cx| { let acp_thread = cx.update(|cx| {
cx.new(|cx| { cx.new(|cx| {
acp_thread::AcpThread::new( acp_thread::AcpThread::new(
db_thread.title, db_thread.title.clone(),
self.clone(), self.clone(),
project.clone(), project.clone(),
session_id.clone(), session_id.clone(),
@ -835,6 +851,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.update(cx, |registry, cx| { .update(cx, |registry, cx| {
db_thread db_thread
.model .model
.as_ref()
.and_then(|model| { .and_then(|model| {
let model = SelectedModel { let model = SelectedModel {
provider: model.provider.clone().into(), provider: model.provider.clone().into(),
@ -852,7 +869,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.context("no model by id")?; .context("no model by id")?;
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
let mut thread = Thread::new( let mut thread = Thread::from_db(
thread_id,
db_thread,
project.clone(), project.clone(),
agent.project_context.clone(), agent.project_context.clone(),
agent.context_server_registry.clone(), agent.context_server_registry.clone(),
@ -873,7 +892,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
agent.sessions.insert( agent.sessions.insert(
session_id, session_id,
Session { Session {
thread, thread: thread.clone(),
acp_thread: acp_thread.downgrade(), acp_thread: acp_thread.downgrade(),
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id()); this.sessions.remove(acp_thread.session_id());
@ -882,8 +901,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
); );
})?; })?;
// we need to actually deserialize the DbThread. let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
// todo!() cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))?
.await?;
Ok(acp_thread) Ok(acp_thread)
}) })

View file

@ -12,7 +12,7 @@ use futures::{
channel::{mpsc, oneshot}, channel::{mpsc, oneshot},
stream::FuturesUnordered, stream::FuturesUnordered,
}; };
use gpui::{App, Context, Entity, SharedString, Task}; use gpui::{App, AppContext, Context, Entity, SharedString, Task};
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
@ -545,7 +545,10 @@ impl Thread {
} }
} }
pub fn replay(&self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> { pub fn replay(
&mut self,
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
let (tx, rx) = mpsc::unbounded(); let (tx, rx) = mpsc::unbounded();
let stream = ThreadEventStream(tx); let stream = ThreadEventStream(tx);
for message in &self.messages { for message in &self.messages {
@ -615,16 +618,15 @@ impl Thread {
); );
tool.replay(tool_use.input.clone(), output, tool_event_stream, cx) tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
.log_err(); .log_err();
} else {
stream.update_tool_call_fields(
&tool_use.id,
acp::ToolCallUpdateFields {
content: Some(vec![TOOL_CANCELED_MESSAGE.into()]),
status: Some(acp::ToolCallStatus::Failed),
..Default::default()
},
);
} }
stream.update_tool_call_fields(
&tool_use.id,
acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
..Default::default()
},
);
} }
pub fn project(&self) -> &Entity<Project> { pub fn project(&self) -> &Entity<Project> {
@ -1744,6 +1746,26 @@ impl From<acp::ContentBlock> for UserMessageContent {
} }
} }
impl From<UserMessageContent> for acp::ContentBlock {
fn from(content: UserMessageContent) -> Self {
match content {
UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
data: image.source.to_string(),
mime_type: "image/png".to_string(),
annotations: None,
uri: None,
}),
UserMessageContent::Mention { uri, content } => {
todo!()
}
}
}
}
fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
LanguageModelImage { LanguageModelImage {
source: image_content.data.into(), source: image_content.data.into(),