Stream deserialized thread to AcpThread
This commit is contained in:
parent
5259c8692d
commit
3a0e55d9b6
2 changed files with 63 additions and 21 deletions
|
@ -5,7 +5,7 @@ use crate::{
|
||||||
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
|
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
|
||||||
UserMessageContent, WebSearchTool, templates::Templates,
|
UserMessageContent, WebSearchTool, templates::Templates,
|
||||||
};
|
};
|
||||||
use crate::{DbThread, ThreadsDatabase};
|
use crate::{DbThread, ThreadId, ThreadsDatabase};
|
||||||
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
|
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use agent_settings::AgentSettings;
|
use agent_settings::AgentSettings;
|
||||||
|
@ -473,10 +473,18 @@ impl NativeAgentConnection {
|
||||||
};
|
};
|
||||||
log::debug!("Found session for: {}", session_id);
|
log::debug!("Found session for: {}", session_id);
|
||||||
|
|
||||||
let mut response_stream = match f(thread, cx) {
|
let response_stream = match f(thread, cx) {
|
||||||
Ok(stream) => stream,
|
Ok(stream) => stream,
|
||||||
Err(err) => return Task::ready(Err(err)),
|
Err(err) => return Task::ready(Err(err)),
|
||||||
};
|
};
|
||||||
|
Self::handle_thread_events(response_stream, acp_thread, cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_thread_events(
|
||||||
|
mut response_stream: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
|
||||||
|
acp_thread: WeakEntity<AcpThread>,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
// Handle response stream and forward to session.acp_thread
|
// Handle response stream and forward to session.acp_thread
|
||||||
while let Some(result) = response_stream.next().await {
|
while let Some(result) = response_stream.next().await {
|
||||||
|
@ -486,7 +494,15 @@ impl NativeAgentConnection {
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
ThreadEvent::UserMessage(message) => {
|
ThreadEvent::UserMessage(message) => {
|
||||||
todo!()
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
for content in message.content {
|
||||||
|
thread.push_user_content_block(
|
||||||
|
Some(message.id.clone()),
|
||||||
|
content.into(),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
})?;
|
||||||
}
|
}
|
||||||
ThreadEvent::AgentText(text) => {
|
ThreadEvent::AgentText(text) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
@ -806,19 +822,19 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
session_id: acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
||||||
let thread_id = session_id.clone().into();
|
let thread_id = ThreadId::from(session_id.clone());
|
||||||
let database = self.0.update(cx, |this, _| this.thread_database.clone());
|
let database = self.0.update(cx, |this, _| this.thread_database.clone());
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
let database = database.await.map_err(|e| anyhow!(e))?;
|
let database = database.await.map_err(|e| anyhow!(e))?;
|
||||||
let db_thread = database
|
let db_thread = database
|
||||||
.load_thread(thread_id)
|
.load_thread(thread_id.clone())
|
||||||
.await?
|
.await?
|
||||||
.context("no such thread found")?;
|
.context("no such thread found")?;
|
||||||
|
|
||||||
let acp_thread = cx.update(|cx| {
|
let acp_thread = cx.update(|cx| {
|
||||||
cx.new(|cx| {
|
cx.new(|cx| {
|
||||||
acp_thread::AcpThread::new(
|
acp_thread::AcpThread::new(
|
||||||
db_thread.title,
|
db_thread.title.clone(),
|
||||||
self.clone(),
|
self.clone(),
|
||||||
project.clone(),
|
project.clone(),
|
||||||
session_id.clone(),
|
session_id.clone(),
|
||||||
|
@ -835,6 +851,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
.update(cx, |registry, cx| {
|
.update(cx, |registry, cx| {
|
||||||
db_thread
|
db_thread
|
||||||
.model
|
.model
|
||||||
|
.as_ref()
|
||||||
.and_then(|model| {
|
.and_then(|model| {
|
||||||
let model = SelectedModel {
|
let model = SelectedModel {
|
||||||
provider: model.provider.clone().into(),
|
provider: model.provider.clone().into(),
|
||||||
|
@ -852,7 +869,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
.context("no model by id")?;
|
.context("no model by id")?;
|
||||||
|
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
let mut thread = Thread::new(
|
let mut thread = Thread::from_db(
|
||||||
|
thread_id,
|
||||||
|
db_thread,
|
||||||
project.clone(),
|
project.clone(),
|
||||||
agent.project_context.clone(),
|
agent.project_context.clone(),
|
||||||
agent.context_server_registry.clone(),
|
agent.context_server_registry.clone(),
|
||||||
|
@ -873,7 +892,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
agent.sessions.insert(
|
agent.sessions.insert(
|
||||||
session_id,
|
session_id,
|
||||||
Session {
|
Session {
|
||||||
thread,
|
thread: thread.clone(),
|
||||||
acp_thread: acp_thread.downgrade(),
|
acp_thread: acp_thread.downgrade(),
|
||||||
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||||
this.sessions.remove(acp_thread.session_id());
|
this.sessions.remove(acp_thread.session_id());
|
||||||
|
@ -882,8 +901,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
);
|
);
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// we need to actually deserialize the DbThread.
|
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
|
||||||
// todo!()
|
cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))?
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(acp_thread)
|
Ok(acp_thread)
|
||||||
})
|
})
|
||||||
|
|
|
@ -12,7 +12,7 @@ use futures::{
|
||||||
channel::{mpsc, oneshot},
|
channel::{mpsc, oneshot},
|
||||||
stream::FuturesUnordered,
|
stream::FuturesUnordered,
|
||||||
};
|
};
|
||||||
use gpui::{App, Context, Entity, SharedString, Task};
|
use gpui::{App, AppContext, Context, Entity, SharedString, Task};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
|
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
|
||||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||||
|
@ -545,7 +545,10 @@ impl Thread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn replay(&self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
|
pub fn replay(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
|
||||||
let (tx, rx) = mpsc::unbounded();
|
let (tx, rx) = mpsc::unbounded();
|
||||||
let stream = ThreadEventStream(tx);
|
let stream = ThreadEventStream(tx);
|
||||||
for message in &self.messages {
|
for message in &self.messages {
|
||||||
|
@ -615,16 +618,15 @@ impl Thread {
|
||||||
);
|
);
|
||||||
tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
|
tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
|
||||||
.log_err();
|
.log_err();
|
||||||
} else {
|
|
||||||
stream.update_tool_call_fields(
|
|
||||||
&tool_use.id,
|
|
||||||
acp::ToolCallUpdateFields {
|
|
||||||
content: Some(vec![TOOL_CANCELED_MESSAGE.into()]),
|
|
||||||
status: Some(acp::ToolCallStatus::Failed),
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stream.update_tool_call_fields(
|
||||||
|
&tool_use.id,
|
||||||
|
acp::ToolCallUpdateFields {
|
||||||
|
status: Some(acp::ToolCallStatus::Completed),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn project(&self) -> &Entity<Project> {
|
pub fn project(&self) -> &Entity<Project> {
|
||||||
|
@ -1744,6 +1746,26 @@ impl From<acp::ContentBlock> for UserMessageContent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<UserMessageContent> for acp::ContentBlock {
|
||||||
|
fn from(content: UserMessageContent) -> Self {
|
||||||
|
match content {
|
||||||
|
UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
|
||||||
|
text,
|
||||||
|
annotations: None,
|
||||||
|
}),
|
||||||
|
UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
|
||||||
|
data: image.source.to_string(),
|
||||||
|
mime_type: "image/png".to_string(),
|
||||||
|
annotations: None,
|
||||||
|
uri: None,
|
||||||
|
}),
|
||||||
|
UserMessageContent::Mention { uri, content } => {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
|
fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
|
||||||
LanguageModelImage {
|
LanguageModelImage {
|
||||||
source: image_content.data.into(),
|
source: image_content.data.into(),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue