agent2: Port Zed AI features (#36172)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Bennet Bo Fenner 2025-08-15 13:17:17 +02:00 committed by GitHub
parent f8b0105258
commit 6f3cd42411
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 994 additions and 358 deletions

View file

@ -1,9 +1,8 @@
use crate::{AgentResponseEvent, Thread, templates::Templates};
use crate::{
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
WebSearchTool,
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
};
use acp_thread::AgentModelSelector;
use agent_client_protocol as acp;
@ -11,6 +10,7 @@ use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use collections::{HashSet, IndexMap};
use fs::Fs;
use futures::channel::mpsc;
use futures::{StreamExt, future};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
@ -21,6 +21,7 @@ use prompt_store::{
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
};
use settings::update_settings_file;
use std::any::Any;
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
@ -426,9 +427,9 @@ impl NativeAgent {
self.models.refresh_list(cx);
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, _| {
let model_id = LanguageModels::model_id(&thread.selected_model);
let model_id = LanguageModels::model_id(&thread.model());
if let Some(model) = self.models.model_from_id(&model_id) {
thread.selected_model = model.clone();
thread.set_model(model.clone());
}
});
}
@ -439,6 +440,124 @@ impl NativeAgent {
#[derive(Clone)]
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
impl NativeAgentConnection {
pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
self.0
.read(cx)
.sessions
.get(session_id)
.map(|session| session.thread.clone())
}
fn run_turn(
&self,
session_id: acp::SessionId,
cx: &mut App,
f: impl 'static
+ FnOnce(
Entity<Thread>,
&mut App,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
) -> Task<Result<acp::PromptResponse>> {
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
agent
.sessions
.get_mut(&session_id)
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
}) else {
return Task::ready(Err(anyhow!("Session not found")));
};
log::debug!("Found session for: {}", session_id);
let mut response_stream = match f(thread, cx) {
Ok(stream) => stream,
Err(err) => return Task::ready(Err(err)),
};
cx.spawn(async move |cx| {
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
match event {
AgentResponseEvent::Text(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
false,
cx,
)
})?;
}
AgentResponseEvent::Thinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
true,
cx,
)
})?;
}
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
}) => {
let recv = acp_thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call, options, cx)
})?;
cx.background_spawn(async move {
if let Some(option) = recv
.await
.context("authorization sender was dropped")
.log_err()
{
response
.send(option)
.map(|_| anyhow!("authorization receiver was dropped"))
.log_err();
}
})
.detach();
}
AgentResponseEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})?;
}
AgentResponseEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
AgentResponseEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
}
}
Err(e) => {
log::error!("Error in model response stream: {:?}", e);
return Err(e);
}
}
}
log::info!("Response stream completed");
anyhow::Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}
}
impl AgentModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
log::debug!("NativeAgentConnection::list_models called");
@ -472,7 +591,7 @@ impl AgentModelSelector for NativeAgentConnection {
};
thread.update(cx, |thread, _cx| {
thread.selected_model = model.clone();
thread.set_model(model.clone());
});
update_settings_file::<AgentSettings>(
@ -502,7 +621,7 @@ impl AgentModelSelector for NativeAgentConnection {
else {
return Task::ready(Err(anyhow!("Session not found")));
};
let model = thread.read(cx).selected_model.clone();
let model = thread.read(cx).model().clone();
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
else {
return Task::ready(Err(anyhow!("Provider not found")));
@ -644,25 +763,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
) -> Task<Result<acp::PromptResponse>> {
let id = id.expect("UserMessageId is required");
let session_id = params.session_id.clone();
let agent = self.0.clone();
log::info!("Received prompt request for session: {}", session_id);
log::debug!("Prompt blocks count: {}", params.prompt.len());
cx.spawn(async move |cx| {
// Get session
let (thread, acp_thread) = agent
.update(cx, |agent, _| {
agent
.sessions
.get_mut(&session_id)
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
})?
.ok_or_else(|| {
log::error!("Session not found: {}", session_id);
anyhow::anyhow!("Session not found")
})?;
log::debug!("Found session for: {}", session_id);
self.run_turn(session_id, cx, |thread, cx| {
let content: Vec<UserMessageContent> = params
.prompt
.into_iter()
@ -672,99 +776,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::debug!("Message id: {:?}", id);
log::debug!("Message content: {:?}", content);
// Get model using the ModelSelector capability (always available for agent2)
// Get the selected model from the thread directly
let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
// Send to thread
log::info!("Sending message to thread with model: {:?}", model.name());
let mut response_stream =
thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
match event {
AgentResponseEvent::Text(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
false,
cx,
)
})?;
}
AgentResponseEvent::Thinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
true,
cx,
)
})?;
}
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
}) => {
let recv = acp_thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call, options, cx)
})?;
cx.background_spawn(async move {
if let Some(option) = recv
.await
.context("authorization sender was dropped")
.log_err()
{
response
.send(option)
.map(|_| anyhow!("authorization receiver was dropped"))
.log_err();
}
})
.detach();
}
AgentResponseEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})?;
}
AgentResponseEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
AgentResponseEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
}
}
Err(e) => {
log::error!("Error in model response stream: {:?}", e);
// TODO: Consider sending an error message to the UI
break;
}
}
}
log::info!("Response stream completed");
anyhow::Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
Ok(thread.update(cx, |thread, cx| {
log::info!(
"Sending message to thread with model: {:?}",
thread.model().name()
);
thread.send(id, content, cx)
}))
})
}
fn resume(
&self,
session_id: &acp::SessionId,
_cx: &mut App,
) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
Some(Rc::new(NativeAgentSessionResume {
connection: self.clone(),
session_id: session_id.clone(),
}) as _)
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
@ -786,6 +818,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
})
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}
struct NativeAgentSessionEditor(Entity<Thread>);
@ -796,6 +832,20 @@ impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
}
}
struct NativeAgentSessionResume {
connection: NativeAgentConnection,
session_id: acp::SessionId,
}
impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
self.connection
.run_turn(self.session_id.clone(), cx, |thread, cx| {
thread.update(cx, |thread, cx| thread.resume(cx))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -957,7 +1007,7 @@ mod tests {
agent.read_with(cx, |agent, _| {
let session = agent.sessions.get(&session_id).unwrap();
session.thread.read_with(cx, |thread, _| {
assert_eq!(thread.selected_model.id().0, "fake");
assert_eq!(thread.model().id().0, "fake");
});
});

View file

@ -12,9 +12,9 @@ use gpui::{
};
use indoc::indoc;
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
fake_provider::FakeLanguageModel,
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
Role, StopReason, fake_provider::FakeLanguageModel,
};
use project::Project;
use prompt_store::ProjectContext;
@ -394,8 +394,194 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
}
#[gpui::test]
async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events = thread.update(cx, |thread, cx| {
thread.add_tool(EchoTool);
thread.send(UserMessageId::new(), ["abc"], cx)
});
cx.run_until_parked();
let tool_use = LanguageModelToolUse {
id: "tool_id_1".into(),
name: EchoTool.name().into(),
raw_input: "{}".into(),
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
is_input_complete: true,
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
let tool_result = LanguageModelToolResult {
tool_use_id: "tool_id_1".into(),
tool_name: EchoTool.name().into(),
is_error: false,
content: "def".into(),
output: Some("def".into()),
};
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["abc".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use.clone())],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result.clone())],
cache: false
},
]
);
// Simulate reaching tool use limit.
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
));
fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!(
last_event
.unwrap_err()
.is::<language_model::ToolUseLimitReachedError>()
);
let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["abc".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec!["Continue where you left off".into()],
cache: false
}
]
);
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
fake_model.end_last_completion_stream();
events.collect::<Vec<_>>().await;
thread.read_with(cx, |thread, _cx| {
assert_eq!(
thread.last_message().unwrap().to_markdown(),
indoc! {"
## Assistant
Done
"}
)
});
// Ensure we error if calling resume when tool use limit was *not* reached.
let error = thread
.update(cx, |thread, cx| thread.resume(cx))
.unwrap_err();
assert_eq!(
error.to_string(),
"can only resume after tool use limit is reached"
)
}
#[gpui::test]
async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events = thread.update(cx, |thread, cx| {
thread.add_tool(EchoTool);
thread.send(UserMessageId::new(), ["abc"], cx)
});
cx.run_until_parked();
let tool_use = LanguageModelToolUse {
id: "tool_id_1".into(),
name: EchoTool.name().into(),
raw_input: "{}".into(),
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
is_input_complete: true,
};
let tool_result = LanguageModelToolResult {
tool_use_id: "tool_id_1".into(),
tool_name: EchoTool.name().into(),
is_error: false,
content: "def".into(),
output: Some("def".into()),
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
));
fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!(
last_event
.unwrap_err()
.is::<language_model::ToolUseLimitReachedError>()
);
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), vec!["ghi"], cx)
});
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages[1..],
vec![
LanguageModelRequestMessage {
role: Role::User,
content: vec!["abc".into()],
cache: false
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result)],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec!["ghi".into()],
cache: false
}
]
);
}
async fn expect_tool_call(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> acp::ToolCall {
let event = events
.next()
@ -411,7 +597,7 @@ async fn expect_tool_call(
}
async fn expect_tool_call_update_fields(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> acp::ToolCallUpdate {
let event = events
.next()
@ -429,7 +615,7 @@ async fn expect_tool_call_update_fields(
}
async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> ToolCallAuthorization {
loop {
let event = events
@ -1007,9 +1193,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
}
/// Filters out the stop events for asserting against in tests
fn stop_events(
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> Vec<acp::StopReason> {
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
result_events
.into_iter()
.filter_map(|event| match event.unwrap() {

View file

@ -7,7 +7,7 @@ use std::future;
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct EchoToolInput {
/// The text to echo.
text: String,
pub text: String,
}
pub struct EchoTool;

View file

@ -2,10 +2,10 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, AgentSettings};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use cloud_llm_client::{CompletionIntent, CompletionMode};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
use collections::IndexMap;
use fs::Fs;
use futures::{
@ -14,10 +14,10 @@ use futures::{
};
use gpui::{App, Context, Entity, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
};
use project::Project;
use prompt_store::ProjectContext;
@ -33,6 +33,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock};
pub enum Message {
User(UserMessage),
Agent(AgentMessage),
Resume,
}
impl Message {
@ -47,6 +48,7 @@ impl Message {
match self {
Message::User(message) => message.to_markdown(),
Message::Agent(message) => message.to_markdown(),
Message::Resume => "[resumed after tool use limit was reached]".into(),
}
}
}
@ -320,7 +322,11 @@ impl AgentMessage {
}
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
let mut content = Vec::with_capacity(self.content.len());
let mut assistant_message = LanguageModelRequestMessage {
role: Role::Assistant,
content: Vec::with_capacity(self.content.len()),
cache: false,
};
for chunk in &self.content {
let chunk = match chunk {
AgentMessageContent::Text(text) => {
@ -342,29 +348,30 @@ impl AgentMessage {
language_model::MessageContent::Image(value.clone())
}
};
content.push(chunk);
assistant_message.content.push(chunk);
}
let mut messages = vec![LanguageModelRequestMessage {
role: Role::Assistant,
content,
let mut user_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
}];
};
if !self.tool_results.is_empty() {
let mut tool_results = Vec::with_capacity(self.tool_results.len());
for tool_result in self.tool_results.values() {
tool_results.push(language_model::MessageContent::ToolResult(
for tool_result in self.tool_results.values() {
user_message
.content
.push(language_model::MessageContent::ToolResult(
tool_result.clone(),
));
}
messages.push(LanguageModelRequestMessage {
role: Role::User,
content: tool_results,
cache: false,
});
}
let mut messages = Vec::new();
if !assistant_message.content.is_empty() {
messages.push(assistant_message);
}
if !user_message.content.is_empty() {
messages.push(user_message);
}
messages
}
}
@ -413,11 +420,12 @@ pub struct Thread {
running_turn: Option<Task<()>>,
pending_message: Option<AgentMessage>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
tool_use_limit_reached: bool,
context_server_registry: Entity<ContextServerRegistry>,
profile_id: AgentProfileId,
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
@ -429,7 +437,7 @@ impl Thread {
context_server_registry: Entity<ContextServerRegistry>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
default_model: Arc<dyn LanguageModel>,
model: Arc<dyn LanguageModel>,
cx: &mut Context<Self>,
) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
@ -439,11 +447,12 @@ impl Thread {
running_turn: None,
pending_message: None,
tools: BTreeMap::default(),
tool_use_limit_reached: false,
context_server_registry,
profile_id,
project_context,
templates,
selected_model: default_model,
model,
project,
action_log,
}
@ -457,7 +466,19 @@ impl Thread {
&self.action_log
}
pub fn set_mode(&mut self, mode: CompletionMode) {
pub fn model(&self) -> &Arc<dyn LanguageModel> {
&self.model
}
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
self.model = model;
}
pub fn completion_mode(&self) -> CompletionMode {
self.completion_mode
}
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
@ -499,36 +520,59 @@ impl Thread {
Ok(())
}
pub fn resume(
&mut self,
cx: &mut Context<Self>,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
anyhow::ensure!(
self.tool_use_limit_reached,
"can only resume after tool use limit is reached"
);
self.messages.push(Message::Resume);
cx.notify();
log::info!("Total messages in thread: {}", self.messages.len());
Ok(self.run_turn(cx))
}
/// Sending a message results in the model streaming a response, which could include tool calls.
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
pub fn send<T>(
&mut self,
message_id: UserMessageId,
id: UserMessageId,
content: impl IntoIterator<Item = T>,
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
where
T: Into<UserMessageContent>,
{
let model = self.selected_model.clone();
log::info!("Thread::send called with model: {:?}", self.model.name());
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
log::info!("Thread::send called with model: {:?}", model.name());
log::debug!("Thread::send content: {:?}", content);
self.messages
.push(Message::User(UserMessage { id, content }));
cx.notify();
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let event_stream = AgentResponseEventStream(events_tx);
self.messages.push(Message::User(UserMessage {
id: message_id.clone(),
content,
}));
log::info!("Total messages in thread: {}", self.messages.len());
self.run_turn(cx)
}
fn run_turn(
&mut self,
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
let model = self.model.clone();
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
let event_stream = AgentResponseEventStream(events_tx);
let message_ix = self.messages.len().saturating_sub(1);
self.tool_use_limit_reached = false;
self.running_turn = Some(cx.spawn(async move |this, cx| {
log::info!("Starting agent turn execution");
let turn_result = async {
let turn_result: Result<()> = async {
let mut completion_intent = CompletionIntent::UserPrompt;
loop {
log::debug!(
@ -543,13 +587,22 @@ impl Thread {
let mut events = model.stream_completion(request, cx).await?;
log::debug!("Stream completion started successfully");
let mut tool_use_limit_reached = false;
let mut tool_uses = FuturesUnordered::new();
while let Some(event) = events.next().await {
match event? {
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::ToolUseLimitReached,
) => {
tool_use_limit_reached = true;
}
LanguageModelCompletionEvent::Stop(reason) => {
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
this.update(cx, |this, _cx| this.truncate(message_id))??;
this.update(cx, |this, _cx| {
this.flush_pending_message();
this.messages.truncate(message_ix);
})?;
return Ok(());
}
}
@ -567,12 +620,7 @@ impl Thread {
}
}
if tool_uses.is_empty() {
log::info!("No tool uses found, completing turn");
return Ok(());
}
log::info!("Found {} tool uses to execute", tool_uses.len());
let used_tools = tool_uses.is_empty();
while let Some(tool_result) = tool_uses.next().await {
log::info!("Tool finished {:?}", tool_result);
@ -596,8 +644,17 @@ impl Thread {
.ok();
}
this.update(cx, |this, _| this.flush_pending_message())?;
completion_intent = CompletionIntent::ToolResults;
if tool_use_limit_reached {
log::info!("Tool use limit reached, completing turn");
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
return Err(language_model::ToolUseLimitReachedError.into());
} else if used_tools {
log::info!("No tool uses found, completing turn");
return Ok(());
} else {
this.update(cx, |this, _| this.flush_pending_message())?;
completion_intent = CompletionIntent::ToolResults;
}
}
}
.await;
@ -678,10 +735,10 @@ impl Thread {
fn handle_text_event(
&mut self,
new_text: String,
events_stream: &AgentResponseEventStream,
event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) {
events_stream.send_text(&new_text);
event_stream.send_text(&new_text);
let last_message = self.pending_message();
if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
@ -798,8 +855,9 @@ impl Thread {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
let supports_images = self.selected_model.supports_images();
let supports_images = self.model.supports_images();
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
log::info!("Running tool {}", tool_use.name);
Some(cx.foreground_executor().spawn(async move {
let tool_result = tool_result.await.and_then(|output| {
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
@ -902,7 +960,7 @@ impl Thread {
name: tool_name,
description: tool.description().to_string(),
input_schema: tool
.input_schema(self.selected_model.tool_input_format())
.input_schema(self.model.tool_input_format())
.log_err()?,
})
})
@ -917,7 +975,7 @@ impl Thread {
thread_id: None,
prompt_id: None,
intent: Some(completion_intent),
mode: Some(self.completion_mode),
mode: Some(self.completion_mode.into()),
messages,
tools,
tool_choice: None,
@ -935,7 +993,7 @@ impl Thread {
.profiles
.get(&self.profile_id)
.context("profile not found")?;
let provider_id = self.selected_model.provider_id();
let provider_id = self.model.provider_id();
Ok(self
.tools
@ -971,6 +1029,11 @@ impl Thread {
match message {
Message::User(message) => messages.push(message.to_request()),
Message::Agent(message) => messages.extend(message.to_request()),
Message::Resume => messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec!["Continue where you left off".into()],
cache: false,
}),
}
}
@ -1123,9 +1186,7 @@ where
}
#[derive(Clone)]
struct AgentResponseEventStream(
mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
);
struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
impl AgentResponseEventStream {
fn send_text(&self, text: &str) {
@ -1212,8 +1273,8 @@ impl AgentResponseEventStream {
}
}
fn send_error(&self, error: LanguageModelCompletionError) {
self.0.unbounded_send(Err(error)).ok();
fn send_error(&self, error: impl Into<anyhow::Error>) {
self.0.unbounded_send(Err(error.into())).ok();
}
}
@ -1229,8 +1290,7 @@ pub struct ToolCallEventStream {
impl ToolCallEventStream {
#[cfg(test)]
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
@ -1351,9 +1411,7 @@ impl ToolCallEventStream {
}
#[cfg(test)]
pub struct ToolCallEventStreamReceiver(
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
);
pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
#[cfg(test)]
impl ToolCallEventStreamReceiver {
@ -1381,7 +1439,7 @@ impl ToolCallEventStreamReceiver {
#[cfg(test)]
impl std::ops::Deref for ToolCallEventStreamReceiver {
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
fn deref(&self) -> &Self::Target {
&self.0

View file

@ -241,7 +241,7 @@ impl AgentTool for EditFileTool {
thread.build_completion_request(CompletionIntent::ToolResults, cx)
});
let thread = self.thread.read(cx);
let model = thread.selected_model.clone();
let model = thread.model().clone();
let action_log = thread.action_log().clone();
let authorize = self.authorize(&input, &event_stream, cx);