WIP
This commit is contained in:
parent
501e72e8f0
commit
5259c8692d
6 changed files with 250 additions and 86 deletions
|
@ -1,9 +1,9 @@
|
||||||
use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
|
use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
|
||||||
use crate::{
|
use crate::{
|
||||||
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
|
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
|
||||||
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
|
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
|
||||||
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
|
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
|
||||||
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
|
UserMessageContent, WebSearchTool, templates::Templates,
|
||||||
};
|
};
|
||||||
use crate::{DbThread, ThreadsDatabase};
|
use crate::{DbThread, ThreadsDatabase};
|
||||||
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
|
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
|
||||||
|
@ -461,10 +461,7 @@ impl NativeAgentConnection {
|
||||||
session_id: acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
f: impl 'static
|
f: impl 'static
|
||||||
+ FnOnce(
|
+ FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
|
||||||
Entity<Thread>,
|
|
||||||
&mut App,
|
|
||||||
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
|
|
||||||
) -> Task<Result<acp::PromptResponse>> {
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
|
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
|
||||||
agent
|
agent
|
||||||
|
@ -488,7 +485,10 @@ impl NativeAgentConnection {
|
||||||
log::trace!("Received completion event: {:?}", event);
|
log::trace!("Received completion event: {:?}", event);
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
AgentResponseEvent::Text(text) => {
|
ThreadEvent::UserMessage(message) => {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
ThreadEvent::AgentText(text) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.push_assistant_content_block(
|
thread.push_assistant_content_block(
|
||||||
acp::ContentBlock::Text(acp::TextContent {
|
acp::ContentBlock::Text(acp::TextContent {
|
||||||
|
@ -500,7 +500,7 @@ impl NativeAgentConnection {
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::Thinking(text) => {
|
ThreadEvent::AgentThinking(text) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.push_assistant_content_block(
|
thread.push_assistant_content_block(
|
||||||
acp::ContentBlock::Text(acp::TextContent {
|
acp::ContentBlock::Text(acp::TextContent {
|
||||||
|
@ -512,7 +512,7 @@ impl NativeAgentConnection {
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
|
ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||||
tool_call,
|
tool_call,
|
||||||
options,
|
options,
|
||||||
response,
|
response,
|
||||||
|
@ -535,17 +535,17 @@ impl NativeAgentConnection {
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
AgentResponseEvent::ToolCall(tool_call) => {
|
ThreadEvent::ToolCall(tool_call) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.upsert_tool_call(tool_call, cx)
|
thread.upsert_tool_call(tool_call, cx)
|
||||||
})??;
|
})??;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::ToolCallUpdate(update) => {
|
ThreadEvent::ToolCallUpdate(update) => {
|
||||||
acp_thread.update(cx, |thread, cx| {
|
acp_thread.update(cx, |thread, cx| {
|
||||||
thread.update_tool_call(update, cx)
|
thread.update_tool_call(update, cx)
|
||||||
})??;
|
})??;
|
||||||
}
|
}
|
||||||
AgentResponseEvent::Stop(stop_reason) => {
|
ThreadEvent::Stop(stop_reason) => {
|
||||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||||
return Ok(acp::PromptResponse { stop_reason });
|
return Ok(acp::PromptResponse { stop_reason });
|
||||||
}
|
}
|
||||||
|
@ -786,7 +786,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|thread| AcpThreadMetadata {
|
.map(|thread| AcpThreadMetadata {
|
||||||
agent: NATIVE_AGENT_SERVER_NAME.clone(),
|
agent: NATIVE_AGENT_SERVER_NAME.clone(),
|
||||||
id: thread.id,
|
id: thread.id.into(),
|
||||||
title: thread.title,
|
title: thread.title,
|
||||||
updated_at: thread.updated_at,
|
updated_at: thread.updated_at,
|
||||||
})
|
})
|
||||||
|
@ -806,11 +806,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
session_id: acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
||||||
|
let thread_id = session_id.clone().into();
|
||||||
let database = self.0.update(cx, |this, _| this.thread_database.clone());
|
let database = self.0.update(cx, |this, _| this.thread_database.clone());
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
let database = database.await.map_err(|e| anyhow!(e))?;
|
let database = database.await.map_err(|e| anyhow!(e))?;
|
||||||
let db_thread = database
|
let db_thread = database
|
||||||
.load_thread(session_id.clone())
|
.load_thread(thread_id)
|
||||||
.await?
|
.await?
|
||||||
.context("no such thread found")?;
|
.context("no such thread found")?;
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
|
use crate::{AgentMessage, AgentMessageContent, ThreadId, UserMessage, UserMessageContent};
|
||||||
use agent::thread_store;
|
use agent::thread_store;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use agent_settings::{AgentProfileId, CompletionMode};
|
use agent_settings::{AgentProfileId, CompletionMode};
|
||||||
|
@ -24,7 +24,7 @@ pub type DbLanguageModel = thread_store::SerializedLanguageModel;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct DbThreadMetadata {
|
pub struct DbThreadMetadata {
|
||||||
pub id: acp::SessionId,
|
pub id: ThreadId,
|
||||||
#[serde(alias = "summary")]
|
#[serde(alias = "summary")]
|
||||||
pub title: SharedString,
|
pub title: SharedString,
|
||||||
pub updated_at: DateTime<Utc>,
|
pub updated_at: DateTime<Utc>,
|
||||||
|
@ -323,7 +323,7 @@ impl ThreadsDatabase {
|
||||||
|
|
||||||
for (id, summary, updated_at) in rows {
|
for (id, summary, updated_at) in rows {
|
||||||
threads.push(DbThreadMetadata {
|
threads.push(DbThreadMetadata {
|
||||||
id: acp::SessionId(id),
|
id: ThreadId(id),
|
||||||
title: summary.into(),
|
title: summary.into(),
|
||||||
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
|
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
|
||||||
});
|
});
|
||||||
|
@ -333,7 +333,7 @@ impl ThreadsDatabase {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
|
pub fn load_thread(&self, id: ThreadId) -> Task<Result<Option<DbThread>>> {
|
||||||
let connection = self.connection.clone();
|
let connection = self.connection.clone();
|
||||||
|
|
||||||
self.executor.spawn(async move {
|
self.executor.spawn(async move {
|
||||||
|
|
|
@ -329,7 +329,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let mut saw_partial_tool_use = false;
|
let mut saw_partial_tool_use = false;
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
|
if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
// Look for a tool use in the thread's last message
|
// Look for a tool use in the thread's last message
|
||||||
let message = thread.last_message().unwrap();
|
let message = thread.last_message().unwrap();
|
||||||
|
@ -710,7 +710,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn expect_tool_call(
|
async fn expect_tool_call(
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
|
||||||
) -> acp::ToolCall {
|
) -> acp::ToolCall {
|
||||||
let event = events
|
let event = events
|
||||||
.next()
|
.next()
|
||||||
|
@ -718,7 +718,7 @@ async fn expect_tool_call(
|
||||||
.expect("no tool call authorization event received")
|
.expect("no tool call authorization event received")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
match event {
|
match event {
|
||||||
AgentResponseEvent::ToolCall(tool_call) => return tool_call,
|
ThreadEvent::ToolCall(tool_call) => return tool_call,
|
||||||
event => {
|
event => {
|
||||||
panic!("Unexpected event {event:?}");
|
panic!("Unexpected event {event:?}");
|
||||||
}
|
}
|
||||||
|
@ -726,7 +726,7 @@ async fn expect_tool_call(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn expect_tool_call_update_fields(
|
async fn expect_tool_call_update_fields(
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
|
||||||
) -> acp::ToolCallUpdate {
|
) -> acp::ToolCallUpdate {
|
||||||
let event = events
|
let event = events
|
||||||
.next()
|
.next()
|
||||||
|
@ -734,7 +734,7 @@ async fn expect_tool_call_update_fields(
|
||||||
.expect("no tool call authorization event received")
|
.expect("no tool call authorization event received")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
match event {
|
match event {
|
||||||
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
|
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
|
||||||
return update;
|
return update;
|
||||||
}
|
}
|
||||||
event => {
|
event => {
|
||||||
|
@ -744,7 +744,7 @@ async fn expect_tool_call_update_fields(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn next_tool_call_authorization(
|
async fn next_tool_call_authorization(
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
|
||||||
) -> ToolCallAuthorization {
|
) -> ToolCallAuthorization {
|
||||||
loop {
|
loop {
|
||||||
let event = events
|
let event = events
|
||||||
|
@ -752,7 +752,7 @@ async fn next_tool_call_authorization(
|
||||||
.await
|
.await
|
||||||
.expect("no tool call authorization event received")
|
.expect("no tool call authorization event received")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
|
if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
|
||||||
let permission_kinds = tool_call_authorization
|
let permission_kinds = tool_call_authorization
|
||||||
.options
|
.options
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -912,13 +912,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
let mut echo_completed = false;
|
let mut echo_completed = false;
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
match event.unwrap() {
|
match event.unwrap() {
|
||||||
AgentResponseEvent::ToolCall(tool_call) => {
|
ThreadEvent::ToolCall(tool_call) => {
|
||||||
assert_eq!(tool_call.title, expected_tools.remove(0));
|
assert_eq!(tool_call.title, expected_tools.remove(0));
|
||||||
if tool_call.title == "Echo" {
|
if tool_call.title == "Echo" {
|
||||||
echo_id = Some(tool_call.id);
|
echo_id = Some(tool_call.id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
||||||
acp::ToolCallUpdate {
|
acp::ToolCallUpdate {
|
||||||
id,
|
id,
|
||||||
fields:
|
fields:
|
||||||
|
@ -946,7 +946,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
assert!(
|
assert!(
|
||||||
matches!(
|
matches!(
|
||||||
last_event,
|
last_event,
|
||||||
Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
|
Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
|
||||||
),
|
),
|
||||||
"unexpected event {last_event:?}"
|
"unexpected event {last_event:?}"
|
||||||
);
|
);
|
||||||
|
@ -1386,11 +1386,11 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Filters out the stop events for asserting against in tests
|
/// Filters out the stop events for asserting against in tests
|
||||||
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
|
fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
|
||||||
result_events
|
result_events
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|event| match event.unwrap() {
|
.filter_map(|event| match event.unwrap() {
|
||||||
AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
|
ThreadEvent::Stop(stop_reason) => Some(stop_reason),
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
|
use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates};
|
||||||
use acp_thread::{MentionUri, UserMessageId};
|
use acp_thread::{MentionUri, UserMessageId};
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
|
@ -30,10 +30,12 @@ use std::{fmt::Write, ops::Range};
|
||||||
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
|
||||||
|
|
||||||
#[derive(
|
#[derive(
|
||||||
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
|
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
|
||||||
)]
|
)]
|
||||||
pub struct ThreadId(Arc<str>);
|
pub struct ThreadId(pub(crate) Arc<str>);
|
||||||
|
|
||||||
impl ThreadId {
|
impl ThreadId {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
|
@ -53,6 +55,18 @@ impl From<&str> for ThreadId {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<acp::SessionId> for ThreadId {
|
||||||
|
fn from(value: acp::SessionId) -> Self {
|
||||||
|
Self(value.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ThreadId> for acp::SessionId {
|
||||||
|
fn from(value: ThreadId) -> Self {
|
||||||
|
Self(value.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// The ID of the user prompt that initiated a request.
|
/// The ID of the user prompt that initiated a request.
|
||||||
///
|
///
|
||||||
/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
|
/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
|
||||||
|
@ -313,9 +327,6 @@ impl AgentMessage {
|
||||||
AgentMessageContent::RedactedThinking(_) => {
|
AgentMessageContent::RedactedThinking(_) => {
|
||||||
markdown.push_str("<redacted_thinking />\n")
|
markdown.push_str("<redacted_thinking />\n")
|
||||||
}
|
}
|
||||||
AgentMessageContent::Image(_) => {
|
|
||||||
markdown.push_str("<image />\n");
|
|
||||||
}
|
|
||||||
AgentMessageContent::ToolUse(tool_use) => {
|
AgentMessageContent::ToolUse(tool_use) => {
|
||||||
markdown.push_str(&format!(
|
markdown.push_str(&format!(
|
||||||
"**Tool Use**: {} (ID: {})\n",
|
"**Tool Use**: {} (ID: {})\n",
|
||||||
|
@ -386,9 +397,6 @@ impl AgentMessage {
|
||||||
AgentMessageContent::ToolUse(value) => {
|
AgentMessageContent::ToolUse(value) => {
|
||||||
language_model::MessageContent::ToolUse(value.clone())
|
language_model::MessageContent::ToolUse(value.clone())
|
||||||
}
|
}
|
||||||
AgentMessageContent::Image(value) => {
|
|
||||||
language_model::MessageContent::Image(value.clone())
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
assistant_message.content.push(chunk);
|
assistant_message.content.push(chunk);
|
||||||
}
|
}
|
||||||
|
@ -432,14 +440,14 @@ pub enum AgentMessageContent {
|
||||||
signature: Option<String>,
|
signature: Option<String>,
|
||||||
},
|
},
|
||||||
RedactedThinking(String),
|
RedactedThinking(String),
|
||||||
Image(LanguageModelImage),
|
|
||||||
ToolUse(LanguageModelToolUse),
|
ToolUse(LanguageModelToolUse),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum AgentResponseEvent {
|
pub enum ThreadEvent {
|
||||||
Text(String),
|
UserMessage(UserMessage),
|
||||||
Thinking(String),
|
AgentText(String),
|
||||||
|
AgentThinking(String),
|
||||||
ToolCall(acp::ToolCall),
|
ToolCall(acp::ToolCall),
|
||||||
ToolCallUpdate(acp_thread::ToolCallUpdate),
|
ToolCallUpdate(acp_thread::ToolCallUpdate),
|
||||||
ToolCallAuthorization(ToolCallAuthorization),
|
ToolCallAuthorization(ToolCallAuthorization),
|
||||||
|
@ -504,6 +512,121 @@ impl Thread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_db(
|
||||||
|
id: ThreadId,
|
||||||
|
db_thread: DbThread,
|
||||||
|
project: Entity<Project>,
|
||||||
|
project_context: Rc<RefCell<ProjectContext>>,
|
||||||
|
context_server_registry: Entity<ContextServerRegistry>,
|
||||||
|
action_log: Entity<ActionLog>,
|
||||||
|
templates: Arc<Templates>,
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Self {
|
||||||
|
let profile_id = db_thread
|
||||||
|
.profile
|
||||||
|
.unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
|
||||||
|
Self {
|
||||||
|
id,
|
||||||
|
prompt_id: PromptId::new(),
|
||||||
|
messages: db_thread.messages,
|
||||||
|
completion_mode: CompletionMode::Normal,
|
||||||
|
running_turn: None,
|
||||||
|
pending_message: None,
|
||||||
|
tools: BTreeMap::default(),
|
||||||
|
tool_use_limit_reached: false,
|
||||||
|
context_server_registry,
|
||||||
|
profile_id,
|
||||||
|
project_context,
|
||||||
|
templates,
|
||||||
|
model,
|
||||||
|
project,
|
||||||
|
action_log,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn replay(&self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
|
||||||
|
let (tx, rx) = mpsc::unbounded();
|
||||||
|
let stream = ThreadEventStream(tx);
|
||||||
|
for message in &self.messages {
|
||||||
|
match message {
|
||||||
|
Message::User(user_message) => stream.send_user_message(&user_message),
|
||||||
|
Message::Agent(assistant_message) => {
|
||||||
|
for content in &assistant_message.content {
|
||||||
|
match content {
|
||||||
|
AgentMessageContent::Text(text) => stream.send_text(text),
|
||||||
|
AgentMessageContent::Thinking { text, .. } => {
|
||||||
|
stream.send_thinking(text)
|
||||||
|
}
|
||||||
|
AgentMessageContent::RedactedThinking(_) => {}
|
||||||
|
AgentMessageContent::ToolUse(tool_use) => {
|
||||||
|
self.replay_tool_call(
|
||||||
|
tool_use,
|
||||||
|
assistant_message.tool_results.get(&tool_use.id),
|
||||||
|
&stream,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Message::Resume => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rx
|
||||||
|
}
|
||||||
|
|
||||||
|
fn replay_tool_call(
|
||||||
|
&self,
|
||||||
|
tool_use: &LanguageModelToolUse,
|
||||||
|
tool_result: Option<&LanguageModelToolResult>,
|
||||||
|
stream: &ThreadEventStream,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
|
||||||
|
stream
|
||||||
|
.0
|
||||||
|
.unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
|
||||||
|
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
||||||
|
title: tool_use.name.to_string(),
|
||||||
|
kind: acp::ToolKind::Other,
|
||||||
|
status: acp::ToolCallStatus::Failed,
|
||||||
|
content: Vec::new(),
|
||||||
|
locations: Vec::new(),
|
||||||
|
raw_input: Some(tool_use.input.clone()),
|
||||||
|
raw_output: None,
|
||||||
|
})))
|
||||||
|
.ok();
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let title = tool.initial_title(tool_use.input.clone());
|
||||||
|
let kind = tool.kind();
|
||||||
|
stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
|
||||||
|
|
||||||
|
if let Some(output) = tool_result
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|result| result.output.clone())
|
||||||
|
{
|
||||||
|
let tool_event_stream = ToolCallEventStream::new(
|
||||||
|
tool_use.id.clone(),
|
||||||
|
stream.clone(),
|
||||||
|
Some(self.project.read(cx).fs().clone()),
|
||||||
|
);
|
||||||
|
tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
|
||||||
|
.log_err();
|
||||||
|
} else {
|
||||||
|
stream.update_tool_call_fields(
|
||||||
|
&tool_use.id,
|
||||||
|
acp::ToolCallUpdateFields {
|
||||||
|
content: Some(vec![TOOL_CANCELED_MESSAGE.into()]),
|
||||||
|
status: Some(acp::ToolCallStatus::Failed),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn project(&self) -> &Entity<Project> {
|
pub fn project(&self) -> &Entity<Project> {
|
||||||
&self.project
|
&self.project
|
||||||
}
|
}
|
||||||
|
@ -574,7 +697,7 @@ impl Thread {
|
||||||
pub fn resume(
|
pub fn resume(
|
||||||
&mut self,
|
&mut self,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
|
) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
|
||||||
anyhow::ensure!(
|
anyhow::ensure!(
|
||||||
self.tool_use_limit_reached,
|
self.tool_use_limit_reached,
|
||||||
"can only resume after tool use limit is reached"
|
"can only resume after tool use limit is reached"
|
||||||
|
@ -595,7 +718,7 @@ impl Thread {
|
||||||
id: UserMessageId,
|
id: UserMessageId,
|
||||||
content: impl IntoIterator<Item = T>,
|
content: impl IntoIterator<Item = T>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
|
) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
|
||||||
where
|
where
|
||||||
T: Into<UserMessageContent>,
|
T: Into<UserMessageContent>,
|
||||||
{
|
{
|
||||||
|
@ -613,15 +736,12 @@ impl Thread {
|
||||||
self.run_turn(cx)
|
self.run_turn(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_turn(
|
fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
|
||||||
&mut self,
|
|
||||||
cx: &mut Context<Self>,
|
|
||||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
|
|
||||||
self.cancel();
|
self.cancel();
|
||||||
|
|
||||||
let model = self.model.clone();
|
let model = self.model.clone();
|
||||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
|
||||||
let event_stream = AgentResponseEventStream(events_tx);
|
let event_stream = ThreadEventStream(events_tx);
|
||||||
let message_ix = self.messages.len().saturating_sub(1);
|
let message_ix = self.messages.len().saturating_sub(1);
|
||||||
self.tool_use_limit_reached = false;
|
self.tool_use_limit_reached = false;
|
||||||
self.running_turn = Some(RunningTurn {
|
self.running_turn = Some(RunningTurn {
|
||||||
|
@ -755,7 +875,7 @@ impl Thread {
|
||||||
fn handle_streamed_completion_event(
|
fn handle_streamed_completion_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
event: LanguageModelCompletionEvent,
|
event: LanguageModelCompletionEvent,
|
||||||
event_stream: &AgentResponseEventStream,
|
event_stream: &ThreadEventStream,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Option<Task<LanguageModelToolResult>> {
|
) -> Option<Task<LanguageModelToolResult>> {
|
||||||
log::trace!("Handling streamed completion event: {:?}", event);
|
log::trace!("Handling streamed completion event: {:?}", event);
|
||||||
|
@ -797,7 +917,7 @@ impl Thread {
|
||||||
fn handle_text_event(
|
fn handle_text_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
new_text: String,
|
new_text: String,
|
||||||
event_stream: &AgentResponseEventStream,
|
event_stream: &ThreadEventStream,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
event_stream.send_text(&new_text);
|
event_stream.send_text(&new_text);
|
||||||
|
@ -818,7 +938,7 @@ impl Thread {
|
||||||
&mut self,
|
&mut self,
|
||||||
new_text: String,
|
new_text: String,
|
||||||
new_signature: Option<String>,
|
new_signature: Option<String>,
|
||||||
event_stream: &AgentResponseEventStream,
|
event_stream: &ThreadEventStream,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
event_stream.send_thinking(&new_text);
|
event_stream.send_thinking(&new_text);
|
||||||
|
@ -850,7 +970,7 @@ impl Thread {
|
||||||
fn handle_tool_use_event(
|
fn handle_tool_use_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
tool_use: LanguageModelToolUse,
|
tool_use: LanguageModelToolUse,
|
||||||
event_stream: &AgentResponseEventStream,
|
event_stream: &ThreadEventStream,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Option<Task<LanguageModelToolResult>> {
|
) -> Option<Task<LanguageModelToolResult>> {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -989,9 +1109,7 @@ impl Thread {
|
||||||
tool_use_id: tool_use.id.clone(),
|
tool_use_id: tool_use.id.clone(),
|
||||||
tool_name: tool_use.name.clone(),
|
tool_name: tool_use.name.clone(),
|
||||||
is_error: true,
|
is_error: true,
|
||||||
content: LanguageModelToolResultContent::Text(
|
content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
|
||||||
"Tool canceled by user".into(),
|
|
||||||
),
|
|
||||||
output: None,
|
output: None,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
@ -1143,7 +1261,7 @@ struct RunningTurn {
|
||||||
_task: Task<()>,
|
_task: Task<()>,
|
||||||
/// The current event stream for the running turn. Used to report a final
|
/// The current event stream for the running turn. Used to report a final
|
||||||
/// cancellation event if we cancel the turn.
|
/// cancellation event if we cancel the turn.
|
||||||
event_stream: AgentResponseEventStream,
|
event_stream: ThreadEventStream,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RunningTurn {
|
impl RunningTurn {
|
||||||
|
@ -1196,6 +1314,17 @@ where
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Self::Output>>;
|
) -> Task<Result<Self::Output>>;
|
||||||
|
|
||||||
|
/// Emits events for a previous execution of the tool.
|
||||||
|
fn replay(
|
||||||
|
&self,
|
||||||
|
_input: Self::Input,
|
||||||
|
_output: Self::Output,
|
||||||
|
_event_stream: ToolCallEventStream,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
||||||
Arc::new(Erased(Arc::new(self)))
|
Arc::new(Erased(Arc::new(self)))
|
||||||
}
|
}
|
||||||
|
@ -1223,6 +1352,13 @@ pub trait AnyAgentTool {
|
||||||
event_stream: ToolCallEventStream,
|
event_stream: ToolCallEventStream,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<AgentToolOutput>>;
|
) -> Task<Result<AgentToolOutput>>;
|
||||||
|
fn replay(
|
||||||
|
&self,
|
||||||
|
input: serde_json::Value,
|
||||||
|
output: serde_json::Value,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> AnyAgentTool for Erased<Arc<T>>
|
impl<T> AnyAgentTool for Erased<Arc<T>>
|
||||||
|
@ -1274,21 +1410,39 @@ where
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn replay(
|
||||||
|
&self,
|
||||||
|
input: serde_json::Value,
|
||||||
|
output: serde_json::Value,
|
||||||
|
event_stream: ToolCallEventStream,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Result<()> {
|
||||||
|
let input = serde_json::from_value(input)?;
|
||||||
|
let output = serde_json::from_value(output)?;
|
||||||
|
self.0.replay(input, output, event_stream, cx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
|
struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
|
||||||
|
|
||||||
|
impl ThreadEventStream {
|
||||||
|
fn send_user_message(&self, message: &UserMessage) {
|
||||||
|
self.0
|
||||||
|
.unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
impl AgentResponseEventStream {
|
|
||||||
fn send_text(&self, text: &str) {
|
fn send_text(&self, text: &str) {
|
||||||
self.0
|
self.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
|
.unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_thinking(&self, text: &str) {
|
fn send_thinking(&self, text: &str) {
|
||||||
self.0
|
self.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
|
.unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1300,7 +1454,7 @@ impl AgentResponseEventStream {
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
) {
|
) {
|
||||||
self.0
|
self.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
|
.unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
|
||||||
id,
|
id,
|
||||||
title.to_string(),
|
title.to_string(),
|
||||||
kind,
|
kind,
|
||||||
|
@ -1333,7 +1487,7 @@ impl AgentResponseEventStream {
|
||||||
fields: acp::ToolCallUpdateFields,
|
fields: acp::ToolCallUpdateFields,
|
||||||
) {
|
) {
|
||||||
self.0
|
self.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
.unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
|
||||||
acp::ToolCallUpdate {
|
acp::ToolCallUpdate {
|
||||||
id: acp::ToolCallId(tool_use_id.to_string().into()),
|
id: acp::ToolCallId(tool_use_id.to_string().into()),
|
||||||
fields,
|
fields,
|
||||||
|
@ -1347,17 +1501,17 @@ impl AgentResponseEventStream {
|
||||||
match reason {
|
match reason {
|
||||||
StopReason::EndTurn => {
|
StopReason::EndTurn => {
|
||||||
self.0
|
self.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
|
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
StopReason::MaxTokens => {
|
StopReason::MaxTokens => {
|
||||||
self.0
|
self.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
|
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
StopReason::Refusal => {
|
StopReason::Refusal => {
|
||||||
self.0
|
self.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
|
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
StopReason::ToolUse => {}
|
StopReason::ToolUse => {}
|
||||||
|
@ -1366,7 +1520,7 @@ impl AgentResponseEventStream {
|
||||||
|
|
||||||
fn send_canceled(&self) {
|
fn send_canceled(&self) {
|
||||||
self.0
|
self.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
|
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1378,24 +1532,23 @@ impl AgentResponseEventStream {
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ToolCallEventStream {
|
pub struct ToolCallEventStream {
|
||||||
tool_use_id: LanguageModelToolUseId,
|
tool_use_id: LanguageModelToolUseId,
|
||||||
stream: AgentResponseEventStream,
|
stream: ThreadEventStream,
|
||||||
fs: Option<Arc<dyn Fs>>,
|
fs: Option<Arc<dyn Fs>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolCallEventStream {
|
impl ToolCallEventStream {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
|
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
|
||||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
|
||||||
|
|
||||||
let stream =
|
let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
|
||||||
ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None);
|
|
||||||
|
|
||||||
(stream, ToolCallEventStreamReceiver(events_rx))
|
(stream, ToolCallEventStreamReceiver(events_rx))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new(
|
fn new(
|
||||||
tool_use_id: LanguageModelToolUseId,
|
tool_use_id: LanguageModelToolUseId,
|
||||||
stream: AgentResponseEventStream,
|
stream: ThreadEventStream,
|
||||||
fs: Option<Arc<dyn Fs>>,
|
fs: Option<Arc<dyn Fs>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -1413,7 +1566,7 @@ impl ToolCallEventStream {
|
||||||
pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
|
pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
|
||||||
self.stream
|
self.stream
|
||||||
.0
|
.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
.unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
|
||||||
acp_thread::ToolCallUpdateDiff {
|
acp_thread::ToolCallUpdateDiff {
|
||||||
id: acp::ToolCallId(self.tool_use_id.to_string().into()),
|
id: acp::ToolCallId(self.tool_use_id.to_string().into()),
|
||||||
diff,
|
diff,
|
||||||
|
@ -1426,7 +1579,7 @@ impl ToolCallEventStream {
|
||||||
pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
|
pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
|
||||||
self.stream
|
self.stream
|
||||||
.0
|
.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
.unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
|
||||||
acp_thread::ToolCallUpdateTerminal {
|
acp_thread::ToolCallUpdateTerminal {
|
||||||
id: acp::ToolCallId(self.tool_use_id.to_string().into()),
|
id: acp::ToolCallId(self.tool_use_id.to_string().into()),
|
||||||
terminal,
|
terminal,
|
||||||
|
@ -1444,7 +1597,7 @@ impl ToolCallEventStream {
|
||||||
let (response_tx, response_rx) = oneshot::channel();
|
let (response_tx, response_rx) = oneshot::channel();
|
||||||
self.stream
|
self.stream
|
||||||
.0
|
.0
|
||||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
|
.unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
|
||||||
ToolCallAuthorization {
|
ToolCallAuthorization {
|
||||||
tool_call: acp::ToolCallUpdate {
|
tool_call: acp::ToolCallUpdate {
|
||||||
id: acp::ToolCallId(self.tool_use_id.to_string().into()),
|
id: acp::ToolCallId(self.tool_use_id.to_string().into()),
|
||||||
|
@ -1494,13 +1647,13 @@ impl ToolCallEventStream {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
|
pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
impl ToolCallEventStreamReceiver {
|
impl ToolCallEventStreamReceiver {
|
||||||
pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
|
pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
|
||||||
let event = self.0.next().await;
|
let event = self.0.next().await;
|
||||||
if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
|
if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
|
||||||
auth
|
auth
|
||||||
} else {
|
} else {
|
||||||
panic!("Expected ToolCallAuthorization but got: {:?}", event);
|
panic!("Expected ToolCallAuthorization but got: {:?}", event);
|
||||||
|
@ -1509,9 +1662,9 @@ impl ToolCallEventStreamReceiver {
|
||||||
|
|
||||||
pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
|
pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
|
||||||
let event = self.0.next().await;
|
let event = self.0.next().await;
|
||||||
if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
|
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
|
||||||
acp_thread::ToolCallUpdate::UpdateTerminal(update),
|
update,
|
||||||
))) = event
|
)))) = event
|
||||||
{
|
{
|
||||||
update.terminal
|
update.terminal
|
||||||
} else {
|
} else {
|
||||||
|
@ -1522,7 +1675,7 @@ impl ToolCallEventStreamReceiver {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
impl std::ops::Deref for ToolCallEventStreamReceiver {
|
impl std::ops::Deref for ToolCallEventStreamReceiver {
|
||||||
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
|
type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
&self.0
|
&self.0
|
||||||
|
|
|
@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn replay(
|
||||||
|
&self,
|
||||||
|
_input: serde_json::Value,
|
||||||
|
_output: serde_json::Value,
|
||||||
|
_event_stream: ToolCallEventStream,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -319,7 +319,7 @@ mod tests {
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use util::test::TempTree;
|
use util::test::TempTree;
|
||||||
|
|
||||||
use crate::AgentResponseEvent;
|
use crate::ThreadEvent;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
@ -396,7 +396,7 @@ mod tests {
|
||||||
});
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
let event = stream_rx.try_next();
|
let event = stream_rx.try_next();
|
||||||
if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event {
|
if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event {
|
||||||
auth.response.send(auth.options[0].id.clone()).unwrap();
|
auth.response.send(auth.options[0].id.clone()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue