agent2: Port Zed AI features (#36172)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Bennet Bo Fenner 2025-08-15 13:17:17 +02:00 committed by GitHub
parent f8b0105258
commit 6f3cd42411
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 994 additions and 358 deletions

View file

@ -1,9 +1,8 @@
use crate::{AgentResponseEvent, Thread, templates::Templates};
use crate::{
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
WebSearchTool,
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
};
use acp_thread::AgentModelSelector;
use agent_client_protocol as acp;
@ -11,6 +10,7 @@ use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use collections::{HashSet, IndexMap};
use fs::Fs;
use futures::channel::mpsc;
use futures::{StreamExt, future};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
@ -21,6 +21,7 @@ use prompt_store::{
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
};
use settings::update_settings_file;
use std::any::Any;
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
@ -426,9 +427,9 @@ impl NativeAgent {
self.models.refresh_list(cx);
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, _| {
let model_id = LanguageModels::model_id(&thread.selected_model);
let model_id = LanguageModels::model_id(&thread.model());
if let Some(model) = self.models.model_from_id(&model_id) {
thread.selected_model = model.clone();
thread.set_model(model.clone());
}
});
}
@ -439,6 +440,124 @@ impl NativeAgent {
#[derive(Clone)]
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
impl NativeAgentConnection {
pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
self.0
.read(cx)
.sessions
.get(session_id)
.map(|session| session.thread.clone())
}
fn run_turn(
&self,
session_id: acp::SessionId,
cx: &mut App,
f: impl 'static
+ FnOnce(
Entity<Thread>,
&mut App,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
) -> Task<Result<acp::PromptResponse>> {
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
agent
.sessions
.get_mut(&session_id)
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
}) else {
return Task::ready(Err(anyhow!("Session not found")));
};
log::debug!("Found session for: {}", session_id);
let mut response_stream = match f(thread, cx) {
Ok(stream) => stream,
Err(err) => return Task::ready(Err(err)),
};
cx.spawn(async move |cx| {
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
match event {
AgentResponseEvent::Text(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
false,
cx,
)
})?;
}
AgentResponseEvent::Thinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
true,
cx,
)
})?;
}
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
}) => {
let recv = acp_thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call, options, cx)
})?;
cx.background_spawn(async move {
if let Some(option) = recv
.await
.context("authorization sender was dropped")
.log_err()
{
response
.send(option)
.map(|_| anyhow!("authorization receiver was dropped"))
.log_err();
}
})
.detach();
}
AgentResponseEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})?;
}
AgentResponseEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
AgentResponseEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
}
}
Err(e) => {
log::error!("Error in model response stream: {:?}", e);
return Err(e);
}
}
}
log::info!("Response stream completed");
anyhow::Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}
}
impl AgentModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
log::debug!("NativeAgentConnection::list_models called");
@ -472,7 +591,7 @@ impl AgentModelSelector for NativeAgentConnection {
};
thread.update(cx, |thread, _cx| {
thread.selected_model = model.clone();
thread.set_model(model.clone());
});
update_settings_file::<AgentSettings>(
@ -502,7 +621,7 @@ impl AgentModelSelector for NativeAgentConnection {
else {
return Task::ready(Err(anyhow!("Session not found")));
};
let model = thread.read(cx).selected_model.clone();
let model = thread.read(cx).model().clone();
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
else {
return Task::ready(Err(anyhow!("Provider not found")));
@ -644,25 +763,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
) -> Task<Result<acp::PromptResponse>> {
let id = id.expect("UserMessageId is required");
let session_id = params.session_id.clone();
let agent = self.0.clone();
log::info!("Received prompt request for session: {}", session_id);
log::debug!("Prompt blocks count: {}", params.prompt.len());
cx.spawn(async move |cx| {
// Get session
let (thread, acp_thread) = agent
.update(cx, |agent, _| {
agent
.sessions
.get_mut(&session_id)
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
})?
.ok_or_else(|| {
log::error!("Session not found: {}", session_id);
anyhow::anyhow!("Session not found")
})?;
log::debug!("Found session for: {}", session_id);
self.run_turn(session_id, cx, |thread, cx| {
let content: Vec<UserMessageContent> = params
.prompt
.into_iter()
@ -672,99 +776,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::debug!("Message id: {:?}", id);
log::debug!("Message content: {:?}", content);
// Get model using the ModelSelector capability (always available for agent2)
// Get the selected model from the thread directly
let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
// Send to thread
log::info!("Sending message to thread with model: {:?}", model.name());
let mut response_stream =
thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
match event {
AgentResponseEvent::Text(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
false,
cx,
)
})?;
}
AgentResponseEvent::Thinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
true,
cx,
)
})?;
}
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
}) => {
let recv = acp_thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call, options, cx)
})?;
cx.background_spawn(async move {
if let Some(option) = recv
.await
.context("authorization sender was dropped")
.log_err()
{
response
.send(option)
.map(|_| anyhow!("authorization receiver was dropped"))
.log_err();
}
})
.detach();
}
AgentResponseEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})?;
}
AgentResponseEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
AgentResponseEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
}
}
Err(e) => {
log::error!("Error in model response stream: {:?}", e);
// TODO: Consider sending an error message to the UI
break;
}
}
}
log::info!("Response stream completed");
anyhow::Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
Ok(thread.update(cx, |thread, cx| {
log::info!(
"Sending message to thread with model: {:?}",
thread.model().name()
);
thread.send(id, content, cx)
}))
})
}
fn resume(
&self,
session_id: &acp::SessionId,
_cx: &mut App,
) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
Some(Rc::new(NativeAgentSessionResume {
connection: self.clone(),
session_id: session_id.clone(),
}) as _)
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
@ -786,6 +818,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
})
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}
struct NativeAgentSessionEditor(Entity<Thread>);
@ -796,6 +832,20 @@ impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
}
}
struct NativeAgentSessionResume {
connection: NativeAgentConnection,
session_id: acp::SessionId,
}
impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
self.connection
.run_turn(self.session_id.clone(), cx, |thread, cx| {
thread.update(cx, |thread, cx| thread.resume(cx))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -957,7 +1007,7 @@ mod tests {
agent.read_with(cx, |agent, _| {
let session = agent.sessions.get(&session_id).unwrap();
session.thread.read_with(cx, |thread, _| {
assert_eq!(thread.selected_model.id().0, "fake");
assert_eq!(thread.model().id().0, "fake");
});
});