agent2: Port Zed AI features (#36172)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
f8b0105258
commit
6f3cd42411
17 changed files with 994 additions and 358 deletions
|
@ -1,9 +1,8 @@
|
|||
use crate::{AgentResponseEvent, Thread, templates::Templates};
|
||||
use crate::{
|
||||
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
|
||||
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
|
||||
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
|
||||
WebSearchTool,
|
||||
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
|
||||
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
|
||||
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
|
||||
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
|
||||
};
|
||||
use acp_thread::AgentModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
|
@ -11,6 +10,7 @@ use agent_settings::AgentSettings;
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashSet, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::channel::mpsc;
|
||||
use futures::{StreamExt, future};
|
||||
use gpui::{
|
||||
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
||||
|
@ -21,6 +21,7 @@ use prompt_store::{
|
|||
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
|
||||
};
|
||||
use settings::update_settings_file;
|
||||
use std::any::Any;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
@ -426,9 +427,9 @@ impl NativeAgent {
|
|||
self.models.refresh_list(cx);
|
||||
for session in self.sessions.values_mut() {
|
||||
session.thread.update(cx, |thread, _| {
|
||||
let model_id = LanguageModels::model_id(&thread.selected_model);
|
||||
let model_id = LanguageModels::model_id(&thread.model());
|
||||
if let Some(model) = self.models.model_from_id(&model_id) {
|
||||
thread.selected_model = model.clone();
|
||||
thread.set_model(model.clone());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -439,6 +440,124 @@ impl NativeAgent {
|
|||
#[derive(Clone)]
|
||||
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
||||
|
||||
impl NativeAgentConnection {
|
||||
pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
|
||||
self.0
|
||||
.read(cx)
|
||||
.sessions
|
||||
.get(session_id)
|
||||
.map(|session| session.thread.clone())
|
||||
}
|
||||
|
||||
fn run_turn(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
cx: &mut App,
|
||||
f: impl 'static
|
||||
+ FnOnce(
|
||||
Entity<Thread>,
|
||||
&mut App,
|
||||
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
|
||||
agent
|
||||
.sessions
|
||||
.get_mut(&session_id)
|
||||
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
|
||||
}) else {
|
||||
return Task::ready(Err(anyhow!("Session not found")));
|
||||
};
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
let mut response_stream = match f(thread, cx) {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => return Task::ready(Err(err)),
|
||||
};
|
||||
cx.spawn(async move |cx| {
|
||||
// Handle response stream and forward to session.acp_thread
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
|
||||
match event {
|
||||
AgentResponseEvent::Text(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::Thinking(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
true,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||
tool_call,
|
||||
options,
|
||||
response,
|
||||
}) => {
|
||||
let recv = acp_thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(tool_call, options, cx)
|
||||
})?;
|
||||
cx.background_spawn(async move {
|
||||
if let Some(option) = recv
|
||||
.await
|
||||
.context("authorization sender was dropped")
|
||||
.log_err()
|
||||
{
|
||||
response
|
||||
.send(option)
|
||||
.map(|_| anyhow!("authorization receiver was dropped"))
|
||||
.log_err();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
AgentResponseEvent::ToolCall(tool_call) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.upsert_tool_call(tool_call, cx)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::ToolCallUpdate(update) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.update_tool_call(update, cx)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
return Ok(acp::PromptResponse { stop_reason });
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Error in model response stream: {:?}", e);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Response stream completed");
|
||||
anyhow::Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentModelSelector for NativeAgentConnection {
|
||||
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
|
||||
log::debug!("NativeAgentConnection::list_models called");
|
||||
|
@ -472,7 +591,7 @@ impl AgentModelSelector for NativeAgentConnection {
|
|||
};
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.selected_model = model.clone();
|
||||
thread.set_model(model.clone());
|
||||
});
|
||||
|
||||
update_settings_file::<AgentSettings>(
|
||||
|
@ -502,7 +621,7 @@ impl AgentModelSelector for NativeAgentConnection {
|
|||
else {
|
||||
return Task::ready(Err(anyhow!("Session not found")));
|
||||
};
|
||||
let model = thread.read(cx).selected_model.clone();
|
||||
let model = thread.read(cx).model().clone();
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Provider not found")));
|
||||
|
@ -644,25 +763,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
) -> Task<Result<acp::PromptResponse>> {
|
||||
let id = id.expect("UserMessageId is required");
|
||||
let session_id = params.session_id.clone();
|
||||
let agent = self.0.clone();
|
||||
log::info!("Received prompt request for session: {}", session_id);
|
||||
log::debug!("Prompt blocks count: {}", params.prompt.len());
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Get session
|
||||
let (thread, acp_thread) = agent
|
||||
.update(cx, |agent, _| {
|
||||
agent
|
||||
.sessions
|
||||
.get_mut(&session_id)
|
||||
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
log::error!("Session not found: {}", session_id);
|
||||
anyhow::anyhow!("Session not found")
|
||||
})?;
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
self.run_turn(session_id, cx, |thread, cx| {
|
||||
let content: Vec<UserMessageContent> = params
|
||||
.prompt
|
||||
.into_iter()
|
||||
|
@ -672,99 +776,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
log::debug!("Message id: {:?}", id);
|
||||
log::debug!("Message content: {:?}", content);
|
||||
|
||||
// Get model using the ModelSelector capability (always available for agent2)
|
||||
// Get the selected model from the thread directly
|
||||
let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||
|
||||
// Send to thread
|
||||
log::info!("Sending message to thread with model: {:?}", model.name());
|
||||
let mut response_stream =
|
||||
thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
|
||||
|
||||
// Handle response stream and forward to session.acp_thread
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
|
||||
match event {
|
||||
AgentResponseEvent::Text(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::Thinking(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
true,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||
tool_call,
|
||||
options,
|
||||
response,
|
||||
}) => {
|
||||
let recv = acp_thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(tool_call, options, cx)
|
||||
})?;
|
||||
cx.background_spawn(async move {
|
||||
if let Some(option) = recv
|
||||
.await
|
||||
.context("authorization sender was dropped")
|
||||
.log_err()
|
||||
{
|
||||
response
|
||||
.send(option)
|
||||
.map(|_| anyhow!("authorization receiver was dropped"))
|
||||
.log_err();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
AgentResponseEvent::ToolCall(tool_call) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.upsert_tool_call(tool_call, cx)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::ToolCallUpdate(update) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.update_tool_call(update, cx)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
return Ok(acp::PromptResponse { stop_reason });
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Error in model response stream: {:?}", e);
|
||||
// TODO: Consider sending an error message to the UI
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Response stream completed");
|
||||
anyhow::Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
Ok(thread.update(cx, |thread, cx| {
|
||||
log::info!(
|
||||
"Sending message to thread with model: {:?}",
|
||||
thread.model().name()
|
||||
);
|
||||
thread.send(id, content, cx)
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
fn resume(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
_cx: &mut App,
|
||||
) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
|
||||
Some(Rc::new(NativeAgentSessionResume {
|
||||
connection: self.clone(),
|
||||
session_id: session_id.clone(),
|
||||
}) as _)
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||
log::info!("Cancelling on session: {}", session_id);
|
||||
self.0.update(cx, |agent, cx| {
|
||||
|
@ -786,6 +818,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
|
||||
})
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
struct NativeAgentSessionEditor(Entity<Thread>);
|
||||
|
@ -796,6 +832,20 @@ impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
|||
}
|
||||
}
|
||||
|
||||
struct NativeAgentSessionResume {
|
||||
connection: NativeAgentConnection,
|
||||
session_id: acp::SessionId,
|
||||
}
|
||||
|
||||
impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
|
||||
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
|
||||
self.connection
|
||||
.run_turn(self.session_id.clone(), cx, |thread, cx| {
|
||||
thread.update(cx, |thread, cx| thread.resume(cx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -957,7 +1007,7 @@ mod tests {
|
|||
agent.read_with(cx, |agent, _| {
|
||||
let session = agent.sessions.get(&session_id).unwrap();
|
||||
session.thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.selected_model.id().0, "fake");
|
||||
assert_eq!(thread.model().id().0, "fake");
|
||||
});
|
||||
});
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue