agent2: Port Zed AI features (#36172)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
f8b0105258
commit
6f3cd42411
17 changed files with 994 additions and 358 deletions
|
@ -2,10 +2,10 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
|
|||
use acp_thread::{MentionUri, UserMessageId};
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::{AgentProfileId, AgentSettings};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::adapt_schema_to_format;
|
||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
|
||||
use collections::IndexMap;
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
|
@ -14,10 +14,10 @@ use futures::{
|
|||
};
|
||||
use gpui::{App, Context, Entity, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
||||
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
|
@ -33,6 +33,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock};
|
|||
pub enum Message {
|
||||
User(UserMessage),
|
||||
Agent(AgentMessage),
|
||||
Resume,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
|
@ -47,6 +48,7 @@ impl Message {
|
|||
match self {
|
||||
Message::User(message) => message.to_markdown(),
|
||||
Message::Agent(message) => message.to_markdown(),
|
||||
Message::Resume => "[resumed after tool use limit was reached]".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -320,7 +322,11 @@ impl AgentMessage {
|
|||
}
|
||||
|
||||
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
let mut content = Vec::with_capacity(self.content.len());
|
||||
let mut assistant_message = LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: Vec::with_capacity(self.content.len()),
|
||||
cache: false,
|
||||
};
|
||||
for chunk in &self.content {
|
||||
let chunk = match chunk {
|
||||
AgentMessageContent::Text(text) => {
|
||||
|
@ -342,29 +348,30 @@ impl AgentMessage {
|
|||
language_model::MessageContent::Image(value.clone())
|
||||
}
|
||||
};
|
||||
content.push(chunk);
|
||||
assistant_message.content.push(chunk);
|
||||
}
|
||||
|
||||
let mut messages = vec![LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content,
|
||||
let mut user_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
}];
|
||||
};
|
||||
|
||||
if !self.tool_results.is_empty() {
|
||||
let mut tool_results = Vec::with_capacity(self.tool_results.len());
|
||||
for tool_result in self.tool_results.values() {
|
||||
tool_results.push(language_model::MessageContent::ToolResult(
|
||||
for tool_result in self.tool_results.values() {
|
||||
user_message
|
||||
.content
|
||||
.push(language_model::MessageContent::ToolResult(
|
||||
tool_result.clone(),
|
||||
));
|
||||
}
|
||||
messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: tool_results,
|
||||
cache: false,
|
||||
});
|
||||
}
|
||||
|
||||
let mut messages = Vec::new();
|
||||
if !assistant_message.content.is_empty() {
|
||||
messages.push(assistant_message);
|
||||
}
|
||||
if !user_message.content.is_empty() {
|
||||
messages.push(user_message);
|
||||
}
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
@ -413,11 +420,12 @@ pub struct Thread {
|
|||
running_turn: Option<Task<()>>,
|
||||
pending_message: Option<AgentMessage>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
tool_use_limit_reached: bool,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
profile_id: AgentProfileId,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
templates: Arc<Templates>,
|
||||
pub selected_model: Arc<dyn LanguageModel>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
@ -429,7 +437,7 @@ impl Thread {
|
|||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
default_model: Arc<dyn LanguageModel>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||
|
@ -439,11 +447,12 @@ impl Thread {
|
|||
running_turn: None,
|
||||
pending_message: None,
|
||||
tools: BTreeMap::default(),
|
||||
tool_use_limit_reached: false,
|
||||
context_server_registry,
|
||||
profile_id,
|
||||
project_context,
|
||||
templates,
|
||||
selected_model: default_model,
|
||||
model,
|
||||
project,
|
||||
action_log,
|
||||
}
|
||||
|
@ -457,7 +466,19 @@ impl Thread {
|
|||
&self.action_log
|
||||
}
|
||||
|
||||
pub fn set_mode(&mut self, mode: CompletionMode) {
|
||||
pub fn model(&self) -> &Arc<dyn LanguageModel> {
|
||||
&self.model
|
||||
}
|
||||
|
||||
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
pub fn completion_mode(&self) -> CompletionMode {
|
||||
self.completion_mode
|
||||
}
|
||||
|
||||
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
|
||||
self.completion_mode = mode;
|
||||
}
|
||||
|
||||
|
@ -499,36 +520,59 @@ impl Thread {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn resume(
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
|
||||
anyhow::ensure!(
|
||||
self.tool_use_limit_reached,
|
||||
"can only resume after tool use limit is reached"
|
||||
);
|
||||
|
||||
self.messages.push(Message::Resume);
|
||||
cx.notify();
|
||||
|
||||
log::info!("Total messages in thread: {}", self.messages.len());
|
||||
Ok(self.run_turn(cx))
|
||||
}
|
||||
|
||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||
pub fn send<T>(
|
||||
&mut self,
|
||||
message_id: UserMessageId,
|
||||
id: UserMessageId,
|
||||
content: impl IntoIterator<Item = T>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
|
||||
where
|
||||
T: Into<UserMessageContent>,
|
||||
{
|
||||
let model = self.selected_model.clone();
|
||||
log::info!("Thread::send called with model: {:?}", self.model.name());
|
||||
|
||||
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
|
||||
log::info!("Thread::send called with model: {:?}", model.name());
|
||||
log::debug!("Thread::send content: {:?}", content);
|
||||
|
||||
self.messages
|
||||
.push(Message::User(UserMessage { id, content }));
|
||||
cx.notify();
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
let event_stream = AgentResponseEventStream(events_tx);
|
||||
|
||||
self.messages.push(Message::User(UserMessage {
|
||||
id: message_id.clone(),
|
||||
content,
|
||||
}));
|
||||
log::info!("Total messages in thread: {}", self.messages.len());
|
||||
self.run_turn(cx)
|
||||
}
|
||||
|
||||
fn run_turn(
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
|
||||
let model = self.model.clone();
|
||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
||||
let event_stream = AgentResponseEventStream(events_tx);
|
||||
let message_ix = self.messages.len().saturating_sub(1);
|
||||
self.tool_use_limit_reached = false;
|
||||
self.running_turn = Some(cx.spawn(async move |this, cx| {
|
||||
log::info!("Starting agent turn execution");
|
||||
let turn_result = async {
|
||||
let turn_result: Result<()> = async {
|
||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||
loop {
|
||||
log::debug!(
|
||||
|
@ -543,13 +587,22 @@ impl Thread {
|
|||
let mut events = model.stream_completion(request, cx).await?;
|
||||
log::debug!("Stream completion started successfully");
|
||||
|
||||
let mut tool_use_limit_reached = false;
|
||||
let mut tool_uses = FuturesUnordered::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event? {
|
||||
LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
) => {
|
||||
tool_use_limit_reached = true;
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
this.update(cx, |this, _cx| this.truncate(message_id))??;
|
||||
this.update(cx, |this, _cx| {
|
||||
this.flush_pending_message();
|
||||
this.messages.truncate(message_ix);
|
||||
})?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
@ -567,12 +620,7 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
if tool_uses.is_empty() {
|
||||
log::info!("No tool uses found, completing turn");
|
||||
return Ok(());
|
||||
}
|
||||
log::info!("Found {} tool uses to execute", tool_uses.len());
|
||||
|
||||
let used_tools = tool_uses.is_empty();
|
||||
while let Some(tool_result) = tool_uses.next().await {
|
||||
log::info!("Tool finished {:?}", tool_result);
|
||||
|
||||
|
@ -596,8 +644,17 @@ impl Thread {
|
|||
.ok();
|
||||
}
|
||||
|
||||
this.update(cx, |this, _| this.flush_pending_message())?;
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
if tool_use_limit_reached {
|
||||
log::info!("Tool use limit reached, completing turn");
|
||||
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
|
||||
return Err(language_model::ToolUseLimitReachedError.into());
|
||||
} else if used_tools {
|
||||
log::info!("No tool uses found, completing turn");
|
||||
return Ok(());
|
||||
} else {
|
||||
this.update(cx, |this, _| this.flush_pending_message())?;
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
@ -678,10 +735,10 @@ impl Thread {
|
|||
fn handle_text_event(
|
||||
&mut self,
|
||||
new_text: String,
|
||||
events_stream: &AgentResponseEventStream,
|
||||
event_stream: &AgentResponseEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
events_stream.send_text(&new_text);
|
||||
event_stream.send_text(&new_text);
|
||||
|
||||
let last_message = self.pending_message();
|
||||
if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
|
@ -798,8 +855,9 @@ impl Thread {
|
|||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
let supports_images = self.selected_model.supports_images();
|
||||
let supports_images = self.model.supports_images();
|
||||
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
|
||||
log::info!("Running tool {}", tool_use.name);
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
let tool_result = tool_result.await.and_then(|output| {
|
||||
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
|
||||
|
@ -902,7 +960,7 @@ impl Thread {
|
|||
name: tool_name,
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(self.selected_model.tool_input_format())
|
||||
.input_schema(self.model.tool_input_format())
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
|
@ -917,7 +975,7 @@ impl Thread {
|
|||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: Some(completion_intent),
|
||||
mode: Some(self.completion_mode),
|
||||
mode: Some(self.completion_mode.into()),
|
||||
messages,
|
||||
tools,
|
||||
tool_choice: None,
|
||||
|
@ -935,7 +993,7 @@ impl Thread {
|
|||
.profiles
|
||||
.get(&self.profile_id)
|
||||
.context("profile not found")?;
|
||||
let provider_id = self.selected_model.provider_id();
|
||||
let provider_id = self.model.provider_id();
|
||||
|
||||
Ok(self
|
||||
.tools
|
||||
|
@ -971,6 +1029,11 @@ impl Thread {
|
|||
match message {
|
||||
Message::User(message) => messages.push(message.to_request()),
|
||||
Message::Agent(message) => messages.extend(message.to_request()),
|
||||
Message::Resume => messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Continue where you left off".into()],
|
||||
cache: false,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1123,9 +1186,7 @@ where
|
|||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AgentResponseEventStream(
|
||||
mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
);
|
||||
struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
|
||||
|
||||
impl AgentResponseEventStream {
|
||||
fn send_text(&self, text: &str) {
|
||||
|
@ -1212,8 +1273,8 @@ impl AgentResponseEventStream {
|
|||
}
|
||||
}
|
||||
|
||||
fn send_error(&self, error: LanguageModelCompletionError) {
|
||||
self.0.unbounded_send(Err(error)).ok();
|
||||
fn send_error(&self, error: impl Into<anyhow::Error>) {
|
||||
self.0.unbounded_send(Err(error.into())).ok();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1229,8 +1290,7 @@ pub struct ToolCallEventStream {
|
|||
impl ToolCallEventStream {
|
||||
#[cfg(test)]
|
||||
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
||||
|
||||
let stream = ToolCallEventStream::new(
|
||||
&LanguageModelToolUse {
|
||||
|
@ -1351,9 +1411,7 @@ impl ToolCallEventStream {
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub struct ToolCallEventStreamReceiver(
|
||||
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
);
|
||||
pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
|
||||
|
||||
#[cfg(test)]
|
||||
impl ToolCallEventStreamReceiver {
|
||||
|
@ -1381,7 +1439,7 @@ impl ToolCallEventStreamReceiver {
|
|||
|
||||
#[cfg(test)]
|
||||
impl std::ops::Deref for ToolCallEventStreamReceiver {
|
||||
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
|
||||
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue