ZIm/crates/agent2/src/agent.rs
2025-08-18 14:29:33 -06:00

1332 lines
48 KiB
Rust

use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
use crate::{
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
UserMessageContent, WebSearchTool, templates::Templates,
};
use crate::{ThreadsDatabase, generate_session_id};
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use collections::{HashSet, IndexMap};
use fs::Fs;
use futures::channel::mpsc;
use futures::{StreamExt, future};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
};
use settings::update_settings_file;
use std::any::Any;
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use std::time::Duration;
use util::ResultExt;
const RULES_FILE_NAMES: [&'static str; 9] = [
".rules",
".cursorrules",
".windsurfrules",
".clinerules",
".github/copilot-instructions.md",
"CLAUDE.md",
"AGENT.md",
"AGENTS.md",
"GEMINI.md",
];
const SAVE_THREAD_DEBOUNCE: Duration = Duration::from_millis(500);
pub struct RulesLoadingError {
pub message: SharedString,
}
/// Holds both the internal Thread and the AcpThread for a session
struct Session {
/// The internal thread that processes messages
thread: Entity<Thread>,
/// The ACP thread that handles protocol communication
acp_thread: WeakEntity<acp_thread::AcpThread>,
save_task: Task<Result<()>>,
_subscriptions: Vec<Subscription>,
}
pub struct LanguageModels {
/// Access language model by ID
models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
/// Cached list for returning language model information
model_list: acp_thread::AgentModelList,
refresh_models_rx: watch::Receiver<()>,
refresh_models_tx: watch::Sender<()>,
}
impl LanguageModels {
fn new(cx: &App) -> Self {
let (refresh_models_tx, refresh_models_rx) = watch::channel(());
let mut this = Self {
models: HashMap::default(),
model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
refresh_models_rx,
refresh_models_tx,
};
this.refresh_list(cx);
this
}
fn refresh_list(&mut self, cx: &App) {
let providers = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
let mut language_model_list = IndexMap::default();
let mut recommended_models = HashSet::default();
let mut recommended = Vec::new();
for provider in &providers {
for model in provider.recommended_models(cx) {
recommended_models.insert(model.id());
recommended.push(Self::map_language_model_to_info(&model, &provider));
}
}
if !recommended.is_empty() {
language_model_list.insert(
acp_thread::AgentModelGroupName("Recommended".into()),
recommended,
);
}
let mut models = HashMap::default();
for provider in providers {
let mut provider_models = Vec::new();
for model in provider.provided_models(cx) {
let model_info = Self::map_language_model_to_info(&model, &provider);
let model_id = model_info.id.clone();
if !recommended_models.contains(&model.id()) {
provider_models.push(model_info);
}
models.insert(model_id, model);
}
if !provider_models.is_empty() {
language_model_list.insert(
acp_thread::AgentModelGroupName(provider.name().0.clone()),
provider_models,
);
}
}
self.models = models;
self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
self.refresh_models_tx.send(()).ok();
}
fn watch(&self) -> watch::Receiver<()> {
self.refresh_models_rx.clone()
}
pub fn model_from_id(
&self,
model_id: &acp_thread::AgentModelId,
) -> Option<Arc<dyn LanguageModel>> {
dbg!(&self.models.len());
self.models.get(model_id).cloned()
}
fn map_language_model_to_info(
model: &Arc<dyn LanguageModel>,
provider: &Arc<dyn LanguageModelProvider>,
) -> acp_thread::AgentModelInfo {
acp_thread::AgentModelInfo {
id: Self::model_id(model),
name: model.name().0,
icon: Some(provider.icon()),
}
}
fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
}
}
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
/// Shared project context for all threads
project_context: Rc<RefCell<ProjectContext>>,
project_context_needs_refresh: watch::Sender<()>,
_maintain_project_context: Task<Result<()>>,
context_server_registry: Entity<ContextServerRegistry>,
/// Shared templates for all threads
templates: Arc<Templates>,
/// Cached model information
models: LanguageModels,
project: Entity<Project>,
prompt_store: Option<Entity<PromptStore>>,
thread_database: Arc<ThreadsDatabase>,
history: watch::Sender<Option<Vec<AcpThreadMetadata>>>,
load_history: Task<()>,
fs: Arc<dyn Fs>,
_subscriptions: Vec<Subscription>,
}
impl NativeAgent {
pub async fn new(
project: Entity<Project>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
cx: &mut AsyncApp,
) -> Result<Entity<NativeAgent>> {
log::info!("Creating new NativeAgent");
let project_context = cx
.update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
.await;
let thread_database = cx
.update(|cx| ThreadsDatabase::connect(cx))?
.await
.map_err(|e| anyhow!(e))?;
cx.new(|cx| {
let mut subscriptions = vec![
cx.subscribe(&project, Self::handle_project_event),
cx.subscribe(
&LanguageModelRegistry::global(cx),
Self::handle_models_updated_event,
),
];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
}
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
watch::channel(());
let mut this = Self {
sessions: HashMap::new(),
project_context: Rc::new(RefCell::new(project_context)),
project_context_needs_refresh: project_context_needs_refresh_tx,
_maintain_project_context: cx.spawn(async move |this, cx| {
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
}),
context_server_registry: cx.new(|cx| {
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
}),
thread_database,
templates,
models: LanguageModels::new(cx),
project,
prompt_store,
fs,
history: watch::channel(None).0,
load_history: Task::ready(()),
_subscriptions: subscriptions,
};
this.reload_history(cx);
this
})
}
pub fn insert_session(
&mut self,
thread: Entity<Thread>,
acp_thread: Entity<AcpThread>,
cx: &mut Context<Self>,
) {
let id = thread.read(cx).id().clone();
self.sessions.insert(
id,
Session {
thread: thread.clone(),
acp_thread: acp_thread.downgrade(),
save_task: Task::ready(Ok(())),
_subscriptions: vec![
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
cx.observe(&thread, |this, thread, cx| {
thread.update(cx, |thread, cx| {
thread.generate_title_if_needed(cx);
});
this.save_thread(thread.clone(), cx)
}),
],
},
);
}
fn save_thread(&mut self, thread_handle: Entity<Thread>, cx: &mut Context<Self>) {
let thread = thread_handle.read(cx);
let id = thread.id().clone();
let Some(session) = self.sessions.get_mut(&id) else {
return;
};
let thread = thread_handle.downgrade();
let thread_database = self.thread_database.clone();
session.save_task = cx.spawn(async move |this, cx| {
cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await;
thread_database.save_thread(id, db_thread).await?;
this.update(cx, |this, cx| this.reload_history(cx))?;
Ok(())
});
}
fn reload_history(&mut self, cx: &mut Context<Self>) {
let thread_database = self.thread_database.clone();
self.load_history = cx.spawn(async move |this, cx| {
let results = cx
.background_spawn(async move {
let results = thread_database.list_threads().await?;
anyhow::Ok(
results
.into_iter()
.map(|thread| AcpThreadMetadata {
agent: NATIVE_AGENT_SERVER_NAME.clone(),
id: thread.id.into(),
title: thread.title,
updated_at: thread.updated_at,
})
.collect(),
)
})
.await;
if let Some(results) = results.log_err() {
this.update(cx, |this, _| this.history.send(Some(results)))
.ok();
}
});
}
pub fn models(&self) -> &LanguageModels {
&self.models
}
async fn maintain_project_context(
this: WeakEntity<Self>,
mut needs_refresh: watch::Receiver<()>,
cx: &mut AsyncApp,
) -> Result<()> {
while needs_refresh.changed().await.is_ok() {
let project_context = this
.update(cx, |this, cx| {
Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
})?
.await;
this.update(cx, |this, _| this.project_context.replace(project_context))?;
}
Ok(())
}
fn build_project_context(
project: &Entity<Project>,
prompt_store: Option<&Entity<PromptStore>>,
cx: &mut App,
) -> Task<ProjectContext> {
let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
let worktree_tasks = worktrees
.into_iter()
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
})
.collect::<Vec<_>>();
let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
prompt_store.read_with(cx, |prompt_store, cx| {
let prompts = prompt_store.default_prompt_metadata();
let load_tasks = prompts.into_iter().map(|prompt_metadata| {
let contents = prompt_store.load(prompt_metadata.id, cx);
async move { (contents.await, prompt_metadata) }
});
cx.background_spawn(future::join_all(load_tasks))
})
} else {
Task::ready(vec![])
};
cx.spawn(async move |_cx| {
let (worktrees, default_user_rules) =
future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
let worktrees = worktrees
.into_iter()
.map(|(worktree, _rules_error)| {
// TODO: show error message
// if let Some(rules_error) = rules_error {
// this.update(cx, |_, cx| cx.emit(rules_error)).ok();
// }
worktree
})
.collect::<Vec<_>>();
let default_user_rules = default_user_rules
.into_iter()
.flat_map(|(contents, prompt_metadata)| match contents {
Ok(contents) => Some(UserRulesContext {
uuid: match prompt_metadata.id {
PromptId::User { uuid } => uuid,
PromptId::EditWorkflow => return None,
},
title: prompt_metadata.title.map(|title| title.to_string()),
contents,
}),
Err(_err) => {
// TODO: show error message
// this.update(cx, |_, cx| {
// cx.emit(RulesLoadingError {
// message: format!("{err:?}").into(),
// });
// })
// .ok();
None
}
})
.collect::<Vec<_>>();
ProjectContext::new(worktrees, default_user_rules)
})
}
fn load_worktree_info_for_system_prompt(
worktree: Entity<Worktree>,
project: Entity<Project>,
cx: &mut App,
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
let tree = worktree.read(cx);
let root_name = tree.root_name().into();
let abs_path = tree.abs_path();
let mut context = WorktreeContext {
root_name,
abs_path,
rules_file: None,
};
let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
let Some(rules_task) = rules_task else {
return Task::ready((context, None));
};
cx.spawn(async move |_| {
let (rules_file, rules_file_error) = match rules_task.await {
Ok(rules_file) => (Some(rules_file), None),
Err(err) => (
None,
Some(RulesLoadingError {
message: format!("{err}").into(),
}),
),
};
context.rules_file = rules_file;
(context, rules_file_error)
})
}
fn load_worktree_rules_file(
worktree: Entity<Worktree>,
project: Entity<Project>,
cx: &mut App,
) -> Option<Task<Result<RulesFileContext>>> {
let worktree = worktree.read(cx);
let worktree_id = worktree.id();
let selected_rules_file = RULES_FILE_NAMES
.into_iter()
.filter_map(|name| {
worktree
.entry_for_path(name)
.filter(|entry| entry.is_file())
.map(|entry| entry.path.clone())
})
.next();
// Note that Cline supports `.clinerules` being a directory, but that is not currently
// supported. This doesn't seem to occur often in GitHub repositories.
selected_rules_file.map(|path_in_worktree| {
let project_path = ProjectPath {
worktree_id,
path: path_in_worktree.clone(),
};
let buffer_task =
project.update(cx, |project, cx| project.open_buffer(project_path, cx));
let rope_task = cx.spawn(async move |cx| {
buffer_task.await?.read_with(cx, |buffer, cx| {
let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
})?
});
// Build a string from the rope on a background thread.
cx.background_spawn(async move {
let (project_entry_id, rope) = rope_task.await?;
anyhow::Ok(RulesFileContext {
path_in_worktree,
text: rope.to_string().trim().to_string(),
project_entry_id: project_entry_id.to_usize(),
})
})
})
}
fn handle_project_event(
&mut self,
_project: Entity<Project>,
event: &project::Event,
_cx: &mut Context<Self>,
) {
match event {
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
self.project_context_needs_refresh.send(()).ok();
}
project::Event::WorktreeUpdatedEntries(_, items) => {
if items.iter().any(|(path, _, _)| {
RULES_FILE_NAMES
.iter()
.any(|name| path.as_ref() == Path::new(name))
}) {
self.project_context_needs_refresh.send(()).ok();
}
}
_ => {}
}
}
fn handle_prompts_updated_event(
&mut self,
_prompt_store: Entity<PromptStore>,
_event: &prompt_store::PromptsUpdatedEvent,
_cx: &mut Context<Self>,
) {
self.project_context_needs_refresh.send(()).ok();
}
fn handle_models_updated_event(
&mut self,
registry: Entity<LanguageModelRegistry>,
_event: &language_model::Event,
cx: &mut Context<Self>,
) {
self.models.refresh_list(cx);
let default_model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|m| m.model.clone());
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, cx| {
if thread.model().is_none()
&& let Some(model) = default_model.clone()
{
thread.set_model(model);
cx.notify();
}
let summarization_model = registry
.read(cx)
.thread_summary_model()
.map(|model| model.model.clone());
thread.set_summarization_model(summarization_model, cx);
});
}
}
}
/// Wrapper struct that implements the AgentConnection trait
#[derive(Clone)]
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
impl NativeAgentConnection {
pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
self.0
.read(cx)
.sessions
.get(session_id)
.map(|session| session.thread.clone())
}
fn run_turn(
&self,
session_id: acp::SessionId,
cx: &mut App,
f: impl 'static
+ FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
) -> Task<Result<acp::PromptResponse>> {
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
agent
.sessions
.get_mut(&session_id)
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
}) else {
return Task::ready(Err(anyhow!("Session not found")));
};
log::debug!("Found session for: {}", session_id);
let response_stream = match f(thread, cx) {
Ok(stream) => stream,
Err(err) => return Task::ready(Err(err)),
};
Self::handle_thread_events(response_stream, acp_thread, cx)
}
fn handle_thread_events(
mut response_stream: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
acp_thread: WeakEntity<AcpThread>,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
cx.spawn(async move |cx| {
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
match event {
ThreadEvent::UserMessage(message) => {
acp_thread.update(cx, |thread, cx| {
for content in message.content {
thread.push_user_content_block(
Some(message.id.clone()),
content.into(),
cx,
);
}
})?;
}
ThreadEvent::AgentText(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
false,
cx,
)
})?;
}
ThreadEvent::AgentThinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
true,
cx,
)
})?;
}
ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
}) => {
let recv = acp_thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call, options, cx)
})?;
cx.background_spawn(async move {
if let Some(recv) = recv.log_err()
&& let Some(option) = recv
.await
.context("authorization sender was dropped")
.log_err()
{
response
.send(option)
.map(|_| anyhow!("authorization receiver was dropped"))
.log_err();
}
})
.detach();
}
ThreadEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})??;
}
ThreadEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
ThreadEvent::TitleUpdate(title) => {
acp_thread
.update(cx, |thread, cx| thread.update_title(title, cx))??;
}
ThreadEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
}
}
Err(e) => {
log::error!("Error in model response stream: {:?}", e);
return Err(e);
}
}
}
log::info!("Response stream completed");
anyhow::Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}
fn register_tools(
thread: &mut Thread,
project: Entity<Project>,
action_log: Entity<action_log::ActionLog>,
cx: &mut Context<Thread>,
) {
let language_registry = project.read(cx).languages().clone();
thread.add_tool(CopyPathTool::new(project.clone()));
thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone()));
thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(GrepTool::new(project.clone()));
thread.add_tool(ListDirectoryTool::new(project.clone()));
thread.add_tool(MovePathTool::new(project.clone()));
thread.add_tool(NowTool);
thread.add_tool(OpenTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(TerminalTool::new(project.clone(), cx));
thread.add_tool(ThinkingTool);
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
}
}
impl AgentModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
log::debug!("NativeAgentConnection::list_models called");
let list = self.0.read(cx).models.model_list.clone();
Task::ready(if list.is_empty() {
Err(anyhow::anyhow!("No models available"))
} else {
Ok(list)
})
}
fn select_model(
&self,
session_id: acp::SessionId,
model_id: acp_thread::AgentModelId,
cx: &mut App,
) -> Task<Result<()>> {
log::info!("Setting model for session {}: {}", session_id, model_id);
let Some(thread) = self
.0
.read(cx)
.sessions
.get(&session_id)
.map(|session| session.thread.clone())
else {
return Task::ready(Err(anyhow!("Session not found")));
};
let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
};
thread.update(cx, |thread, cx| {
thread.set_model(model.clone(), cx);
});
update_settings_file::<AgentSettings>(
self.0.read(cx).fs.clone(),
cx,
move |settings, _cx| {
settings.set_model(model);
},
);
Task::ready(Ok(()))
}
fn selected_model(
&self,
session_id: &acp::SessionId,
cx: &mut App,
) -> Task<Result<acp_thread::AgentModelInfo>> {
let session_id = session_id.clone();
let Some(thread) = self
.0
.read(cx)
.sessions
.get(&session_id)
.map(|session| session.thread.clone())
else {
return Task::ready(Err(anyhow!("Session not found")));
};
let Some(model) = thread.read(cx).model() else {
return Task::ready(Err(anyhow!("Model not found")));
};
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
else {
return Task::ready(Err(anyhow!("Provider not found")));
};
Task::ready(Ok(LanguageModels::map_language_model_to_info(
model, &provider,
)))
}
fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
self.0.read(cx).models.watch()
}
}
impl acp_thread::AgentConnection for NativeAgentConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
let agent = self.0.clone();
log::info!("Creating new thread for project at: {:?}", cwd);
cx.spawn(async move |cx| {
log::debug!("Starting thread creation in async context");
// Generate session ID
let session_id = generate_session_id();
log::info!("Created session with ID: {}", session_id);
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new(
"agent2",
self.clone(),
project.clone(),
session_id.clone(),
cx,
)
})
})?;
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
// Create Thread
let thread = agent.update(
cx,
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
// Fetch default model from registry settings
let registry = LanguageModelRegistry::read_global(cx);
// Log available models for debugging
let available_count = registry.available_models(cx).count();
log::debug!("Total available models: {}", available_count);
let default_model = registry.default_model().and_then(|default_model| {
agent
.models
.model_from_id(&LanguageModels::model_id(&default_model.model))
});
let summarization_model = registry.thread_summary_model().map(|c| c.model);
let thread = cx.new(|cx| {
let mut thread = Thread::new(
session_id.clone(),
project.clone(),
agent.project_context.clone(),
agent.context_server_registry.clone(),
action_log.clone(),
agent.templates.clone(),
default_model,
summarization_model,
cx,
);
Self::register_tools(&mut thread, project, action_log, cx);
thread
});
Ok(thread)
},
)??;
// Store the session
agent.update(cx, |agent, cx| {
agent.insert_session(thread, acp_thread.clone(), cx)
})?;
Ok(acp_thread)
})
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
&[] // No auth for in-process
}
fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Ok(()))
}
fn list_threads(
&self,
cx: &mut App,
) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
Some(self.0.read(cx).history.receiver())
}
fn load_thread(
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
session_id: acp::SessionId,
cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
let database = self.0.update(cx, |this, _| this.thread_database.clone());
cx.spawn(async move |cx| {
let db_thread = database
.load_thread(session_id.clone())
.await?
.context("no such thread found")?;
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new(
db_thread.title.clone(),
self.clone(),
project.clone(),
session_id.clone(),
cx,
)
})
})?;
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
let agent = self.0.clone();
// Create Thread
let thread = agent.update(cx, |agent, cx| {
let language_model_registry = LanguageModelRegistry::global(cx);
let configured_model = language_model_registry
.update(cx, |registry, cx| {
db_thread
.model
.as_ref()
.and_then(|model| {
let model = SelectedModel {
provider: model.provider.clone().into(),
model: model.model.clone().into(),
};
registry.select_model(&model, cx)
})
.or_else(|| registry.default_model())
})
.context("no default model configured")?;
let model = agent
.models
.model_from_id(&LanguageModels::model_id(&configured_model.model))
.context("no model by id")?;
let summarization_model = language_model_registry
.read(cx)
.thread_summary_model()
.map(|c| c.model);
let thread = cx.new(|cx| {
let mut thread = Thread::from_db(
session_id,
db_thread,
project.clone(),
agent.project_context.clone(),
agent.context_server_registry.clone(),
action_log.clone(),
agent.templates.clone(),
model,
summarization_model,
cx,
);
Self::register_tools(&mut thread, project, action_log, cx);
thread
});
anyhow::Ok(thread)
})??;
// Store the session
agent.update(cx, |agent, cx| {
agent.insert_session(thread.clone(), acp_thread.clone(), cx)
})?;
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))?
.await?;
Ok(acp_thread)
})
}
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
}
fn prompt(
&self,
id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
let id = id.expect("UserMessageId is required");
let session_id = params.session_id.clone();
log::info!("Received prompt request for session: {}", session_id);
log::debug!("Prompt blocks count: {}", params.prompt.len());
self.run_turn(session_id, cx, |thread, cx| {
let content: Vec<UserMessageContent> = params
.prompt
.into_iter()
.map(Into::into)
.collect::<Vec<_>>();
log::info!("Converted prompt to message: {} chars", content.len());
log::debug!("Message id: {:?}", id);
log::debug!("Message content: {:?}", content);
thread.update(cx, |thread, cx| thread.send(id, content, cx))
})
}
fn resume(
&self,
session_id: &acp::SessionId,
_cx: &mut App,
) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
Some(Rc::new(NativeAgentSessionResume {
connection: self.clone(),
session_id: session_id.clone(),
}) as _)
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
if let Some(agent) = agent.sessions.get(session_id) {
agent.thread.update(cx, |thread, cx| thread.cancel(cx));
}
});
}
fn session_editor(
&self,
session_id: &agent_client_protocol::SessionId,
cx: &mut App,
) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
self.0.update(cx, |agent, _cx| {
agent
.sessions
.get(session_id)
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
})
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
self
}
}
struct NativeAgentSessionEditor(Entity<Thread>);
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
Task::ready(
self.0
.update(cx, |thread, cx| thread.truncate(message_id, cx)),
)
}
}
struct NativeAgentSessionResume {
connection: NativeAgentConnection,
session_id: acp::SessionId,
}
impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
self.connection
.run_turn(self.session_id.clone(), cx, |thread, cx| {
thread.update(cx, |thread, cx| thread.resume(cx))
})
}
}
#[cfg(test)]
mod tests {
use crate::HistoryStore;
use super::*;
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
use fs::FakeFs;
use gpui::TestAppContext;
use language_model::fake_provider::FakeLanguageModel;
use serde_json::json;
use settings::SettingsStore;
use util::path;
#[gpui::test]
async fn test_maintaining_project_context(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/",
json!({
"a": {}
}),
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
let agent = NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap();
agent.read_with(cx, |agent, _| {
assert_eq!(agent.project_context.borrow().worktrees, vec![])
});
let worktree = project
.update(cx, |project, cx| project.create_worktree("/a", true, cx))
.await
.unwrap();
cx.run_until_parked();
agent.read_with(cx, |agent, _| {
assert_eq!(
agent.project_context.borrow().worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
rules_file: None
}]
)
});
// Creating `/a/.rules` updates the project context.
fs.insert_file("/a/.rules", Vec::new()).await;
cx.run_until_parked();
agent.read_with(cx, |agent, cx| {
let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
assert_eq!(
agent.project_context.borrow().worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
rules_file: Some(RulesFileContext {
path_in_worktree: Path::new(".rules").into(),
text: "".into(),
project_entry_id: rules_entry.id.to_usize()
})
}]
)
});
}
#[gpui::test]
async fn test_listing_models(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await;
let connection = NativeAgentConnection(
NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap(),
);
let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
let acp_thread::AgentModelList::Grouped(models) = models else {
panic!("Unexpected model group");
};
assert_eq!(
models,
IndexMap::from_iter([(
AgentModelGroupName("Fake".into()),
vec![AgentModelInfo {
id: AgentModelId("fake/fake".into()),
name: "Fake".into(),
icon: Some(ui::IconName::ZedAssistant),
}]
)])
);
}
#[gpui::test]
async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.create_dir(paths::settings_file().parent().unwrap())
.await
.unwrap();
fs.insert_file(
paths::settings_file(),
json!({
"agent": {
"default_model": {
"provider": "foo",
"model": "bar"
}
}
})
.to_string()
.into_bytes(),
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
// Create the agent and connection
let agent = NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap();
let connection = NativeAgentConnection(agent.clone());
// Create a thread/session
let acp_thread = cx
.update(|cx| {
Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
})
.await
.unwrap();
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
// Select a model
let model_id = AgentModelId("fake/fake".into());
cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
.await
.unwrap();
// Verify the thread has the selected model
agent.read_with(cx, |agent, _| {
let session = agent.sessions.get(&session_id).unwrap();
session.thread.read_with(cx, |thread, _| {
assert_eq!(thread.model().unwrap().id().0, "fake");
});
});
cx.run_until_parked();
// Verify settings file was updated
let settings_content = fs.load(paths::settings_file()).await.unwrap();
let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
// Check that the agent settings contain the selected model
assert_eq!(
settings_json["agent"]["default_model"]["model"],
json!("fake")
);
assert_eq!(
settings_json["agent"]["default_model"]["provider"],
json!("fake")
);
}
#[gpui::test]
async fn test_history(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [], cx).await;
let agent = NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap();
let connection = NativeAgentConnection(agent.clone());
let history_store = cx.new(|cx| {
let mut store = HistoryStore::new(cx);
store.register_agent(NATIVE_AGENT_SERVER_NAME.clone(), &connection, cx);
store
});
let acp_thread = cx
.update(|cx| {
Rc::new(connection.clone()).new_thread(project.clone(), Path::new(path!("")), cx)
})
.await
.unwrap();
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
let selector = connection.model_selector().unwrap();
let summarization_model: Arc<dyn LanguageModel> =
Arc::new(FakeLanguageModel::default()) as _;
agent.update(cx, |agent, cx| {
let thread = agent.sessions.get(&session_id).unwrap().thread.clone();
thread.update(cx, |thread, cx| {
thread.set_summarization_model(Some(summarization_model.clone()), cx);
})
});
let model = cx
.update(|cx| selector.selected_model(&session_id, cx))
.await
.expect("selected_model should succeed");
let model = cx
.update(|cx| agent.read(cx).models().model_from_id(&model.id))
.unwrap();
let model = model.as_fake();
let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Hi", cx));
let send = cx.foreground_executor().spawn(send);
cx.run_until_parked();
model.send_last_completion_stream_text_chunk("Hey");
model.end_last_completion_stream();
send.await.unwrap();
summarization_model
.as_fake()
.send_last_completion_stream_text_chunk("Saying Hello");
summarization_model.as_fake().end_last_completion_stream();
cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE);
let history = history_store.update(cx, |store, cx| store.entries(cx));
assert_eq!(history.len(), 1);
assert_eq!(history[0].title(), "Saying Hello");
}
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
agent_settings::init(cx);
language::init(cx);
LanguageModelRegistry::test(cx);
});
}
}