agent2: Port Zed AI features (#36172)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Bennet Bo Fenner 2025-08-15 13:17:17 +02:00 committed by GitHub
parent f8b0105258
commit 6f3cd42411
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 994 additions and 358 deletions

View file

@ -2,10 +2,10 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, AgentSettings};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use cloud_llm_client::{CompletionIntent, CompletionMode};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
use collections::IndexMap;
use fs::Fs;
use futures::{
@ -14,10 +14,10 @@ use futures::{
};
use gpui::{App, Context, Entity, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
};
use project::Project;
use prompt_store::ProjectContext;
@ -33,6 +33,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock};
pub enum Message {
User(UserMessage),
Agent(AgentMessage),
Resume,
}
impl Message {
@ -47,6 +48,7 @@ impl Message {
match self {
Message::User(message) => message.to_markdown(),
Message::Agent(message) => message.to_markdown(),
Message::Resume => "[resumed after tool use limit was reached]".into(),
}
}
}
@ -320,7 +322,11 @@ impl AgentMessage {
}
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
let mut content = Vec::with_capacity(self.content.len());
let mut assistant_message = LanguageModelRequestMessage {
role: Role::Assistant,
content: Vec::with_capacity(self.content.len()),
cache: false,
};
for chunk in &self.content {
let chunk = match chunk {
AgentMessageContent::Text(text) => {
@ -342,29 +348,30 @@ impl AgentMessage {
language_model::MessageContent::Image(value.clone())
}
};
content.push(chunk);
assistant_message.content.push(chunk);
}
let mut messages = vec![LanguageModelRequestMessage {
role: Role::Assistant,
content,
let mut user_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
}];
};
if !self.tool_results.is_empty() {
let mut tool_results = Vec::with_capacity(self.tool_results.len());
for tool_result in self.tool_results.values() {
tool_results.push(language_model::MessageContent::ToolResult(
for tool_result in self.tool_results.values() {
user_message
.content
.push(language_model::MessageContent::ToolResult(
tool_result.clone(),
));
}
messages.push(LanguageModelRequestMessage {
role: Role::User,
content: tool_results,
cache: false,
});
}
let mut messages = Vec::new();
if !assistant_message.content.is_empty() {
messages.push(assistant_message);
}
if !user_message.content.is_empty() {
messages.push(user_message);
}
messages
}
}
@ -413,11 +420,12 @@ pub struct Thread {
running_turn: Option<Task<()>>,
pending_message: Option<AgentMessage>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
tool_use_limit_reached: bool,
context_server_registry: Entity<ContextServerRegistry>,
profile_id: AgentProfileId,
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
@ -429,7 +437,7 @@ impl Thread {
context_server_registry: Entity<ContextServerRegistry>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
default_model: Arc<dyn LanguageModel>,
model: Arc<dyn LanguageModel>,
cx: &mut Context<Self>,
) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
@ -439,11 +447,12 @@ impl Thread {
running_turn: None,
pending_message: None,
tools: BTreeMap::default(),
tool_use_limit_reached: false,
context_server_registry,
profile_id,
project_context,
templates,
selected_model: default_model,
model,
project,
action_log,
}
@ -457,7 +466,19 @@ impl Thread {
&self.action_log
}
pub fn set_mode(&mut self, mode: CompletionMode) {
pub fn model(&self) -> &Arc<dyn LanguageModel> {
&self.model
}
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
self.model = model;
}
pub fn completion_mode(&self) -> CompletionMode {
self.completion_mode
}
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
@ -499,36 +520,59 @@ impl Thread {
Ok(())
}
pub fn resume(
&mut self,
cx: &mut Context<Self>,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
anyhow::ensure!(
self.tool_use_limit_reached,
"can only resume after tool use limit is reached"
);
self.messages.push(Message::Resume);
cx.notify();
log::info!("Total messages in thread: {}", self.messages.len());
Ok(self.run_turn(cx))
}
/// Sending a message results in the model streaming a response, which could include tool calls.
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
pub fn send<T>(
&mut self,
message_id: UserMessageId,
id: UserMessageId,
content: impl IntoIterator<Item = T>,
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
where
T: Into<UserMessageContent>,
{
let model = self.selected_model.clone();
log::info!("Thread::send called with model: {:?}", self.model.name());
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
log::info!("Thread::send called with model: {:?}", model.name());
log::debug!("Thread::send content: {:?}", content);
self.messages
.push(Message::User(UserMessage { id, content }));
cx.notify();
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let event_stream = AgentResponseEventStream(events_tx);
self.messages.push(Message::User(UserMessage {
id: message_id.clone(),
content,
}));
log::info!("Total messages in thread: {}", self.messages.len());
self.run_turn(cx)
}
fn run_turn(
&mut self,
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
let model = self.model.clone();
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
let event_stream = AgentResponseEventStream(events_tx);
let message_ix = self.messages.len().saturating_sub(1);
self.tool_use_limit_reached = false;
self.running_turn = Some(cx.spawn(async move |this, cx| {
log::info!("Starting agent turn execution");
let turn_result = async {
let turn_result: Result<()> = async {
let mut completion_intent = CompletionIntent::UserPrompt;
loop {
log::debug!(
@ -543,13 +587,22 @@ impl Thread {
let mut events = model.stream_completion(request, cx).await?;
log::debug!("Stream completion started successfully");
let mut tool_use_limit_reached = false;
let mut tool_uses = FuturesUnordered::new();
while let Some(event) = events.next().await {
match event? {
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::ToolUseLimitReached,
) => {
tool_use_limit_reached = true;
}
LanguageModelCompletionEvent::Stop(reason) => {
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
this.update(cx, |this, _cx| this.truncate(message_id))??;
this.update(cx, |this, _cx| {
this.flush_pending_message();
this.messages.truncate(message_ix);
})?;
return Ok(());
}
}
@ -567,12 +620,7 @@ impl Thread {
}
}
if tool_uses.is_empty() {
log::info!("No tool uses found, completing turn");
return Ok(());
}
log::info!("Found {} tool uses to execute", tool_uses.len());
let used_tools = tool_uses.is_empty();
while let Some(tool_result) = tool_uses.next().await {
log::info!("Tool finished {:?}", tool_result);
@ -596,8 +644,17 @@ impl Thread {
.ok();
}
this.update(cx, |this, _| this.flush_pending_message())?;
completion_intent = CompletionIntent::ToolResults;
if tool_use_limit_reached {
log::info!("Tool use limit reached, completing turn");
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
return Err(language_model::ToolUseLimitReachedError.into());
} else if used_tools {
log::info!("No tool uses found, completing turn");
return Ok(());
} else {
this.update(cx, |this, _| this.flush_pending_message())?;
completion_intent = CompletionIntent::ToolResults;
}
}
}
.await;
@ -678,10 +735,10 @@ impl Thread {
fn handle_text_event(
&mut self,
new_text: String,
events_stream: &AgentResponseEventStream,
event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) {
events_stream.send_text(&new_text);
event_stream.send_text(&new_text);
let last_message = self.pending_message();
if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
@ -798,8 +855,9 @@ impl Thread {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
let supports_images = self.selected_model.supports_images();
let supports_images = self.model.supports_images();
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
log::info!("Running tool {}", tool_use.name);
Some(cx.foreground_executor().spawn(async move {
let tool_result = tool_result.await.and_then(|output| {
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
@ -902,7 +960,7 @@ impl Thread {
name: tool_name,
description: tool.description().to_string(),
input_schema: tool
.input_schema(self.selected_model.tool_input_format())
.input_schema(self.model.tool_input_format())
.log_err()?,
})
})
@ -917,7 +975,7 @@ impl Thread {
thread_id: None,
prompt_id: None,
intent: Some(completion_intent),
mode: Some(self.completion_mode),
mode: Some(self.completion_mode.into()),
messages,
tools,
tool_choice: None,
@ -935,7 +993,7 @@ impl Thread {
.profiles
.get(&self.profile_id)
.context("profile not found")?;
let provider_id = self.selected_model.provider_id();
let provider_id = self.model.provider_id();
Ok(self
.tools
@ -971,6 +1029,11 @@ impl Thread {
match message {
Message::User(message) => messages.push(message.to_request()),
Message::Agent(message) => messages.extend(message.to_request()),
Message::Resume => messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec!["Continue where you left off".into()],
cache: false,
}),
}
}
@ -1123,9 +1186,7 @@ where
}
#[derive(Clone)]
struct AgentResponseEventStream(
mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
);
struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
impl AgentResponseEventStream {
fn send_text(&self, text: &str) {
@ -1212,8 +1273,8 @@ impl AgentResponseEventStream {
}
}
fn send_error(&self, error: LanguageModelCompletionError) {
self.0.unbounded_send(Err(error)).ok();
fn send_error(&self, error: impl Into<anyhow::Error>) {
self.0.unbounded_send(Err(error.into())).ok();
}
}
@ -1229,8 +1290,7 @@ pub struct ToolCallEventStream {
impl ToolCallEventStream {
#[cfg(test)]
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
@ -1351,9 +1411,7 @@ impl ToolCallEventStream {
}
#[cfg(test)]
pub struct ToolCallEventStreamReceiver(
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
);
pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
#[cfg(test)]
impl ToolCallEventStreamReceiver {
@ -1381,7 +1439,7 @@ impl ToolCallEventStreamReceiver {
#[cfg(test)]
impl std::ops::Deref for ToolCallEventStreamReceiver {
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
fn deref(&self) -> &Self::Target {
&self.0