agent2: Port Zed AI features (#36172)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
f8b0105258
commit
6f3cd42411
17 changed files with 994 additions and 358 deletions
|
@ -33,13 +33,23 @@ pub struct UserMessage {
|
||||||
pub id: Option<UserMessageId>,
|
pub id: Option<UserMessageId>,
|
||||||
pub content: ContentBlock,
|
pub content: ContentBlock,
|
||||||
pub chunks: Vec<acp::ContentBlock>,
|
pub chunks: Vec<acp::ContentBlock>,
|
||||||
pub checkpoint: Option<GitStoreCheckpoint>,
|
pub checkpoint: Option<Checkpoint>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Checkpoint {
|
||||||
|
git_checkpoint: GitStoreCheckpoint,
|
||||||
|
pub show: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UserMessage {
|
impl UserMessage {
|
||||||
fn to_markdown(&self, cx: &App) -> String {
|
fn to_markdown(&self, cx: &App) -> String {
|
||||||
let mut markdown = String::new();
|
let mut markdown = String::new();
|
||||||
if let Some(_) = self.checkpoint {
|
if self
|
||||||
|
.checkpoint
|
||||||
|
.as_ref()
|
||||||
|
.map_or(false, |checkpoint| checkpoint.show)
|
||||||
|
{
|
||||||
writeln!(markdown, "## User (checkpoint)").unwrap();
|
writeln!(markdown, "## User (checkpoint)").unwrap();
|
||||||
} else {
|
} else {
|
||||||
writeln!(markdown, "## User").unwrap();
|
writeln!(markdown, "## User").unwrap();
|
||||||
|
@ -1145,9 +1155,12 @@ impl AcpThread {
|
||||||
self.project.read(cx).languages().clone(),
|
self.project.read(cx).languages().clone(),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
let request = acp::PromptRequest {
|
||||||
|
prompt: message.clone(),
|
||||||
|
session_id: self.session_id.clone(),
|
||||||
|
};
|
||||||
let git_store = self.project.read(cx).git_store().clone();
|
let git_store = self.project.read(cx).git_store().clone();
|
||||||
|
|
||||||
let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
|
|
||||||
let message_id = if self
|
let message_id = if self
|
||||||
.connection
|
.connection
|
||||||
.session_editor(&self.session_id, cx)
|
.session_editor(&self.session_id, cx)
|
||||||
|
@ -1161,68 +1174,63 @@ impl AcpThread {
|
||||||
AgentThreadEntry::UserMessage(UserMessage {
|
AgentThreadEntry::UserMessage(UserMessage {
|
||||||
id: message_id.clone(),
|
id: message_id.clone(),
|
||||||
content: block,
|
content: block,
|
||||||
chunks: message.clone(),
|
chunks: message,
|
||||||
checkpoint: None,
|
checkpoint: None,
|
||||||
}),
|
}),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
self.run_turn(cx, async move |this, cx| {
|
||||||
|
let old_checkpoint = git_store
|
||||||
|
.update(cx, |git, cx| git.checkpoint(cx))?
|
||||||
|
.await
|
||||||
|
.context("failed to get old checkpoint")
|
||||||
|
.log_err();
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
if let Some((_ix, message)) = this.last_user_message() {
|
||||||
|
message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
|
||||||
|
git_checkpoint,
|
||||||
|
show: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
this.connection.prompt(message_id, request, cx)
|
||||||
|
})?
|
||||||
|
.await
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
|
||||||
|
self.run_turn(cx, async move |this, cx| {
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
this.connection
|
||||||
|
.resume(&this.session_id, cx)
|
||||||
|
.map(|resume| resume.run(cx))
|
||||||
|
})?
|
||||||
|
.context("resuming a session is not supported")?
|
||||||
|
.await
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_turn(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
|
||||||
|
) -> BoxFuture<'static, Result<()>> {
|
||||||
self.clear_completed_plan_entries(cx);
|
self.clear_completed_plan_entries(cx);
|
||||||
|
|
||||||
let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
|
|
||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = oneshot::channel();
|
||||||
let cancel_task = self.cancel(cx);
|
let cancel_task = self.cancel(cx);
|
||||||
let request = acp::PromptRequest {
|
|
||||||
prompt: message,
|
|
||||||
session_id: self.session_id.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
self.send_task = Some(cx.spawn({
|
self.send_task = Some(cx.spawn(async move |this, cx| {
|
||||||
let message_id = message_id.clone();
|
cancel_task.await;
|
||||||
async move |this, cx| {
|
tx.send(f(this, cx).await).ok();
|
||||||
cancel_task.await;
|
|
||||||
|
|
||||||
old_checkpoint_tx.send(old_checkpoint.await).ok();
|
|
||||||
if let Ok(result) = this.update(cx, |this, cx| {
|
|
||||||
this.connection.prompt(message_id, request, cx)
|
|
||||||
}) {
|
|
||||||
tx.send(result.await).log_err();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}));
|
}));
|
||||||
|
|
||||||
cx.spawn(async move |this, cx| {
|
cx.spawn(async move |this, cx| {
|
||||||
let old_checkpoint = old_checkpoint_rx
|
|
||||||
.await
|
|
||||||
.map_err(|_| anyhow!("send canceled"))
|
|
||||||
.flatten()
|
|
||||||
.context("failed to get old checkpoint")
|
|
||||||
.log_err();
|
|
||||||
|
|
||||||
let response = rx.await;
|
let response = rx.await;
|
||||||
|
|
||||||
if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
|
this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
|
||||||
let new_checkpoint = git_store
|
.await?;
|
||||||
.update(cx, |git, cx| git.checkpoint(cx))?
|
|
||||||
.await
|
|
||||||
.context("failed to get new checkpoint")
|
|
||||||
.log_err();
|
|
||||||
if let Some(new_checkpoint) = new_checkpoint {
|
|
||||||
let equal = git_store
|
|
||||||
.update(cx, |git, cx| {
|
|
||||||
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
|
|
||||||
})?
|
|
||||||
.await
|
|
||||||
.unwrap_or(true);
|
|
||||||
if !equal {
|
|
||||||
this.update(cx, |this, cx| {
|
|
||||||
if let Some((ix, message)) = this.user_message_mut(&message_id) {
|
|
||||||
message.checkpoint = Some(old_checkpoint);
|
|
||||||
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
|
||||||
}
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
match response {
|
match response {
|
||||||
|
@ -1294,7 +1302,10 @@ impl AcpThread {
|
||||||
return Task::ready(Err(anyhow!("message not found")));
|
return Task::ready(Err(anyhow!("message not found")));
|
||||||
};
|
};
|
||||||
|
|
||||||
let checkpoint = message.checkpoint.clone();
|
let checkpoint = message
|
||||||
|
.checkpoint
|
||||||
|
.as_ref()
|
||||||
|
.map(|c| c.git_checkpoint.clone());
|
||||||
|
|
||||||
let git_store = self.project.read(cx).git_store().clone();
|
let git_store = self.project.read(cx).git_store().clone();
|
||||||
cx.spawn(async move |this, cx| {
|
cx.spawn(async move |this, cx| {
|
||||||
|
@ -1316,6 +1327,59 @@ impl AcpThread {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||||
|
let git_store = self.project.read(cx).git_store().clone();
|
||||||
|
|
||||||
|
let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
|
||||||
|
if let Some(checkpoint) = message.checkpoint.as_ref() {
|
||||||
|
checkpoint.git_checkpoint.clone()
|
||||||
|
} else {
|
||||||
|
return Task::ready(Ok(()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Task::ready(Ok(()));
|
||||||
|
};
|
||||||
|
|
||||||
|
let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
|
||||||
|
cx.spawn(async move |this, cx| {
|
||||||
|
let new_checkpoint = new_checkpoint
|
||||||
|
.await
|
||||||
|
.context("failed to get new checkpoint")
|
||||||
|
.log_err();
|
||||||
|
if let Some(new_checkpoint) = new_checkpoint {
|
||||||
|
let equal = git_store
|
||||||
|
.update(cx, |git, cx| {
|
||||||
|
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
|
||||||
|
})?
|
||||||
|
.await
|
||||||
|
.unwrap_or(true);
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
let (ix, message) = this.last_user_message().context("no user message")?;
|
||||||
|
let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
|
||||||
|
checkpoint.show = !equal;
|
||||||
|
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||||
|
anyhow::Ok(())
|
||||||
|
})??;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
|
||||||
|
self.entries
|
||||||
|
.iter_mut()
|
||||||
|
.enumerate()
|
||||||
|
.rev()
|
||||||
|
.find_map(|(ix, entry)| {
|
||||||
|
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||||
|
Some((ix, message))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
|
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
|
||||||
self.entries.iter().find_map(|entry| {
|
self.entries.iter().find_map(|entry| {
|
||||||
if let AgentThreadEntry::UserMessage(message) = entry {
|
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||||
|
@ -1552,6 +1616,7 @@ mod tests {
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::stream::StreamExt as _;
|
use smol::stream::StreamExt as _;
|
||||||
use std::{
|
use std::{
|
||||||
|
any::Any,
|
||||||
cell::RefCell,
|
cell::RefCell,
|
||||||
path::Path,
|
path::Path,
|
||||||
rc::Rc,
|
rc::Rc,
|
||||||
|
@ -2284,6 +2349,10 @@ mod tests {
|
||||||
_session_id: session_id.clone(),
|
_session_id: session_id.clone(),
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FakeAgentSessionEditor {
|
struct FakeAgentSessionEditor {
|
||||||
|
|
|
@ -4,7 +4,7 @@ use anyhow::Result;
|
||||||
use collections::IndexMap;
|
use collections::IndexMap;
|
||||||
use gpui::{Entity, SharedString, Task};
|
use gpui::{Entity, SharedString, Task};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||||
use ui::{App, IconName};
|
use ui::{App, IconName};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
@ -36,6 +36,14 @@ pub trait AgentConnection {
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<acp::PromptResponse>>;
|
) -> Task<Result<acp::PromptResponse>>;
|
||||||
|
|
||||||
|
fn resume(
|
||||||
|
&self,
|
||||||
|
_session_id: &acp::SessionId,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Option<Rc<dyn AgentSessionResume>> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||||
|
|
||||||
fn session_editor(
|
fn session_editor(
|
||||||
|
@ -53,12 +61,24 @@ pub trait AgentConnection {
|
||||||
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
|
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl dyn AgentConnection {
|
||||||
|
pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
|
||||||
|
self.into_any().downcast().ok()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait AgentSessionEditor {
|
pub trait AgentSessionEditor {
|
||||||
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
|
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait AgentSessionResume {
|
||||||
|
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct AuthRequired;
|
pub struct AuthRequired;
|
||||||
|
|
||||||
|
@ -299,6 +319,10 @@ mod test_support {
|
||||||
) -> Option<Rc<dyn AgentSessionEditor>> {
|
) -> Option<Rc<dyn AgentSessionEditor>> {
|
||||||
Some(Rc::new(StubAgentSessionEditor))
|
Some(Rc::new(StubAgentSessionEditor))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct StubAgentSessionEditor;
|
struct StubAgentSessionEditor;
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
use crate::{AgentResponseEvent, Thread, templates::Templates};
|
|
||||||
use crate::{
|
use crate::{
|
||||||
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
|
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
|
||||||
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
|
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
|
||||||
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
|
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
|
||||||
WebSearchTool,
|
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
|
||||||
};
|
};
|
||||||
use acp_thread::AgentModelSelector;
|
use acp_thread::AgentModelSelector;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
|
@ -11,6 +10,7 @@ use agent_settings::AgentSettings;
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use collections::{HashSet, IndexMap};
|
use collections::{HashSet, IndexMap};
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
|
use futures::channel::mpsc;
|
||||||
use futures::{StreamExt, future};
|
use futures::{StreamExt, future};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
||||||
|
@ -21,6 +21,7 @@ use prompt_store::{
|
||||||
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
|
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
|
||||||
};
|
};
|
||||||
use settings::update_settings_file;
|
use settings::update_settings_file;
|
||||||
|
use std::any::Any;
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
@ -426,9 +427,9 @@ impl NativeAgent {
|
||||||
self.models.refresh_list(cx);
|
self.models.refresh_list(cx);
|
||||||
for session in self.sessions.values_mut() {
|
for session in self.sessions.values_mut() {
|
||||||
session.thread.update(cx, |thread, _| {
|
session.thread.update(cx, |thread, _| {
|
||||||
let model_id = LanguageModels::model_id(&thread.selected_model);
|
let model_id = LanguageModels::model_id(&thread.model());
|
||||||
if let Some(model) = self.models.model_from_id(&model_id) {
|
if let Some(model) = self.models.model_from_id(&model_id) {
|
||||||
thread.selected_model = model.clone();
|
thread.set_model(model.clone());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -439,6 +440,124 @@ impl NativeAgent {
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
||||||
|
|
||||||
|
impl NativeAgentConnection {
|
||||||
|
pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
|
||||||
|
self.0
|
||||||
|
.read(cx)
|
||||||
|
.sessions
|
||||||
|
.get(session_id)
|
||||||
|
.map(|session| session.thread.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_turn(
|
||||||
|
&self,
|
||||||
|
session_id: acp::SessionId,
|
||||||
|
cx: &mut App,
|
||||||
|
f: impl 'static
|
||||||
|
+ FnOnce(
|
||||||
|
Entity<Thread>,
|
||||||
|
&mut App,
|
||||||
|
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
|
||||||
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
|
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
|
||||||
|
agent
|
||||||
|
.sessions
|
||||||
|
.get_mut(&session_id)
|
||||||
|
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
|
||||||
|
}) else {
|
||||||
|
return Task::ready(Err(anyhow!("Session not found")));
|
||||||
|
};
|
||||||
|
log::debug!("Found session for: {}", session_id);
|
||||||
|
|
||||||
|
let mut response_stream = match f(thread, cx) {
|
||||||
|
Ok(stream) => stream,
|
||||||
|
Err(err) => return Task::ready(Err(err)),
|
||||||
|
};
|
||||||
|
cx.spawn(async move |cx| {
|
||||||
|
// Handle response stream and forward to session.acp_thread
|
||||||
|
while let Some(result) = response_stream.next().await {
|
||||||
|
match result {
|
||||||
|
Ok(event) => {
|
||||||
|
log::trace!("Received completion event: {:?}", event);
|
||||||
|
|
||||||
|
match event {
|
||||||
|
AgentResponseEvent::Text(text) => {
|
||||||
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.push_assistant_content_block(
|
||||||
|
acp::ContentBlock::Text(acp::TextContent {
|
||||||
|
text,
|
||||||
|
annotations: None,
|
||||||
|
}),
|
||||||
|
false,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
AgentResponseEvent::Thinking(text) => {
|
||||||
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.push_assistant_content_block(
|
||||||
|
acp::ContentBlock::Text(acp::TextContent {
|
||||||
|
text,
|
||||||
|
annotations: None,
|
||||||
|
}),
|
||||||
|
true,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||||
|
tool_call,
|
||||||
|
options,
|
||||||
|
response,
|
||||||
|
}) => {
|
||||||
|
let recv = acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.request_tool_call_authorization(tool_call, options, cx)
|
||||||
|
})?;
|
||||||
|
cx.background_spawn(async move {
|
||||||
|
if let Some(option) = recv
|
||||||
|
.await
|
||||||
|
.context("authorization sender was dropped")
|
||||||
|
.log_err()
|
||||||
|
{
|
||||||
|
response
|
||||||
|
.send(option)
|
||||||
|
.map(|_| anyhow!("authorization receiver was dropped"))
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
|
AgentResponseEvent::ToolCall(tool_call) => {
|
||||||
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.upsert_tool_call(tool_call, cx)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
AgentResponseEvent::ToolCallUpdate(update) => {
|
||||||
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.update_tool_call(update, cx)
|
||||||
|
})??;
|
||||||
|
}
|
||||||
|
AgentResponseEvent::Stop(stop_reason) => {
|
||||||
|
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||||
|
return Ok(acp::PromptResponse { stop_reason });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("Error in model response stream: {:?}", e);
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("Response stream completed");
|
||||||
|
anyhow::Ok(acp::PromptResponse {
|
||||||
|
stop_reason: acp::StopReason::EndTurn,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl AgentModelSelector for NativeAgentConnection {
|
impl AgentModelSelector for NativeAgentConnection {
|
||||||
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
|
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
|
||||||
log::debug!("NativeAgentConnection::list_models called");
|
log::debug!("NativeAgentConnection::list_models called");
|
||||||
|
@ -472,7 +591,7 @@ impl AgentModelSelector for NativeAgentConnection {
|
||||||
};
|
};
|
||||||
|
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
thread.selected_model = model.clone();
|
thread.set_model(model.clone());
|
||||||
});
|
});
|
||||||
|
|
||||||
update_settings_file::<AgentSettings>(
|
update_settings_file::<AgentSettings>(
|
||||||
|
@ -502,7 +621,7 @@ impl AgentModelSelector for NativeAgentConnection {
|
||||||
else {
|
else {
|
||||||
return Task::ready(Err(anyhow!("Session not found")));
|
return Task::ready(Err(anyhow!("Session not found")));
|
||||||
};
|
};
|
||||||
let model = thread.read(cx).selected_model.clone();
|
let model = thread.read(cx).model().clone();
|
||||||
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
|
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
|
||||||
else {
|
else {
|
||||||
return Task::ready(Err(anyhow!("Provider not found")));
|
return Task::ready(Err(anyhow!("Provider not found")));
|
||||||
|
@ -644,25 +763,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
) -> Task<Result<acp::PromptResponse>> {
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
let id = id.expect("UserMessageId is required");
|
let id = id.expect("UserMessageId is required");
|
||||||
let session_id = params.session_id.clone();
|
let session_id = params.session_id.clone();
|
||||||
let agent = self.0.clone();
|
|
||||||
log::info!("Received prompt request for session: {}", session_id);
|
log::info!("Received prompt request for session: {}", session_id);
|
||||||
log::debug!("Prompt blocks count: {}", params.prompt.len());
|
log::debug!("Prompt blocks count: {}", params.prompt.len());
|
||||||
|
|
||||||
cx.spawn(async move |cx| {
|
self.run_turn(session_id, cx, |thread, cx| {
|
||||||
// Get session
|
|
||||||
let (thread, acp_thread) = agent
|
|
||||||
.update(cx, |agent, _| {
|
|
||||||
agent
|
|
||||||
.sessions
|
|
||||||
.get_mut(&session_id)
|
|
||||||
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
|
|
||||||
})?
|
|
||||||
.ok_or_else(|| {
|
|
||||||
log::error!("Session not found: {}", session_id);
|
|
||||||
anyhow::anyhow!("Session not found")
|
|
||||||
})?;
|
|
||||||
log::debug!("Found session for: {}", session_id);
|
|
||||||
|
|
||||||
let content: Vec<UserMessageContent> = params
|
let content: Vec<UserMessageContent> = params
|
||||||
.prompt
|
.prompt
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -672,99 +776,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
log::debug!("Message id: {:?}", id);
|
log::debug!("Message id: {:?}", id);
|
||||||
log::debug!("Message content: {:?}", content);
|
log::debug!("Message content: {:?}", content);
|
||||||
|
|
||||||
// Get model using the ModelSelector capability (always available for agent2)
|
Ok(thread.update(cx, |thread, cx| {
|
||||||
// Get the selected model from the thread directly
|
log::info!(
|
||||||
let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
"Sending message to thread with model: {:?}",
|
||||||
|
thread.model().name()
|
||||||
// Send to thread
|
);
|
||||||
log::info!("Sending message to thread with model: {:?}", model.name());
|
thread.send(id, content, cx)
|
||||||
let mut response_stream =
|
}))
|
||||||
thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
|
|
||||||
|
|
||||||
// Handle response stream and forward to session.acp_thread
|
|
||||||
while let Some(result) = response_stream.next().await {
|
|
||||||
match result {
|
|
||||||
Ok(event) => {
|
|
||||||
log::trace!("Received completion event: {:?}", event);
|
|
||||||
|
|
||||||
match event {
|
|
||||||
AgentResponseEvent::Text(text) => {
|
|
||||||
acp_thread.update(cx, |thread, cx| {
|
|
||||||
thread.push_assistant_content_block(
|
|
||||||
acp::ContentBlock::Text(acp::TextContent {
|
|
||||||
text,
|
|
||||||
annotations: None,
|
|
||||||
}),
|
|
||||||
false,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
AgentResponseEvent::Thinking(text) => {
|
|
||||||
acp_thread.update(cx, |thread, cx| {
|
|
||||||
thread.push_assistant_content_block(
|
|
||||||
acp::ContentBlock::Text(acp::TextContent {
|
|
||||||
text,
|
|
||||||
annotations: None,
|
|
||||||
}),
|
|
||||||
true,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
|
|
||||||
tool_call,
|
|
||||||
options,
|
|
||||||
response,
|
|
||||||
}) => {
|
|
||||||
let recv = acp_thread.update(cx, |thread, cx| {
|
|
||||||
thread.request_tool_call_authorization(tool_call, options, cx)
|
|
||||||
})?;
|
|
||||||
cx.background_spawn(async move {
|
|
||||||
if let Some(option) = recv
|
|
||||||
.await
|
|
||||||
.context("authorization sender was dropped")
|
|
||||||
.log_err()
|
|
||||||
{
|
|
||||||
response
|
|
||||||
.send(option)
|
|
||||||
.map(|_| anyhow!("authorization receiver was dropped"))
|
|
||||||
.log_err();
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
}
|
|
||||||
AgentResponseEvent::ToolCall(tool_call) => {
|
|
||||||
acp_thread.update(cx, |thread, cx| {
|
|
||||||
thread.upsert_tool_call(tool_call, cx)
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
AgentResponseEvent::ToolCallUpdate(update) => {
|
|
||||||
acp_thread.update(cx, |thread, cx| {
|
|
||||||
thread.update_tool_call(update, cx)
|
|
||||||
})??;
|
|
||||||
}
|
|
||||||
AgentResponseEvent::Stop(stop_reason) => {
|
|
||||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
|
||||||
return Ok(acp::PromptResponse { stop_reason });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
log::error!("Error in model response stream: {:?}", e);
|
|
||||||
// TODO: Consider sending an error message to the UI
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log::info!("Response stream completed");
|
|
||||||
anyhow::Ok(acp::PromptResponse {
|
|
||||||
stop_reason: acp::StopReason::EndTurn,
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn resume(
|
||||||
|
&self,
|
||||||
|
session_id: &acp::SessionId,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
|
||||||
|
Some(Rc::new(NativeAgentSessionResume {
|
||||||
|
connection: self.clone(),
|
||||||
|
session_id: session_id.clone(),
|
||||||
|
}) as _)
|
||||||
|
}
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||||
log::info!("Cancelling on session: {}", session_id);
|
log::info!("Cancelling on session: {}", session_id);
|
||||||
self.0.update(cx, |agent, cx| {
|
self.0.update(cx, |agent, cx| {
|
||||||
|
@ -786,6 +818,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
|
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NativeAgentSessionEditor(Entity<Thread>);
|
struct NativeAgentSessionEditor(Entity<Thread>);
|
||||||
|
@ -796,6 +832,20 @@ impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct NativeAgentSessionResume {
|
||||||
|
connection: NativeAgentConnection,
|
||||||
|
session_id: acp::SessionId,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
|
||||||
|
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
|
||||||
|
self.connection
|
||||||
|
.run_turn(self.session_id.clone(), cx, |thread, cx| {
|
||||||
|
thread.update(cx, |thread, cx| thread.resume(cx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -957,7 +1007,7 @@ mod tests {
|
||||||
agent.read_with(cx, |agent, _| {
|
agent.read_with(cx, |agent, _| {
|
||||||
let session = agent.sessions.get(&session_id).unwrap();
|
let session = agent.sessions.get(&session_id).unwrap();
|
||||||
session.thread.read_with(cx, |thread, _| {
|
session.thread.read_with(cx, |thread, _| {
|
||||||
assert_eq!(thread.selected_model.id().0, "fake");
|
assert_eq!(thread.model().id().0, "fake");
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,9 @@ use gpui::{
|
||||||
};
|
};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
|
||||||
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
|
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
|
||||||
fake_provider::FakeLanguageModel,
|
Role, StopReason, fake_provider::FakeLanguageModel,
|
||||||
};
|
};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use prompt_store::ProjectContext;
|
use prompt_store::ProjectContext;
|
||||||
|
@ -394,8 +394,194 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
||||||
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
|
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
let events = thread.update(cx, |thread, cx| {
|
||||||
|
thread.add_tool(EchoTool);
|
||||||
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
let tool_use = LanguageModelToolUse {
|
||||||
|
id: "tool_id_1".into(),
|
||||||
|
name: EchoTool.name().into(),
|
||||||
|
raw_input: "{}".into(),
|
||||||
|
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
|
||||||
|
is_input_complete: true,
|
||||||
|
};
|
||||||
|
fake_model
|
||||||
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
|
||||||
|
cx.run_until_parked();
|
||||||
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
let tool_result = LanguageModelToolResult {
|
||||||
|
tool_use_id: "tool_id_1".into(),
|
||||||
|
tool_name: EchoTool.name().into(),
|
||||||
|
is_error: false,
|
||||||
|
content: "def".into(),
|
||||||
|
output: Some("def".into()),
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
completion.messages[1..],
|
||||||
|
vec![
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec!["abc".into()],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: vec![MessageContent::ToolUse(tool_use.clone())],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec![MessageContent::ToolResult(tool_result.clone())],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Simulate reaching tool use limit.
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
|
||||||
|
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
|
||||||
|
assert!(
|
||||||
|
last_event
|
||||||
|
.unwrap_err()
|
||||||
|
.is::<language_model::ToolUseLimitReachedError>()
|
||||||
|
);
|
||||||
|
|
||||||
|
let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
completion.messages[1..],
|
||||||
|
vec![
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec!["abc".into()],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: vec![MessageContent::ToolUse(tool_use)],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec![MessageContent::ToolResult(tool_result)],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec!["Continue where you left off".into()],
|
||||||
|
cache: false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
events.collect::<Vec<_>>().await;
|
||||||
|
thread.read_with(cx, |thread, _cx| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.last_message().unwrap().to_markdown(),
|
||||||
|
indoc! {"
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
Done
|
||||||
|
"}
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Ensure we error if calling resume when tool use limit was *not* reached.
|
||||||
|
let error = thread
|
||||||
|
.update(cx, |thread, cx| thread.resume(cx))
|
||||||
|
.unwrap_err();
|
||||||
|
assert_eq!(
|
||||||
|
error.to_string(),
|
||||||
|
"can only resume after tool use limit is reached"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
let events = thread.update(cx, |thread, cx| {
|
||||||
|
thread.add_tool(EchoTool);
|
||||||
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
let tool_use = LanguageModelToolUse {
|
||||||
|
id: "tool_id_1".into(),
|
||||||
|
name: EchoTool.name().into(),
|
||||||
|
raw_input: "{}".into(),
|
||||||
|
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
|
||||||
|
is_input_complete: true,
|
||||||
|
};
|
||||||
|
let tool_result = LanguageModelToolResult {
|
||||||
|
tool_use_id: "tool_id_1".into(),
|
||||||
|
tool_name: EchoTool.name().into(),
|
||||||
|
is_error: false,
|
||||||
|
content: "def".into(),
|
||||||
|
output: Some("def".into()),
|
||||||
|
};
|
||||||
|
fake_model
|
||||||
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
|
||||||
|
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
|
||||||
|
assert!(
|
||||||
|
last_event
|
||||||
|
.unwrap_err()
|
||||||
|
.is::<language_model::ToolUseLimitReachedError>()
|
||||||
|
);
|
||||||
|
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), vec!["ghi"], cx)
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
completion.messages[1..],
|
||||||
|
vec![
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec!["abc".into()],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: vec![MessageContent::ToolUse(tool_use)],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec![MessageContent::ToolResult(tool_result)],
|
||||||
|
cache: false
|
||||||
|
},
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec!["ghi".into()],
|
||||||
|
cache: false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
async fn expect_tool_call(
|
async fn expect_tool_call(
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
||||||
) -> acp::ToolCall {
|
) -> acp::ToolCall {
|
||||||
let event = events
|
let event = events
|
||||||
.next()
|
.next()
|
||||||
|
@ -411,7 +597,7 @@ async fn expect_tool_call(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn expect_tool_call_update_fields(
|
async fn expect_tool_call_update_fields(
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
||||||
) -> acp::ToolCallUpdate {
|
) -> acp::ToolCallUpdate {
|
||||||
let event = events
|
let event = events
|
||||||
.next()
|
.next()
|
||||||
|
@ -429,7 +615,7 @@ async fn expect_tool_call_update_fields(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn next_tool_call_authorization(
|
async fn next_tool_call_authorization(
|
||||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
||||||
) -> ToolCallAuthorization {
|
) -> ToolCallAuthorization {
|
||||||
loop {
|
loop {
|
||||||
let event = events
|
let event = events
|
||||||
|
@ -1007,9 +1193,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Filters out the stop events for asserting against in tests
|
/// Filters out the stop events for asserting against in tests
|
||||||
fn stop_events(
|
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
|
||||||
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
|
||||||
) -> Vec<acp::StopReason> {
|
|
||||||
result_events
|
result_events
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|event| match event.unwrap() {
|
.filter_map(|event| match event.unwrap() {
|
||||||
|
|
|
@ -7,7 +7,7 @@ use std::future;
|
||||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||||
pub struct EchoToolInput {
|
pub struct EchoToolInput {
|
||||||
/// The text to echo.
|
/// The text to echo.
|
||||||
text: String,
|
pub text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct EchoTool;
|
pub struct EchoTool;
|
||||||
|
|
|
@ -2,10 +2,10 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
|
||||||
use acp_thread::{MentionUri, UserMessageId};
|
use acp_thread::{MentionUri, UserMessageId};
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use agent_settings::{AgentProfileId, AgentSettings};
|
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_tool::adapt_schema_to_format;
|
use assistant_tool::adapt_schema_to_format;
|
||||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
|
||||||
use collections::IndexMap;
|
use collections::IndexMap;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use futures::{
|
use futures::{
|
||||||
|
@ -14,10 +14,10 @@ use futures::{
|
||||||
};
|
};
|
||||||
use gpui::{App, Context, Entity, SharedString, Task};
|
use gpui::{App, Context, Entity, SharedString, Task};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
|
||||||
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
|
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||||
};
|
};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use prompt_store::ProjectContext;
|
use prompt_store::ProjectContext;
|
||||||
|
@ -33,6 +33,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock};
|
||||||
pub enum Message {
|
pub enum Message {
|
||||||
User(UserMessage),
|
User(UserMessage),
|
||||||
Agent(AgentMessage),
|
Agent(AgentMessage),
|
||||||
|
Resume,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Message {
|
impl Message {
|
||||||
|
@ -47,6 +48,7 @@ impl Message {
|
||||||
match self {
|
match self {
|
||||||
Message::User(message) => message.to_markdown(),
|
Message::User(message) => message.to_markdown(),
|
||||||
Message::Agent(message) => message.to_markdown(),
|
Message::Agent(message) => message.to_markdown(),
|
||||||
|
Message::Resume => "[resumed after tool use limit was reached]".into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -320,7 +322,11 @@ impl AgentMessage {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
|
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
|
||||||
let mut content = Vec::with_capacity(self.content.len());
|
let mut assistant_message = LanguageModelRequestMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: Vec::with_capacity(self.content.len()),
|
||||||
|
cache: false,
|
||||||
|
};
|
||||||
for chunk in &self.content {
|
for chunk in &self.content {
|
||||||
let chunk = match chunk {
|
let chunk = match chunk {
|
||||||
AgentMessageContent::Text(text) => {
|
AgentMessageContent::Text(text) => {
|
||||||
|
@ -342,29 +348,30 @@ impl AgentMessage {
|
||||||
language_model::MessageContent::Image(value.clone())
|
language_model::MessageContent::Image(value.clone())
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
content.push(chunk);
|
assistant_message.content.push(chunk);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut messages = vec![LanguageModelRequestMessage {
|
let mut user_message = LanguageModelRequestMessage {
|
||||||
role: Role::Assistant,
|
role: Role::User,
|
||||||
content,
|
content: Vec::new(),
|
||||||
cache: false,
|
cache: false,
|
||||||
}];
|
};
|
||||||
|
|
||||||
if !self.tool_results.is_empty() {
|
for tool_result in self.tool_results.values() {
|
||||||
let mut tool_results = Vec::with_capacity(self.tool_results.len());
|
user_message
|
||||||
for tool_result in self.tool_results.values() {
|
.content
|
||||||
tool_results.push(language_model::MessageContent::ToolResult(
|
.push(language_model::MessageContent::ToolResult(
|
||||||
tool_result.clone(),
|
tool_result.clone(),
|
||||||
));
|
));
|
||||||
}
|
|
||||||
messages.push(LanguageModelRequestMessage {
|
|
||||||
role: Role::User,
|
|
||||||
content: tool_results,
|
|
||||||
cache: false,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
if !assistant_message.content.is_empty() {
|
||||||
|
messages.push(assistant_message);
|
||||||
|
}
|
||||||
|
if !user_message.content.is_empty() {
|
||||||
|
messages.push(user_message);
|
||||||
|
}
|
||||||
messages
|
messages
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -413,11 +420,12 @@ pub struct Thread {
|
||||||
running_turn: Option<Task<()>>,
|
running_turn: Option<Task<()>>,
|
||||||
pending_message: Option<AgentMessage>,
|
pending_message: Option<AgentMessage>,
|
||||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||||
|
tool_use_limit_reached: bool,
|
||||||
context_server_registry: Entity<ContextServerRegistry>,
|
context_server_registry: Entity<ContextServerRegistry>,
|
||||||
profile_id: AgentProfileId,
|
profile_id: AgentProfileId,
|
||||||
project_context: Rc<RefCell<ProjectContext>>,
|
project_context: Rc<RefCell<ProjectContext>>,
|
||||||
templates: Arc<Templates>,
|
templates: Arc<Templates>,
|
||||||
pub selected_model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
action_log: Entity<ActionLog>,
|
action_log: Entity<ActionLog>,
|
||||||
}
|
}
|
||||||
|
@ -429,7 +437,7 @@ impl Thread {
|
||||||
context_server_registry: Entity<ContextServerRegistry>,
|
context_server_registry: Entity<ContextServerRegistry>,
|
||||||
action_log: Entity<ActionLog>,
|
action_log: Entity<ActionLog>,
|
||||||
templates: Arc<Templates>,
|
templates: Arc<Templates>,
|
||||||
default_model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||||
|
@ -439,11 +447,12 @@ impl Thread {
|
||||||
running_turn: None,
|
running_turn: None,
|
||||||
pending_message: None,
|
pending_message: None,
|
||||||
tools: BTreeMap::default(),
|
tools: BTreeMap::default(),
|
||||||
|
tool_use_limit_reached: false,
|
||||||
context_server_registry,
|
context_server_registry,
|
||||||
profile_id,
|
profile_id,
|
||||||
project_context,
|
project_context,
|
||||||
templates,
|
templates,
|
||||||
selected_model: default_model,
|
model,
|
||||||
project,
|
project,
|
||||||
action_log,
|
action_log,
|
||||||
}
|
}
|
||||||
|
@ -457,7 +466,19 @@ impl Thread {
|
||||||
&self.action_log
|
&self.action_log
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_mode(&mut self, mode: CompletionMode) {
|
pub fn model(&self) -> &Arc<dyn LanguageModel> {
|
||||||
|
&self.model
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
|
||||||
|
self.model = model;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn completion_mode(&self) -> CompletionMode {
|
||||||
|
self.completion_mode
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
|
||||||
self.completion_mode = mode;
|
self.completion_mode = mode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -499,36 +520,59 @@ impl Thread {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn resume(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
|
||||||
|
anyhow::ensure!(
|
||||||
|
self.tool_use_limit_reached,
|
||||||
|
"can only resume after tool use limit is reached"
|
||||||
|
);
|
||||||
|
|
||||||
|
self.messages.push(Message::Resume);
|
||||||
|
cx.notify();
|
||||||
|
|
||||||
|
log::info!("Total messages in thread: {}", self.messages.len());
|
||||||
|
Ok(self.run_turn(cx))
|
||||||
|
}
|
||||||
|
|
||||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||||
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
||||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||||
pub fn send<T>(
|
pub fn send<T>(
|
||||||
&mut self,
|
&mut self,
|
||||||
message_id: UserMessageId,
|
id: UserMessageId,
|
||||||
content: impl IntoIterator<Item = T>,
|
content: impl IntoIterator<Item = T>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
|
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
|
||||||
where
|
where
|
||||||
T: Into<UserMessageContent>,
|
T: Into<UserMessageContent>,
|
||||||
{
|
{
|
||||||
let model = self.selected_model.clone();
|
log::info!("Thread::send called with model: {:?}", self.model.name());
|
||||||
|
|
||||||
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
|
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
|
||||||
log::info!("Thread::send called with model: {:?}", model.name());
|
|
||||||
log::debug!("Thread::send content: {:?}", content);
|
log::debug!("Thread::send content: {:?}", content);
|
||||||
|
|
||||||
|
self.messages
|
||||||
|
.push(Message::User(UserMessage { id, content }));
|
||||||
cx.notify();
|
cx.notify();
|
||||||
let (events_tx, events_rx) =
|
|
||||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
|
||||||
let event_stream = AgentResponseEventStream(events_tx);
|
|
||||||
|
|
||||||
self.messages.push(Message::User(UserMessage {
|
|
||||||
id: message_id.clone(),
|
|
||||||
content,
|
|
||||||
}));
|
|
||||||
log::info!("Total messages in thread: {}", self.messages.len());
|
log::info!("Total messages in thread: {}", self.messages.len());
|
||||||
|
self.run_turn(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_turn(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
|
||||||
|
let model = self.model.clone();
|
||||||
|
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
||||||
|
let event_stream = AgentResponseEventStream(events_tx);
|
||||||
|
let message_ix = self.messages.len().saturating_sub(1);
|
||||||
|
self.tool_use_limit_reached = false;
|
||||||
self.running_turn = Some(cx.spawn(async move |this, cx| {
|
self.running_turn = Some(cx.spawn(async move |this, cx| {
|
||||||
log::info!("Starting agent turn execution");
|
log::info!("Starting agent turn execution");
|
||||||
let turn_result = async {
|
let turn_result: Result<()> = async {
|
||||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||||
loop {
|
loop {
|
||||||
log::debug!(
|
log::debug!(
|
||||||
|
@ -543,13 +587,22 @@ impl Thread {
|
||||||
let mut events = model.stream_completion(request, cx).await?;
|
let mut events = model.stream_completion(request, cx).await?;
|
||||||
log::debug!("Stream completion started successfully");
|
log::debug!("Stream completion started successfully");
|
||||||
|
|
||||||
|
let mut tool_use_limit_reached = false;
|
||||||
let mut tool_uses = FuturesUnordered::new();
|
let mut tool_uses = FuturesUnordered::new();
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
match event? {
|
match event? {
|
||||||
|
LanguageModelCompletionEvent::StatusUpdate(
|
||||||
|
CompletionRequestStatus::ToolUseLimitReached,
|
||||||
|
) => {
|
||||||
|
tool_use_limit_reached = true;
|
||||||
|
}
|
||||||
LanguageModelCompletionEvent::Stop(reason) => {
|
LanguageModelCompletionEvent::Stop(reason) => {
|
||||||
event_stream.send_stop(reason);
|
event_stream.send_stop(reason);
|
||||||
if reason == StopReason::Refusal {
|
if reason == StopReason::Refusal {
|
||||||
this.update(cx, |this, _cx| this.truncate(message_id))??;
|
this.update(cx, |this, _cx| {
|
||||||
|
this.flush_pending_message();
|
||||||
|
this.messages.truncate(message_ix);
|
||||||
|
})?;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -567,12 +620,7 @@ impl Thread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool_uses.is_empty() {
|
let used_tools = tool_uses.is_empty();
|
||||||
log::info!("No tool uses found, completing turn");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
log::info!("Found {} tool uses to execute", tool_uses.len());
|
|
||||||
|
|
||||||
while let Some(tool_result) = tool_uses.next().await {
|
while let Some(tool_result) = tool_uses.next().await {
|
||||||
log::info!("Tool finished {:?}", tool_result);
|
log::info!("Tool finished {:?}", tool_result);
|
||||||
|
|
||||||
|
@ -596,8 +644,17 @@ impl Thread {
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
this.update(cx, |this, _| this.flush_pending_message())?;
|
if tool_use_limit_reached {
|
||||||
completion_intent = CompletionIntent::ToolResults;
|
log::info!("Tool use limit reached, completing turn");
|
||||||
|
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
|
||||||
|
return Err(language_model::ToolUseLimitReachedError.into());
|
||||||
|
} else if used_tools {
|
||||||
|
log::info!("No tool uses found, completing turn");
|
||||||
|
return Ok(());
|
||||||
|
} else {
|
||||||
|
this.update(cx, |this, _| this.flush_pending_message())?;
|
||||||
|
completion_intent = CompletionIntent::ToolResults;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.await;
|
.await;
|
||||||
|
@ -678,10 +735,10 @@ impl Thread {
|
||||||
fn handle_text_event(
|
fn handle_text_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
new_text: String,
|
new_text: String,
|
||||||
events_stream: &AgentResponseEventStream,
|
event_stream: &AgentResponseEventStream,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
events_stream.send_text(&new_text);
|
event_stream.send_text(&new_text);
|
||||||
|
|
||||||
let last_message = self.pending_message();
|
let last_message = self.pending_message();
|
||||||
if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
|
if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
|
||||||
|
@ -798,8 +855,9 @@ impl Thread {
|
||||||
status: Some(acp::ToolCallStatus::InProgress),
|
status: Some(acp::ToolCallStatus::InProgress),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
});
|
});
|
||||||
let supports_images = self.selected_model.supports_images();
|
let supports_images = self.model.supports_images();
|
||||||
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
|
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
|
||||||
|
log::info!("Running tool {}", tool_use.name);
|
||||||
Some(cx.foreground_executor().spawn(async move {
|
Some(cx.foreground_executor().spawn(async move {
|
||||||
let tool_result = tool_result.await.and_then(|output| {
|
let tool_result = tool_result.await.and_then(|output| {
|
||||||
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
|
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
|
||||||
|
@ -902,7 +960,7 @@ impl Thread {
|
||||||
name: tool_name,
|
name: tool_name,
|
||||||
description: tool.description().to_string(),
|
description: tool.description().to_string(),
|
||||||
input_schema: tool
|
input_schema: tool
|
||||||
.input_schema(self.selected_model.tool_input_format())
|
.input_schema(self.model.tool_input_format())
|
||||||
.log_err()?,
|
.log_err()?,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -917,7 +975,7 @@ impl Thread {
|
||||||
thread_id: None,
|
thread_id: None,
|
||||||
prompt_id: None,
|
prompt_id: None,
|
||||||
intent: Some(completion_intent),
|
intent: Some(completion_intent),
|
||||||
mode: Some(self.completion_mode),
|
mode: Some(self.completion_mode.into()),
|
||||||
messages,
|
messages,
|
||||||
tools,
|
tools,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
|
@ -935,7 +993,7 @@ impl Thread {
|
||||||
.profiles
|
.profiles
|
||||||
.get(&self.profile_id)
|
.get(&self.profile_id)
|
||||||
.context("profile not found")?;
|
.context("profile not found")?;
|
||||||
let provider_id = self.selected_model.provider_id();
|
let provider_id = self.model.provider_id();
|
||||||
|
|
||||||
Ok(self
|
Ok(self
|
||||||
.tools
|
.tools
|
||||||
|
@ -971,6 +1029,11 @@ impl Thread {
|
||||||
match message {
|
match message {
|
||||||
Message::User(message) => messages.push(message.to_request()),
|
Message::User(message) => messages.push(message.to_request()),
|
||||||
Message::Agent(message) => messages.extend(message.to_request()),
|
Message::Agent(message) => messages.extend(message.to_request()),
|
||||||
|
Message::Resume => messages.push(LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec!["Continue where you left off".into()],
|
||||||
|
cache: false,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1123,9 +1186,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct AgentResponseEventStream(
|
struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
|
||||||
mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
impl AgentResponseEventStream {
|
impl AgentResponseEventStream {
|
||||||
fn send_text(&self, text: &str) {
|
fn send_text(&self, text: &str) {
|
||||||
|
@ -1212,8 +1273,8 @@ impl AgentResponseEventStream {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_error(&self, error: LanguageModelCompletionError) {
|
fn send_error(&self, error: impl Into<anyhow::Error>) {
|
||||||
self.0.unbounded_send(Err(error)).ok();
|
self.0.unbounded_send(Err(error.into())).ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1229,8 +1290,7 @@ pub struct ToolCallEventStream {
|
||||||
impl ToolCallEventStream {
|
impl ToolCallEventStream {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
|
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
|
||||||
let (events_tx, events_rx) =
|
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
||||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
|
||||||
|
|
||||||
let stream = ToolCallEventStream::new(
|
let stream = ToolCallEventStream::new(
|
||||||
&LanguageModelToolUse {
|
&LanguageModelToolUse {
|
||||||
|
@ -1351,9 +1411,7 @@ impl ToolCallEventStream {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub struct ToolCallEventStreamReceiver(
|
pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
|
||||||
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
impl ToolCallEventStreamReceiver {
|
impl ToolCallEventStreamReceiver {
|
||||||
|
@ -1381,7 +1439,7 @@ impl ToolCallEventStreamReceiver {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
impl std::ops::Deref for ToolCallEventStreamReceiver {
|
impl std::ops::Deref for ToolCallEventStreamReceiver {
|
||||||
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
|
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
&self.0
|
&self.0
|
||||||
|
|
|
@ -241,7 +241,7 @@ impl AgentTool for EditFileTool {
|
||||||
thread.build_completion_request(CompletionIntent::ToolResults, cx)
|
thread.build_completion_request(CompletionIntent::ToolResults, cx)
|
||||||
});
|
});
|
||||||
let thread = self.thread.read(cx);
|
let thread = self.thread.read(cx);
|
||||||
let model = thread.selected_model.clone();
|
let model = thread.model().clone();
|
||||||
let action_log = thread.action_log().clone();
|
let action_log = thread.action_log().clone();
|
||||||
|
|
||||||
let authorize = self.authorize(&input, &event_stream, cx);
|
let authorize = self.authorize(&input, &event_stream, cx);
|
||||||
|
|
|
@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||||
use futures::channel::oneshot;
|
use futures::channel::oneshot;
|
||||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use std::{cell::RefCell, path::Path, rc::Rc};
|
use std::{any::Any, cell::RefCell, path::Path, rc::Rc};
|
||||||
use ui::App;
|
use ui::App;
|
||||||
use util::ResultExt as _;
|
use util::ResultExt as _;
|
||||||
|
|
||||||
|
@ -507,4 +507,8 @@ impl AgentConnection for AcpConnection {
|
||||||
})
|
})
|
||||||
.detach_and_log_err(cx)
|
.detach_and_log_err(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,9 +3,9 @@ use anyhow::anyhow;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use futures::channel::oneshot;
|
use futures::channel::oneshot;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use std::cell::RefCell;
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
use std::{any::Any, cell::RefCell};
|
||||||
|
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result};
|
||||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||||
|
@ -191,6 +191,10 @@ impl AgentConnection for AcpConnection {
|
||||||
.spawn(async move { conn.cancel(params).await })
|
.spawn(async move { conn.cancel(params).await })
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ClientDelegate {
|
struct ClientDelegate {
|
||||||
|
|
|
@ -6,6 +6,7 @@ use context_server::listener::McpServerTool;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::process::Child;
|
use smol::process::Child;
|
||||||
|
use std::any::Any;
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
@ -289,6 +290,10 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
})
|
})
|
||||||
.log_err();
|
.log_err();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
|
|
|
@ -7,20 +7,21 @@ use action_log::ActionLog;
|
||||||
use agent::{TextThreadStore, ThreadStore};
|
use agent::{TextThreadStore, ThreadStore};
|
||||||
use agent_client_protocol::{self as acp};
|
use agent_client_protocol::{self as acp};
|
||||||
use agent_servers::AgentServer;
|
use agent_servers::AgentServer;
|
||||||
use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
|
use agent_settings::{AgentSettings, CompletionMode, NotifyWhenAgentWaiting};
|
||||||
use anyhow::bail;
|
use anyhow::bail;
|
||||||
use audio::{Audio, Sound};
|
use audio::{Audio, Sound};
|
||||||
use buffer_diff::BufferDiff;
|
use buffer_diff::BufferDiff;
|
||||||
|
use client::zed_urls;
|
||||||
use collections::{HashMap, HashSet};
|
use collections::{HashMap, HashSet};
|
||||||
use editor::scroll::Autoscroll;
|
use editor::scroll::Autoscroll;
|
||||||
use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects};
|
use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects};
|
||||||
use file_icons::FileIcons;
|
use file_icons::FileIcons;
|
||||||
use gpui::{
|
use gpui::{
|
||||||
Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, EdgesRefinement, Empty, Entity,
|
Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement,
|
||||||
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay,
|
Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton,
|
||||||
SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement,
|
PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle,
|
||||||
Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop,
|
TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div,
|
||||||
linear_gradient, list, percentage, point, prelude::*, pulsating_between,
|
linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between,
|
||||||
};
|
};
|
||||||
use language::Buffer;
|
use language::Buffer;
|
||||||
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
|
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
|
||||||
|
@ -32,8 +33,8 @@ use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
|
||||||
use text::Anchor;
|
use text::Anchor;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::{
|
use ui::{
|
||||||
Disclosure, Divider, DividerColor, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState,
|
Callout, Disclosure, Divider, DividerColor, ElevationIndex, KeyBinding, PopoverMenuHandle,
|
||||||
Tooltip, prelude::*,
|
Scrollbar, ScrollbarState, Tooltip, prelude::*,
|
||||||
};
|
};
|
||||||
use util::{ResultExt, size::format_file_size, time::duration_alt_display};
|
use util::{ResultExt, size::format_file_size, time::duration_alt_display};
|
||||||
use workspace::{CollaboratorId, Workspace};
|
use workspace::{CollaboratorId, Workspace};
|
||||||
|
@ -44,16 +45,39 @@ use super::entry_view_state::EntryViewState;
|
||||||
use crate::acp::AcpModelSelectorPopover;
|
use crate::acp::AcpModelSelectorPopover;
|
||||||
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
|
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
|
||||||
use crate::agent_diff::AgentDiff;
|
use crate::agent_diff::AgentDiff;
|
||||||
use crate::ui::{AgentNotification, AgentNotificationEvent};
|
use crate::ui::{AgentNotification, AgentNotificationEvent, BurnModeTooltip};
|
||||||
use crate::{
|
use crate::{
|
||||||
AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll,
|
AgentDiffPane, AgentPanel, ContinueThread, ContinueWithBurnMode, ExpandMessageEditor, Follow,
|
||||||
|
KeepAll, OpenAgentDiff, RejectAll, ToggleBurnMode,
|
||||||
};
|
};
|
||||||
|
|
||||||
const RESPONSE_PADDING_X: Pixels = px(19.);
|
const RESPONSE_PADDING_X: Pixels = px(19.);
|
||||||
|
|
||||||
pub const MIN_EDITOR_LINES: usize = 4;
|
pub const MIN_EDITOR_LINES: usize = 4;
|
||||||
pub const MAX_EDITOR_LINES: usize = 8;
|
pub const MAX_EDITOR_LINES: usize = 8;
|
||||||
|
|
||||||
|
enum ThreadError {
|
||||||
|
PaymentRequired,
|
||||||
|
ModelRequestLimitReached(cloud_llm_client::Plan),
|
||||||
|
ToolUseLimitReached,
|
||||||
|
Other(SharedString),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ThreadError {
|
||||||
|
fn from_err(error: anyhow::Error) -> Self {
|
||||||
|
if error.is::<language_model::PaymentRequiredError>() {
|
||||||
|
Self::PaymentRequired
|
||||||
|
} else if error.is::<language_model::ToolUseLimitReachedError>() {
|
||||||
|
Self::ToolUseLimitReached
|
||||||
|
} else if let Some(error) =
|
||||||
|
error.downcast_ref::<language_model::ModelRequestLimitReachedError>()
|
||||||
|
{
|
||||||
|
Self::ModelRequestLimitReached(error.plan)
|
||||||
|
} else {
|
||||||
|
Self::Other(error.to_string().into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct AcpThreadView {
|
pub struct AcpThreadView {
|
||||||
agent: Rc<dyn AgentServer>,
|
agent: Rc<dyn AgentServer>,
|
||||||
workspace: WeakEntity<Workspace>,
|
workspace: WeakEntity<Workspace>,
|
||||||
|
@ -66,7 +90,7 @@ pub struct AcpThreadView {
|
||||||
model_selector: Option<Entity<AcpModelSelectorPopover>>,
|
model_selector: Option<Entity<AcpModelSelectorPopover>>,
|
||||||
notifications: Vec<WindowHandle<AgentNotification>>,
|
notifications: Vec<WindowHandle<AgentNotification>>,
|
||||||
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
|
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
|
||||||
last_error: Option<Entity<Markdown>>,
|
thread_error: Option<ThreadError>,
|
||||||
list_state: ListState,
|
list_state: ListState,
|
||||||
scrollbar_state: ScrollbarState,
|
scrollbar_state: ScrollbarState,
|
||||||
auth_task: Option<Task<()>>,
|
auth_task: Option<Task<()>>,
|
||||||
|
@ -151,7 +175,7 @@ impl AcpThreadView {
|
||||||
entry_view_state: EntryViewState::default(),
|
entry_view_state: EntryViewState::default(),
|
||||||
list_state: list_state.clone(),
|
list_state: list_state.clone(),
|
||||||
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
|
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
|
||||||
last_error: None,
|
thread_error: None,
|
||||||
auth_task: None,
|
auth_task: None,
|
||||||
expanded_tool_calls: HashSet::default(),
|
expanded_tool_calls: HashSet::default(),
|
||||||
expanded_thinking_blocks: HashSet::default(),
|
expanded_thinking_blocks: HashSet::default(),
|
||||||
|
@ -316,7 +340,7 @@ impl AcpThreadView {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cancel_generation(&mut self, cx: &mut Context<Self>) {
|
pub fn cancel_generation(&mut self, cx: &mut Context<Self>) {
|
||||||
self.last_error.take();
|
self.thread_error.take();
|
||||||
|
|
||||||
if let Some(thread) = self.thread() {
|
if let Some(thread) = self.thread() {
|
||||||
self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
|
self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
|
||||||
|
@ -371,6 +395,25 @@ impl AcpThreadView {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn resume_chat(&mut self, cx: &mut Context<Self>) {
|
||||||
|
self.thread_error.take();
|
||||||
|
let Some(thread) = self.thread() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let task = thread.update(cx, |thread, cx| thread.resume(cx));
|
||||||
|
cx.spawn(async move |this, cx| {
|
||||||
|
let result = task.await;
|
||||||
|
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
if let Err(err) = result {
|
||||||
|
this.handle_thread_error(err, cx);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
|
|
||||||
fn send(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
fn send(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||||
let contents = self
|
let contents = self
|
||||||
.message_editor
|
.message_editor
|
||||||
|
@ -384,7 +427,7 @@ impl AcpThreadView {
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
self.last_error.take();
|
self.thread_error.take();
|
||||||
self.editing_message.take();
|
self.editing_message.take();
|
||||||
|
|
||||||
let Some(thread) = self.thread().cloned() else {
|
let Some(thread) = self.thread().cloned() else {
|
||||||
|
@ -409,11 +452,9 @@ impl AcpThreadView {
|
||||||
});
|
});
|
||||||
|
|
||||||
cx.spawn(async move |this, cx| {
|
cx.spawn(async move |this, cx| {
|
||||||
if let Err(e) = task.await {
|
if let Err(err) = task.await {
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
this.last_error =
|
this.handle_thread_error(err, cx);
|
||||||
Some(cx.new(|cx| Markdown::new(e.to_string().into(), None, None, cx)));
|
|
||||||
cx.notify()
|
|
||||||
})
|
})
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
@ -476,6 +517,16 @@ impl AcpThreadView {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn handle_thread_error(&mut self, error: anyhow::Error, cx: &mut Context<Self>) {
|
||||||
|
self.thread_error = Some(ThreadError::from_err(error));
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_thread_error(&mut self, cx: &mut Context<Self>) {
|
||||||
|
self.thread_error = None;
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
fn handle_thread_event(
|
fn handle_thread_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
thread: &Entity<AcpThread>,
|
thread: &Entity<AcpThread>,
|
||||||
|
@ -551,7 +602,7 @@ impl AcpThreadView {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
self.last_error.take();
|
self.thread_error.take();
|
||||||
let authenticate = connection.authenticate(method, cx);
|
let authenticate = connection.authenticate(method, cx);
|
||||||
self.auth_task = Some(cx.spawn_in(window, {
|
self.auth_task = Some(cx.spawn_in(window, {
|
||||||
let project = self.project.clone();
|
let project = self.project.clone();
|
||||||
|
@ -561,9 +612,7 @@ impl AcpThreadView {
|
||||||
|
|
||||||
this.update_in(cx, |this, window, cx| {
|
this.update_in(cx, |this, window, cx| {
|
||||||
if let Err(err) = result {
|
if let Err(err) = result {
|
||||||
this.last_error = Some(cx.new(|cx| {
|
this.handle_thread_error(err, cx);
|
||||||
Markdown::new(format!("Error: {err}").into(), None, None, cx)
|
|
||||||
}))
|
|
||||||
} else {
|
} else {
|
||||||
this.thread_state = Self::initial_state(
|
this.thread_state = Self::initial_state(
|
||||||
agent,
|
agent,
|
||||||
|
@ -620,9 +669,7 @@ impl AcpThreadView {
|
||||||
.py_4()
|
.py_4()
|
||||||
.px_2()
|
.px_2()
|
||||||
.children(message.id.clone().and_then(|message_id| {
|
.children(message.id.clone().and_then(|message_id| {
|
||||||
message.checkpoint.as_ref()?;
|
message.checkpoint.as_ref()?.show.then(|| {
|
||||||
|
|
||||||
Some(
|
|
||||||
Button::new("restore-checkpoint", "Restore Checkpoint")
|
Button::new("restore-checkpoint", "Restore Checkpoint")
|
||||||
.icon(IconName::Undo)
|
.icon(IconName::Undo)
|
||||||
.icon_size(IconSize::XSmall)
|
.icon_size(IconSize::XSmall)
|
||||||
|
@ -630,8 +677,8 @@ impl AcpThreadView {
|
||||||
.label_size(LabelSize::XSmall)
|
.label_size(LabelSize::XSmall)
|
||||||
.on_click(cx.listener(move |this, _, _window, cx| {
|
.on_click(cx.listener(move |this, _, _window, cx| {
|
||||||
this.rewind(&message_id, cx);
|
this.rewind(&message_id, cx);
|
||||||
})),
|
}))
|
||||||
)
|
})
|
||||||
}))
|
}))
|
||||||
.child(
|
.child(
|
||||||
v_flex()
|
v_flex()
|
||||||
|
@ -2322,7 +2369,12 @@ impl AcpThreadView {
|
||||||
h_flex()
|
h_flex()
|
||||||
.flex_none()
|
.flex_none()
|
||||||
.justify_between()
|
.justify_between()
|
||||||
.child(self.render_follow_toggle(cx))
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.gap_1()
|
||||||
|
.child(self.render_follow_toggle(cx))
|
||||||
|
.children(self.render_burn_mode_toggle(cx)),
|
||||||
|
)
|
||||||
.child(
|
.child(
|
||||||
h_flex()
|
h_flex()
|
||||||
.gap_1()
|
.gap_1()
|
||||||
|
@ -2333,6 +2385,68 @@ impl AcpThreadView {
|
||||||
.into_any()
|
.into_any()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn as_native_connection(&self, cx: &App) -> Option<Rc<agent2::NativeAgentConnection>> {
|
||||||
|
let acp_thread = self.thread()?.read(cx);
|
||||||
|
acp_thread.connection().clone().downcast()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_native_thread(&self, cx: &App) -> Option<Entity<agent2::Thread>> {
|
||||||
|
let acp_thread = self.thread()?.read(cx);
|
||||||
|
self.as_native_connection(cx)?
|
||||||
|
.thread(acp_thread.session_id(), cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn toggle_burn_mode(
|
||||||
|
&mut self,
|
||||||
|
_: &ToggleBurnMode,
|
||||||
|
_window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let Some(thread) = self.as_native_thread(cx) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
thread.update(cx, |thread, _cx| {
|
||||||
|
let current_mode = thread.completion_mode();
|
||||||
|
thread.set_completion_mode(match current_mode {
|
||||||
|
CompletionMode::Burn => CompletionMode::Normal,
|
||||||
|
CompletionMode::Normal => CompletionMode::Burn,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_burn_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||||
|
let thread = self.as_native_thread(cx)?.read(cx);
|
||||||
|
|
||||||
|
if !thread.model().supports_burn_mode() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let active_completion_mode = thread.completion_mode();
|
||||||
|
let burn_mode_enabled = active_completion_mode == CompletionMode::Burn;
|
||||||
|
let icon = if burn_mode_enabled {
|
||||||
|
IconName::ZedBurnModeOn
|
||||||
|
} else {
|
||||||
|
IconName::ZedBurnMode
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(
|
||||||
|
IconButton::new("burn-mode", icon)
|
||||||
|
.icon_size(IconSize::Small)
|
||||||
|
.icon_color(Color::Muted)
|
||||||
|
.toggle_state(burn_mode_enabled)
|
||||||
|
.selected_icon_color(Color::Error)
|
||||||
|
.on_click(cx.listener(|this, _event, window, cx| {
|
||||||
|
this.toggle_burn_mode(&ToggleBurnMode, window, cx);
|
||||||
|
}))
|
||||||
|
.tooltip(move |_window, cx| {
|
||||||
|
cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled))
|
||||||
|
.into()
|
||||||
|
})
|
||||||
|
.into_any_element(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn start_editing_message(&mut self, index: usize, window: &mut Window, cx: &mut Context<Self>) {
|
fn start_editing_message(&mut self, index: usize, window: &mut Window, cx: &mut Context<Self>) {
|
||||||
let Some(thread) = self.thread() else {
|
let Some(thread) = self.thread() else {
|
||||||
return;
|
return;
|
||||||
|
@ -3002,6 +3116,187 @@ impl AcpThreadView {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AcpThreadView {
|
||||||
|
fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option<Div> {
|
||||||
|
let content = match self.thread_error.as_ref()? {
|
||||||
|
ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx),
|
||||||
|
ThreadError::PaymentRequired => self.render_payment_required_error(cx),
|
||||||
|
ThreadError::ModelRequestLimitReached(plan) => {
|
||||||
|
self.render_model_request_limit_reached_error(*plan, cx)
|
||||||
|
}
|
||||||
|
ThreadError::ToolUseLimitReached => {
|
||||||
|
self.render_tool_use_limit_reached_error(window, cx)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(
|
||||||
|
div()
|
||||||
|
.border_t_1()
|
||||||
|
.border_color(cx.theme().colors().border)
|
||||||
|
.child(content),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_any_thread_error(&self, error: SharedString, cx: &mut Context<'_, Self>) -> Callout {
|
||||||
|
let icon = Icon::new(IconName::XCircle)
|
||||||
|
.size(IconSize::Small)
|
||||||
|
.color(Color::Error);
|
||||||
|
|
||||||
|
Callout::new()
|
||||||
|
.icon(icon)
|
||||||
|
.title("Error")
|
||||||
|
.description(error.clone())
|
||||||
|
.secondary_action(self.create_copy_button(error.to_string()))
|
||||||
|
.primary_action(self.dismiss_error_button(cx))
|
||||||
|
.bg_color(self.error_callout_bg(cx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_payment_required_error(&self, cx: &mut Context<Self>) -> Callout {
|
||||||
|
const ERROR_MESSAGE: &str =
|
||||||
|
"You reached your free usage limit. Upgrade to Zed Pro for more prompts.";
|
||||||
|
|
||||||
|
let icon = Icon::new(IconName::XCircle)
|
||||||
|
.size(IconSize::Small)
|
||||||
|
.color(Color::Error);
|
||||||
|
|
||||||
|
Callout::new()
|
||||||
|
.icon(icon)
|
||||||
|
.title("Free Usage Exceeded")
|
||||||
|
.description(ERROR_MESSAGE)
|
||||||
|
.tertiary_action(self.upgrade_button(cx))
|
||||||
|
.secondary_action(self.create_copy_button(ERROR_MESSAGE))
|
||||||
|
.primary_action(self.dismiss_error_button(cx))
|
||||||
|
.bg_color(self.error_callout_bg(cx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_model_request_limit_reached_error(
|
||||||
|
&self,
|
||||||
|
plan: cloud_llm_client::Plan,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Callout {
|
||||||
|
let error_message = match plan {
|
||||||
|
cloud_llm_client::Plan::ZedPro => "Upgrade to usage-based billing for more prompts.",
|
||||||
|
cloud_llm_client::Plan::ZedProTrial | cloud_llm_client::Plan::ZedFree => {
|
||||||
|
"Upgrade to Zed Pro for more prompts."
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let icon = Icon::new(IconName::XCircle)
|
||||||
|
.size(IconSize::Small)
|
||||||
|
.color(Color::Error);
|
||||||
|
|
||||||
|
Callout::new()
|
||||||
|
.icon(icon)
|
||||||
|
.title("Model Prompt Limit Reached")
|
||||||
|
.description(error_message)
|
||||||
|
.tertiary_action(self.upgrade_button(cx))
|
||||||
|
.secondary_action(self.create_copy_button(error_message))
|
||||||
|
.primary_action(self.dismiss_error_button(cx))
|
||||||
|
.bg_color(self.error_callout_bg(cx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_tool_use_limit_reached_error(
|
||||||
|
&self,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Option<Callout> {
|
||||||
|
let thread = self.as_native_thread(cx)?;
|
||||||
|
let supports_burn_mode = thread.read(cx).model().supports_burn_mode();
|
||||||
|
|
||||||
|
let focus_handle = self.focus_handle(cx);
|
||||||
|
|
||||||
|
let icon = Icon::new(IconName::Info)
|
||||||
|
.size(IconSize::Small)
|
||||||
|
.color(Color::Info);
|
||||||
|
|
||||||
|
Some(
|
||||||
|
Callout::new()
|
||||||
|
.icon(icon)
|
||||||
|
.title("Consecutive tool use limit reached.")
|
||||||
|
.when(supports_burn_mode, |this| {
|
||||||
|
this.secondary_action(
|
||||||
|
Button::new("continue-burn-mode", "Continue with Burn Mode")
|
||||||
|
.style(ButtonStyle::Filled)
|
||||||
|
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||||
|
.layer(ElevationIndex::ModalSurface)
|
||||||
|
.label_size(LabelSize::Small)
|
||||||
|
.key_binding(
|
||||||
|
KeyBinding::for_action_in(
|
||||||
|
&ContinueWithBurnMode,
|
||||||
|
&focus_handle,
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.map(|kb| kb.size(rems_from_px(10.))),
|
||||||
|
)
|
||||||
|
.tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use."))
|
||||||
|
.on_click({
|
||||||
|
cx.listener(move |this, _, _window, cx| {
|
||||||
|
thread.update(cx, |thread, _cx| {
|
||||||
|
thread.set_completion_mode(CompletionMode::Burn);
|
||||||
|
});
|
||||||
|
this.resume_chat(cx);
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.primary_action(
|
||||||
|
Button::new("continue-conversation", "Continue")
|
||||||
|
.layer(ElevationIndex::ModalSurface)
|
||||||
|
.label_size(LabelSize::Small)
|
||||||
|
.key_binding(
|
||||||
|
KeyBinding::for_action_in(&ContinueThread, &focus_handle, window, cx)
|
||||||
|
.map(|kb| kb.size(rems_from_px(10.))),
|
||||||
|
)
|
||||||
|
.on_click(cx.listener(|this, _, _window, cx| {
|
||||||
|
this.resume_chat(cx);
|
||||||
|
})),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_copy_button(&self, message: impl Into<String>) -> impl IntoElement {
|
||||||
|
let message = message.into();
|
||||||
|
|
||||||
|
IconButton::new("copy", IconName::Copy)
|
||||||
|
.icon_size(IconSize::Small)
|
||||||
|
.icon_color(Color::Muted)
|
||||||
|
.tooltip(Tooltip::text("Copy Error Message"))
|
||||||
|
.on_click(move |_, _, cx| {
|
||||||
|
cx.write_to_clipboard(ClipboardItem::new_string(message.clone()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dismiss_error_button(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||||
|
IconButton::new("dismiss", IconName::Close)
|
||||||
|
.icon_size(IconSize::Small)
|
||||||
|
.icon_color(Color::Muted)
|
||||||
|
.tooltip(Tooltip::text("Dismiss Error"))
|
||||||
|
.on_click(cx.listener({
|
||||||
|
move |this, _, _, cx| {
|
||||||
|
this.clear_thread_error(cx);
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upgrade_button(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||||
|
Button::new("upgrade", "Upgrade")
|
||||||
|
.label_size(LabelSize::Small)
|
||||||
|
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||||
|
.on_click(cx.listener({
|
||||||
|
move |this, _, _, cx| {
|
||||||
|
this.clear_thread_error(cx);
|
||||||
|
cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx));
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn error_callout_bg(&self, cx: &Context<Self>) -> Hsla {
|
||||||
|
cx.theme().status().error.opacity(0.08)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Focusable for AcpThreadView {
|
impl Focusable for AcpThreadView {
|
||||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||||
self.message_editor.focus_handle(cx)
|
self.message_editor.focus_handle(cx)
|
||||||
|
@ -3016,6 +3311,7 @@ impl Render for AcpThreadView {
|
||||||
.size_full()
|
.size_full()
|
||||||
.key_context("AcpThread")
|
.key_context("AcpThread")
|
||||||
.on_action(cx.listener(Self::open_agent_diff))
|
.on_action(cx.listener(Self::open_agent_diff))
|
||||||
|
.on_action(cx.listener(Self::toggle_burn_mode))
|
||||||
.bg(cx.theme().colors().panel_background)
|
.bg(cx.theme().colors().panel_background)
|
||||||
.child(match &self.thread_state {
|
.child(match &self.thread_state {
|
||||||
ThreadState::Unauthenticated { connection } => v_flex()
|
ThreadState::Unauthenticated { connection } => v_flex()
|
||||||
|
@ -3100,19 +3396,7 @@ impl Render for AcpThreadView {
|
||||||
}
|
}
|
||||||
_ => this,
|
_ => this,
|
||||||
})
|
})
|
||||||
.when_some(self.last_error.clone(), |el, error| {
|
.children(self.render_thread_error(window, cx))
|
||||||
el.child(
|
|
||||||
div()
|
|
||||||
.p_2()
|
|
||||||
.text_xs()
|
|
||||||
.border_t_1()
|
|
||||||
.border_color(cx.theme().colors().border)
|
|
||||||
.bg(cx.theme().status().error_background)
|
|
||||||
.child(
|
|
||||||
self.render_markdown(error, default_markdown_style(false, window, cx)),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.child(self.render_message_editor(window, cx))
|
.child(self.render_message_editor(window, cx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3299,8 +3583,6 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub(crate) mod tests {
|
pub(crate) mod tests {
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
use acp_thread::StubAgentConnection;
|
use acp_thread::StubAgentConnection;
|
||||||
use agent::{TextThreadStore, ThreadStore};
|
use agent::{TextThreadStore, ThreadStore};
|
||||||
use agent_client_protocol::SessionId;
|
use agent_client_protocol::SessionId;
|
||||||
|
@ -3310,6 +3592,8 @@ pub(crate) mod tests {
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
|
use std::any::Any;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
@ -3547,6 +3831,10 @@ pub(crate) mod tests {
|
||||||
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
|
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn init_test(cx: &mut TestAppContext) {
|
pub(crate) fn init_test(cx: &mut TestAppContext) {
|
||||||
|
|
|
@ -5,7 +5,6 @@ mod agent_diff;
|
||||||
mod agent_model_selector;
|
mod agent_model_selector;
|
||||||
mod agent_panel;
|
mod agent_panel;
|
||||||
mod buffer_codegen;
|
mod buffer_codegen;
|
||||||
mod burn_mode_tooltip;
|
|
||||||
mod context_picker;
|
mod context_picker;
|
||||||
mod context_server_configuration;
|
mod context_server_configuration;
|
||||||
mod context_strip;
|
mod context_strip;
|
||||||
|
|
|
@ -1,61 +0,0 @@
|
||||||
use gpui::{Context, FontWeight, IntoElement, Render, Window};
|
|
||||||
use ui::{prelude::*, tooltip_container};
|
|
||||||
|
|
||||||
pub struct BurnModeTooltip {
|
|
||||||
selected: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BurnModeTooltip {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self { selected: false }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn selected(mut self, selected: bool) -> Self {
|
|
||||||
self.selected = selected;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Render for BurnModeTooltip {
|
|
||||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
|
||||||
let (icon, color) = if self.selected {
|
|
||||||
(IconName::ZedBurnModeOn, Color::Error)
|
|
||||||
} else {
|
|
||||||
(IconName::ZedBurnMode, Color::Default)
|
|
||||||
};
|
|
||||||
|
|
||||||
let turned_on = h_flex()
|
|
||||||
.h_4()
|
|
||||||
.px_1()
|
|
||||||
.border_1()
|
|
||||||
.border_color(cx.theme().colors().border)
|
|
||||||
.bg(cx.theme().colors().text_accent.opacity(0.1))
|
|
||||||
.rounded_sm()
|
|
||||||
.child(
|
|
||||||
Label::new("ON")
|
|
||||||
.size(LabelSize::XSmall)
|
|
||||||
.weight(FontWeight::SEMIBOLD)
|
|
||||||
.color(Color::Accent),
|
|
||||||
);
|
|
||||||
|
|
||||||
let title = h_flex()
|
|
||||||
.gap_1p5()
|
|
||||||
.child(Icon::new(icon).size(IconSize::Small).color(color))
|
|
||||||
.child(Label::new("Burn Mode"))
|
|
||||||
.when(self.selected, |title| title.child(turned_on));
|
|
||||||
|
|
||||||
tooltip_container(window, cx, |this, _, _| {
|
|
||||||
this
|
|
||||||
.child(title)
|
|
||||||
.child(
|
|
||||||
div()
|
|
||||||
.max_w_64()
|
|
||||||
.child(
|
|
||||||
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.")
|
|
||||||
.size(LabelSize::Small)
|
|
||||||
.color(Color::Muted)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -6,7 +6,7 @@ use crate::agent_diff::AgentDiffThread;
|
||||||
use crate::agent_model_selector::AgentModelSelector;
|
use crate::agent_model_selector::AgentModelSelector;
|
||||||
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
|
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
|
||||||
use crate::ui::{
|
use crate::ui::{
|
||||||
MaxModeTooltip,
|
BurnModeTooltip,
|
||||||
preview::{AgentPreview, UsageCallout},
|
preview::{AgentPreview, UsageCallout},
|
||||||
};
|
};
|
||||||
use agent::history_store::HistoryStore;
|
use agent::history_store::HistoryStore;
|
||||||
|
@ -605,7 +605,7 @@ impl MessageEditor {
|
||||||
this.toggle_burn_mode(&ToggleBurnMode, window, cx);
|
this.toggle_burn_mode(&ToggleBurnMode, window, cx);
|
||||||
}))
|
}))
|
||||||
.tooltip(move |_window, cx| {
|
.tooltip(move |_window, cx| {
|
||||||
cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled))
|
cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled))
|
||||||
.into()
|
.into()
|
||||||
})
|
})
|
||||||
.into_any_element(),
|
.into_any_element(),
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
burn_mode_tooltip::BurnModeTooltip,
|
|
||||||
language_model_selector::{LanguageModelSelector, language_model_selector},
|
language_model_selector::{LanguageModelSelector, language_model_selector},
|
||||||
|
ui::BurnModeTooltip,
|
||||||
};
|
};
|
||||||
use agent_settings::{AgentSettings, CompletionMode};
|
use agent_settings::{AgentSettings, CompletionMode};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|
|
@ -2,11 +2,11 @@ use crate::ToggleBurnMode;
|
||||||
use gpui::{Context, FontWeight, IntoElement, Render, Window};
|
use gpui::{Context, FontWeight, IntoElement, Render, Window};
|
||||||
use ui::{KeyBinding, prelude::*, tooltip_container};
|
use ui::{KeyBinding, prelude::*, tooltip_container};
|
||||||
|
|
||||||
pub struct MaxModeTooltip {
|
pub struct BurnModeTooltip {
|
||||||
selected: bool,
|
selected: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MaxModeTooltip {
|
impl BurnModeTooltip {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self { selected: false }
|
Self { selected: false }
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,7 @@ impl MaxModeTooltip {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Render for MaxModeTooltip {
|
impl Render for BurnModeTooltip {
|
||||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||||
let (icon, color) = if self.selected {
|
let (icon, color) = if self.selected {
|
||||||
(IconName::ZedBurnModeOn, Color::Error)
|
(IconName::ZedBurnModeOn, Color::Error)
|
||||||
|
|
|
@ -42,6 +42,18 @@ impl fmt::Display for ModelRequestLimitReachedError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub struct ToolUseLimitReachedError;
|
||||||
|
|
||||||
|
impl fmt::Display for ToolUseLimitReachedError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"Consecutive tool use limit reached. Enable Burn Mode for unlimited tool use."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Default)]
|
#[derive(Clone, Default)]
|
||||||
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue