Support profiles in agent2 (#36034)
We still need a profile selector. Release Notes: - N/A --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
13bf45dd4a
commit
2444321756
15 changed files with 587 additions and 108 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -196,6 +196,7 @@ dependencies = [
|
||||||
"clock",
|
"clock",
|
||||||
"cloud_llm_client",
|
"cloud_llm_client",
|
||||||
"collections",
|
"collections",
|
||||||
|
"context_server",
|
||||||
"ctor",
|
"ctor",
|
||||||
"editor",
|
"editor",
|
||||||
"env_logger 0.11.8",
|
"env_logger 0.11.8",
|
||||||
|
|
|
@ -254,6 +254,15 @@ impl ToolCall {
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(raw_output) = raw_output {
|
if let Some(raw_output) = raw_output {
|
||||||
|
if self.content.is_empty() {
|
||||||
|
if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
|
||||||
|
{
|
||||||
|
self.content
|
||||||
|
.push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
|
||||||
|
markdown,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
self.raw_output = Some(raw_output);
|
self.raw_output = Some(raw_output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1266,6 +1275,48 @@ impl AcpThread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn markdown_for_raw_output(
|
||||||
|
raw_output: &serde_json::Value,
|
||||||
|
language_registry: &Arc<LanguageRegistry>,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Option<Entity<Markdown>> {
|
||||||
|
match raw_output {
|
||||||
|
serde_json::Value::Null => None,
|
||||||
|
serde_json::Value::Bool(value) => Some(cx.new(|cx| {
|
||||||
|
Markdown::new(
|
||||||
|
value.to_string().into(),
|
||||||
|
Some(language_registry.clone()),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})),
|
||||||
|
serde_json::Value::Number(value) => Some(cx.new(|cx| {
|
||||||
|
Markdown::new(
|
||||||
|
value.to_string().into(),
|
||||||
|
Some(language_registry.clone()),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})),
|
||||||
|
serde_json::Value::String(value) => Some(cx.new(|cx| {
|
||||||
|
Markdown::new(
|
||||||
|
value.clone().into(),
|
||||||
|
Some(language_registry.clone()),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})),
|
||||||
|
value => Some(cx.new(|cx| {
|
||||||
|
Markdown::new(
|
||||||
|
format!("```json\n{}\n```", value).into(),
|
||||||
|
Some(language_registry.clone()),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
@ -23,6 +23,7 @@ assistant_tools.workspace = true
|
||||||
chrono.workspace = true
|
chrono.workspace = true
|
||||||
cloud_llm_client.workspace = true
|
cloud_llm_client.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
|
context_server.workspace = true
|
||||||
fs.workspace = true
|
fs.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
|
@ -60,6 +61,7 @@ workspace-hack.workspace = true
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
client = { workspace = true, "features" = ["test-support"] }
|
client = { workspace = true, "features" = ["test-support"] }
|
||||||
clock = { workspace = true, "features" = ["test-support"] }
|
clock = { workspace = true, "features" = ["test-support"] }
|
||||||
|
context_server = { workspace = true, "features" = ["test-support"] }
|
||||||
editor = { workspace = true, "features" = ["test-support"] }
|
editor = { workspace = true, "features" = ["test-support"] }
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
fs = { workspace = true, "features" = ["test-support"] }
|
fs = { workspace = true, "features" = ["test-support"] }
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
use crate::{AgentResponseEvent, Thread, templates::Templates};
|
use crate::{AgentResponseEvent, Thread, templates::Templates};
|
||||||
use crate::{
|
use crate::{
|
||||||
CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool,
|
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
|
||||||
GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool,
|
FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool,
|
||||||
ThinkingTool, ToolCallAuthorization, WebSearchTool,
|
ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
|
||||||
};
|
};
|
||||||
use acp_thread::ModelSelector;
|
use acp_thread::ModelSelector;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
|
@ -55,6 +55,7 @@ pub struct NativeAgent {
|
||||||
project_context: Rc<RefCell<ProjectContext>>,
|
project_context: Rc<RefCell<ProjectContext>>,
|
||||||
project_context_needs_refresh: watch::Sender<()>,
|
project_context_needs_refresh: watch::Sender<()>,
|
||||||
_maintain_project_context: Task<Result<()>>,
|
_maintain_project_context: Task<Result<()>>,
|
||||||
|
context_server_registry: Entity<ContextServerRegistry>,
|
||||||
/// Shared templates for all threads
|
/// Shared templates for all threads
|
||||||
templates: Arc<Templates>,
|
templates: Arc<Templates>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
@ -90,6 +91,9 @@ impl NativeAgent {
|
||||||
_maintain_project_context: cx.spawn(async move |this, cx| {
|
_maintain_project_context: cx.spawn(async move |this, cx| {
|
||||||
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
|
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
|
||||||
}),
|
}),
|
||||||
|
context_server_registry: cx.new(|cx| {
|
||||||
|
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
|
||||||
|
}),
|
||||||
templates,
|
templates,
|
||||||
project,
|
project,
|
||||||
prompt_store,
|
prompt_store,
|
||||||
|
@ -385,7 +389,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
// Create AcpThread
|
// Create AcpThread
|
||||||
let acp_thread = cx.update(|cx| {
|
let acp_thread = cx.update(|cx| {
|
||||||
cx.new(|cx| {
|
cx.new(|cx| {
|
||||||
acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
|
acp_thread::AcpThread::new(
|
||||||
|
"agent2",
|
||||||
|
self.clone(),
|
||||||
|
project.clone(),
|
||||||
|
session_id.clone(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
|
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
|
||||||
|
@ -413,11 +423,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
})
|
})
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
log::warn!("No default model configured in settings");
|
log::warn!("No default model configured in settings");
|
||||||
anyhow!("No default model configured. Please configure a default model in settings.")
|
anyhow!(
|
||||||
|
"No default model. Please configure a default model in settings."
|
||||||
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
|
let mut thread = Thread::new(
|
||||||
|
project.clone(),
|
||||||
|
agent.project_context.clone(),
|
||||||
|
agent.context_server_registry.clone(),
|
||||||
|
action_log.clone(),
|
||||||
|
agent.templates.clone(),
|
||||||
|
default_model,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||||
thread.add_tool(CopyPathTool::new(project.clone()));
|
thread.add_tool(CopyPathTool::new(project.clone()));
|
||||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||||
|
@ -450,7 +470,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
acp_thread: acp_thread.downgrade(),
|
acp_thread: acp_thread.downgrade(),
|
||||||
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||||
this.sessions.remove(acp_thread.session_id());
|
this.sessions.remove(acp_thread.session_id());
|
||||||
})
|
}),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
})?;
|
})?;
|
||||||
|
|
|
@ -2,6 +2,7 @@ use super::*;
|
||||||
use acp_thread::AgentConnection;
|
use acp_thread::AgentConnection;
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol::{self as acp};
|
use agent_client_protocol::{self as acp};
|
||||||
|
use agent_settings::AgentProfileId;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use client::{Client, UserStore};
|
use client::{Client, UserStore};
|
||||||
use fs::{FakeFs, Fs};
|
use fs::{FakeFs, Fs};
|
||||||
|
@ -165,7 +166,9 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
})
|
}),
|
||||||
|
"{}",
|
||||||
|
thread.to_markdown()
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -469,6 +472,82 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_profiles(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest {
|
||||||
|
model, thread, fs, ..
|
||||||
|
} = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
thread.update(cx, |thread, _cx| {
|
||||||
|
thread.add_tool(DelayTool);
|
||||||
|
thread.add_tool(EchoTool);
|
||||||
|
thread.add_tool(InfiniteTool);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Override profiles and wait for settings to be loaded.
|
||||||
|
fs.insert_file(
|
||||||
|
paths::settings_file(),
|
||||||
|
json!({
|
||||||
|
"agent": {
|
||||||
|
"profiles": {
|
||||||
|
"test-1": {
|
||||||
|
"name": "Test Profile 1",
|
||||||
|
"tools": {
|
||||||
|
EchoTool.name(): true,
|
||||||
|
DelayTool.name(): true,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"test-2": {
|
||||||
|
"name": "Test Profile 2",
|
||||||
|
"tools": {
|
||||||
|
InfiniteTool.name(): true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.to_string()
|
||||||
|
.into_bytes(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
// Test that test-1 profile (default) has echo and delay tools
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.set_profile(AgentProfileId("test-1".into()));
|
||||||
|
thread.send("test", cx);
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
let mut pending_completions = fake_model.pending_completions();
|
||||||
|
assert_eq!(pending_completions.len(), 1);
|
||||||
|
let completion = pending_completions.pop().unwrap();
|
||||||
|
let tool_names: Vec<String> = completion
|
||||||
|
.tools
|
||||||
|
.iter()
|
||||||
|
.map(|tool| tool.name.clone())
|
||||||
|
.collect();
|
||||||
|
assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
|
||||||
|
// Switch to test-2 profile, and verify that it has only the infinite tool.
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.set_profile(AgentProfileId("test-2".into()));
|
||||||
|
thread.send("test2", cx)
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
let mut pending_completions = fake_model.pending_completions();
|
||||||
|
assert_eq!(pending_completions.len(), 1);
|
||||||
|
let completion = pending_completions.pop().unwrap();
|
||||||
|
let tool_names: Vec<String> = completion
|
||||||
|
.tools
|
||||||
|
.iter()
|
||||||
|
.map(|tool| tool.name.clone())
|
||||||
|
.collect();
|
||||||
|
assert_eq!(tool_names, vec![InfiniteTool.name()]);
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
#[ignore = "can't run on CI yet"]
|
#[ignore = "can't run on CI yet"]
|
||||||
async fn test_cancellation(cx: &mut TestAppContext) {
|
async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
|
@ -595,6 +674,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||||
language_models::init(user_store.clone(), client.clone(), cx);
|
language_models::init(user_store.clone(), client.clone(), cx);
|
||||||
Project::init_settings(cx);
|
Project::init_settings(cx);
|
||||||
LanguageModelRegistry::test(cx);
|
LanguageModelRegistry::test(cx);
|
||||||
|
agent_settings::init(cx);
|
||||||
});
|
});
|
||||||
cx.executor().forbid_parking();
|
cx.executor().forbid_parking();
|
||||||
|
|
||||||
|
@ -790,6 +870,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||||
id: acp::ToolCallId("1".into()),
|
id: acp::ToolCallId("1".into()),
|
||||||
fields: acp::ToolCallUpdateFields {
|
fields: acp::ToolCallUpdateFields {
|
||||||
status: Some(acp::ToolCallStatus::Completed),
|
status: Some(acp::ToolCallStatus::Completed),
|
||||||
|
raw_output: Some("Finished thinking.".into()),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -813,6 +894,7 @@ struct ThreadTest {
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
thread: Entity<Thread>,
|
thread: Entity<Thread>,
|
||||||
project_context: Rc<RefCell<ProjectContext>>,
|
project_context: Rc<RefCell<ProjectContext>>,
|
||||||
|
fs: Arc<FakeFs>,
|
||||||
}
|
}
|
||||||
|
|
||||||
enum TestModel {
|
enum TestModel {
|
||||||
|
@ -835,30 +917,57 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||||
cx.executor().allow_parking();
|
cx.executor().allow_parking();
|
||||||
|
|
||||||
let fs = FakeFs::new(cx.background_executor.clone());
|
let fs = FakeFs::new(cx.background_executor.clone());
|
||||||
|
fs.create_dir(paths::settings_file().parent().unwrap())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
fs.insert_file(
|
||||||
|
paths::settings_file(),
|
||||||
|
json!({
|
||||||
|
"agent": {
|
||||||
|
"default_profile": "test-profile",
|
||||||
|
"profiles": {
|
||||||
|
"test-profile": {
|
||||||
|
"name": "Test Profile",
|
||||||
|
"tools": {
|
||||||
|
EchoTool.name(): true,
|
||||||
|
DelayTool.name(): true,
|
||||||
|
WordListTool.name(): true,
|
||||||
|
ToolRequiringPermission.name(): true,
|
||||||
|
InfiniteTool.name(): true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.to_string()
|
||||||
|
.into_bytes(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
settings::init(cx);
|
settings::init(cx);
|
||||||
watch_settings(fs.clone(), cx);
|
|
||||||
Project::init_settings(cx);
|
Project::init_settings(cx);
|
||||||
agent_settings::init(cx);
|
agent_settings::init(cx);
|
||||||
|
gpui_tokio::init(cx);
|
||||||
|
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||||
|
cx.set_http_client(Arc::new(http_client));
|
||||||
|
|
||||||
|
client::init_settings(cx);
|
||||||
|
let client = Client::production(cx);
|
||||||
|
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||||
|
language_model::init(client.clone(), cx);
|
||||||
|
language_models::init(user_store.clone(), client.clone(), cx);
|
||||||
|
|
||||||
|
watch_settings(fs.clone(), cx);
|
||||||
});
|
});
|
||||||
|
|
||||||
let templates = Templates::new();
|
let templates = Templates::new();
|
||||||
|
|
||||||
fs.insert_tree(path!("/test"), json!({})).await;
|
fs.insert_tree(path!("/test"), json!({})).await;
|
||||||
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
|
||||||
|
|
||||||
let model = cx
|
let model = cx
|
||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
gpui_tokio::init(cx);
|
|
||||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
|
||||||
cx.set_http_client(Arc::new(http_client));
|
|
||||||
|
|
||||||
client::init_settings(cx);
|
|
||||||
let client = Client::production(cx);
|
|
||||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
|
||||||
language_model::init(client.clone(), cx);
|
|
||||||
language_models::init(user_store.clone(), client.clone(), cx);
|
|
||||||
|
|
||||||
if let TestModel::Fake = model {
|
if let TestModel::Fake = model {
|
||||||
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
|
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
|
||||||
} else {
|
} else {
|
||||||
|
@ -881,20 +990,25 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
|
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
|
||||||
|
let context_server_registry =
|
||||||
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let thread = cx.new(|_| {
|
let thread = cx.new(|cx| {
|
||||||
Thread::new(
|
Thread::new(
|
||||||
project,
|
project,
|
||||||
project_context.clone(),
|
project_context.clone(),
|
||||||
|
context_server_registry,
|
||||||
action_log,
|
action_log,
|
||||||
templates,
|
templates,
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
ThreadTest {
|
ThreadTest {
|
||||||
model,
|
model,
|
||||||
thread,
|
thread,
|
||||||
project_context,
|
project_context,
|
||||||
|
fs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use crate::{SystemPromptTemplate, Template, Templates};
|
use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use agent_settings::AgentSettings;
|
use agent_settings::{AgentProfileId, AgentSettings};
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_tool::adapt_schema_to_format;
|
use assistant_tool::adapt_schema_to_format;
|
||||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||||
|
@ -126,6 +126,8 @@ pub struct Thread {
|
||||||
running_turn: Option<Task<()>>,
|
running_turn: Option<Task<()>>,
|
||||||
pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
|
pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
|
||||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||||
|
context_server_registry: Entity<ContextServerRegistry>,
|
||||||
|
profile_id: AgentProfileId,
|
||||||
project_context: Rc<RefCell<ProjectContext>>,
|
project_context: Rc<RefCell<ProjectContext>>,
|
||||||
templates: Arc<Templates>,
|
templates: Arc<Templates>,
|
||||||
pub selected_model: Arc<dyn LanguageModel>,
|
pub selected_model: Arc<dyn LanguageModel>,
|
||||||
|
@ -137,16 +139,21 @@ impl Thread {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
project_context: Rc<RefCell<ProjectContext>>,
|
project_context: Rc<RefCell<ProjectContext>>,
|
||||||
|
context_server_registry: Entity<ContextServerRegistry>,
|
||||||
action_log: Entity<ActionLog>,
|
action_log: Entity<ActionLog>,
|
||||||
templates: Arc<Templates>,
|
templates: Arc<Templates>,
|
||||||
default_model: Arc<dyn LanguageModel>,
|
default_model: Arc<dyn LanguageModel>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||||
Self {
|
Self {
|
||||||
messages: Vec::new(),
|
messages: Vec::new(),
|
||||||
completion_mode: CompletionMode::Normal,
|
completion_mode: CompletionMode::Normal,
|
||||||
running_turn: None,
|
running_turn: None,
|
||||||
pending_tool_uses: HashMap::default(),
|
pending_tool_uses: HashMap::default(),
|
||||||
tools: BTreeMap::default(),
|
tools: BTreeMap::default(),
|
||||||
|
context_server_registry,
|
||||||
|
profile_id,
|
||||||
project_context,
|
project_context,
|
||||||
templates,
|
templates,
|
||||||
selected_model: default_model,
|
selected_model: default_model,
|
||||||
|
@ -179,6 +186,10 @@ impl Thread {
|
||||||
self.tools.remove(name).is_some()
|
self.tools.remove(name).is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
|
||||||
|
self.profile_id = profile_id;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn cancel(&mut self) {
|
pub fn cancel(&mut self) {
|
||||||
self.running_turn.take();
|
self.running_turn.take();
|
||||||
|
|
||||||
|
@ -298,6 +309,7 @@ impl Thread {
|
||||||
} else {
|
} else {
|
||||||
acp::ToolCallStatus::Completed
|
acp::ToolCallStatus::Completed
|
||||||
}),
|
}),
|
||||||
|
raw_output: tool_result.output.clone(),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
@ -604,21 +616,23 @@ impl Thread {
|
||||||
let messages = self.build_request_messages();
|
let messages = self.build_request_messages();
|
||||||
log::info!("Request will include {} messages", messages.len());
|
log::info!("Request will include {} messages", messages.len());
|
||||||
|
|
||||||
let tools: Vec<LanguageModelRequestTool> = self
|
let tools = if let Some(tools) = self.tools(cx).log_err() {
|
||||||
.tools
|
tools
|
||||||
.values()
|
.filter_map(|tool| {
|
||||||
.filter_map(|tool| {
|
let tool_name = tool.name().to_string();
|
||||||
let tool_name = tool.name().to_string();
|
log::trace!("Including tool: {}", tool_name);
|
||||||
log::trace!("Including tool: {}", tool_name);
|
Some(LanguageModelRequestTool {
|
||||||
Some(LanguageModelRequestTool {
|
name: tool_name,
|
||||||
name: tool_name,
|
description: tool.description().to_string(),
|
||||||
description: tool.description(cx).to_string(),
|
input_schema: tool
|
||||||
input_schema: tool
|
.input_schema(self.selected_model.tool_input_format())
|
||||||
.input_schema(self.selected_model.tool_input_format())
|
.log_err()?,
|
||||||
.log_err()?,
|
})
|
||||||
})
|
})
|
||||||
})
|
.collect()
|
||||||
.collect();
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
};
|
||||||
|
|
||||||
log::info!("Request includes {} tools", tools.len());
|
log::info!("Request includes {} tools", tools.len());
|
||||||
|
|
||||||
|
@ -639,6 +653,35 @@ impl Thread {
|
||||||
request
|
request
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
|
||||||
|
let profile = AgentSettings::get_global(cx)
|
||||||
|
.profiles
|
||||||
|
.get(&self.profile_id)
|
||||||
|
.context("profile not found")?;
|
||||||
|
|
||||||
|
Ok(self
|
||||||
|
.tools
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(tool_name, tool)| {
|
||||||
|
if profile.is_tool_enabled(tool_name) {
|
||||||
|
Some(tool)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.chain(self.context_server_registry.read(cx).servers().flat_map(
|
||||||
|
|(server_id, tools)| {
|
||||||
|
tools.iter().filter_map(|(tool_name, tool)| {
|
||||||
|
if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
|
||||||
|
Some(tool)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
||||||
log::trace!(
|
log::trace!(
|
||||||
"Building request messages from {} thread messages",
|
"Building request messages from {} thread messages",
|
||||||
|
@ -686,7 +729,7 @@ where
|
||||||
|
|
||||||
fn name(&self) -> SharedString;
|
fn name(&self) -> SharedString;
|
||||||
|
|
||||||
fn description(&self, _cx: &mut App) -> SharedString {
|
fn description(&self) -> SharedString {
|
||||||
let schema = schemars::schema_for!(Self::Input);
|
let schema = schemars::schema_for!(Self::Input);
|
||||||
SharedString::new(
|
SharedString::new(
|
||||||
schema
|
schema
|
||||||
|
@ -722,13 +765,13 @@ where
|
||||||
pub struct Erased<T>(T);
|
pub struct Erased<T>(T);
|
||||||
|
|
||||||
pub struct AgentToolOutput {
|
pub struct AgentToolOutput {
|
||||||
llm_output: LanguageModelToolResultContent,
|
pub llm_output: LanguageModelToolResultContent,
|
||||||
raw_output: serde_json::Value,
|
pub raw_output: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait AnyAgentTool {
|
pub trait AnyAgentTool {
|
||||||
fn name(&self) -> SharedString;
|
fn name(&self) -> SharedString;
|
||||||
fn description(&self, cx: &mut App) -> SharedString;
|
fn description(&self) -> SharedString;
|
||||||
fn kind(&self) -> acp::ToolKind;
|
fn kind(&self) -> acp::ToolKind;
|
||||||
fn initial_title(&self, input: serde_json::Value) -> SharedString;
|
fn initial_title(&self, input: serde_json::Value) -> SharedString;
|
||||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||||
|
@ -748,8 +791,8 @@ where
|
||||||
self.0.name()
|
self.0.name()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn description(&self, cx: &mut App) -> SharedString {
|
fn description(&self) -> SharedString {
|
||||||
self.0.description(cx)
|
self.0.description()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kind(&self) -> agent_client_protocol::ToolKind {
|
fn kind(&self) -> agent_client_protocol::ToolKind {
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
mod context_server_registry;
|
||||||
mod copy_path_tool;
|
mod copy_path_tool;
|
||||||
mod create_directory_tool;
|
mod create_directory_tool;
|
||||||
mod delete_path_tool;
|
mod delete_path_tool;
|
||||||
|
@ -15,6 +16,7 @@ mod terminal_tool;
|
||||||
mod thinking_tool;
|
mod thinking_tool;
|
||||||
mod web_search_tool;
|
mod web_search_tool;
|
||||||
|
|
||||||
|
pub use context_server_registry::*;
|
||||||
pub use copy_path_tool::*;
|
pub use copy_path_tool::*;
|
||||||
pub use create_directory_tool::*;
|
pub use create_directory_tool::*;
|
||||||
pub use delete_path_tool::*;
|
pub use delete_path_tool::*;
|
||||||
|
|
231
crates/agent2/src/tools/context_server_registry.rs
Normal file
231
crates/agent2/src/tools/context_server_registry.rs
Normal file
|
@ -0,0 +1,231 @@
|
||||||
|
use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
|
||||||
|
use agent_client_protocol::ToolKind;
|
||||||
|
use anyhow::{Result, anyhow, bail};
|
||||||
|
use collections::{BTreeMap, HashMap};
|
||||||
|
use context_server::ContextServerId;
|
||||||
|
use gpui::{App, Context, Entity, SharedString, Task};
|
||||||
|
use project::context_server_store::{ContextServerStatus, ContextServerStore};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
pub struct ContextServerRegistry {
|
||||||
|
server_store: Entity<ContextServerStore>,
|
||||||
|
registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
|
||||||
|
_subscription: gpui::Subscription,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RegisteredContextServer {
|
||||||
|
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||||
|
load_tools: Task<Result<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ContextServerRegistry {
|
||||||
|
pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
|
||||||
|
let mut this = Self {
|
||||||
|
server_store: server_store.clone(),
|
||||||
|
registered_servers: HashMap::default(),
|
||||||
|
_subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
|
||||||
|
};
|
||||||
|
for server in server_store.read(cx).running_servers() {
|
||||||
|
this.reload_tools_for_server(server.id(), cx);
|
||||||
|
}
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn servers(
|
||||||
|
&self,
|
||||||
|
) -> impl Iterator<
|
||||||
|
Item = (
|
||||||
|
&ContextServerId,
|
||||||
|
&BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||||
|
),
|
||||||
|
> {
|
||||||
|
self.registered_servers
|
||||||
|
.iter()
|
||||||
|
.map(|(id, server)| (id, &server.tools))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
|
||||||
|
let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let Some(client) = server.client() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
if !client.capable(context_server::protocol::ServerCapability::Tools) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let registered_server =
|
||||||
|
self.registered_servers
|
||||||
|
.entry(server_id.clone())
|
||||||
|
.or_insert(RegisteredContextServer {
|
||||||
|
tools: BTreeMap::default(),
|
||||||
|
load_tools: Task::ready(Ok(())),
|
||||||
|
});
|
||||||
|
registered_server.load_tools = cx.spawn(async move |this, cx| {
|
||||||
|
let response = client
|
||||||
|
.request::<context_server::types::requests::ListTools>(())
|
||||||
|
.await;
|
||||||
|
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
registered_server.tools.clear();
|
||||||
|
if let Some(response) = response.log_err() {
|
||||||
|
for tool in response.tools {
|
||||||
|
let tool = Arc::new(ContextServerTool::new(
|
||||||
|
this.server_store.clone(),
|
||||||
|
server.id(),
|
||||||
|
tool,
|
||||||
|
));
|
||||||
|
registered_server.tools.insert(tool.name(), tool);
|
||||||
|
}
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_context_server_store_event(
|
||||||
|
&mut self,
|
||||||
|
_: Entity<ContextServerStore>,
|
||||||
|
event: &project::context_server_store::Event,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
match event {
|
||||||
|
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
|
||||||
|
match status {
|
||||||
|
ContextServerStatus::Starting => {}
|
||||||
|
ContextServerStatus::Running => {
|
||||||
|
self.reload_tools_for_server(server_id.clone(), cx);
|
||||||
|
}
|
||||||
|
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
|
||||||
|
self.registered_servers.remove(&server_id);
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ContextServerTool {
|
||||||
|
store: Entity<ContextServerStore>,
|
||||||
|
server_id: ContextServerId,
|
||||||
|
tool: context_server::types::Tool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ContextServerTool {
|
||||||
|
fn new(
|
||||||
|
store: Entity<ContextServerStore>,
|
||||||
|
server_id: ContextServerId,
|
||||||
|
tool: context_server::types::Tool,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
store,
|
||||||
|
server_id,
|
||||||
|
tool,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AnyAgentTool for ContextServerTool {
|
||||||
|
fn name(&self) -> SharedString {
|
||||||
|
self.tool.name.clone().into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> SharedString {
|
||||||
|
self.tool.description.clone().unwrap_or_default().into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn kind(&self) -> ToolKind {
|
||||||
|
ToolKind::Other
|
||||||
|
}
|
||||||
|
|
||||||
|
fn initial_title(&self, _input: serde_json::Value) -> SharedString {
|
||||||
|
format!("Run MCP tool `{}`", self.tool.name).into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_schema(
|
||||||
|
&self,
|
||||||
|
format: language_model::LanguageModelToolSchemaFormat,
|
||||||
|
) -> Result<serde_json::Value> {
|
||||||
|
let mut schema = self.tool.input_schema.clone();
|
||||||
|
assistant_tool::adapt_schema_to_format(&mut schema, format)?;
|
||||||
|
Ok(match schema {
|
||||||
|
serde_json::Value::Null => {
|
||||||
|
serde_json::json!({ "type": "object", "properties": [] })
|
||||||
|
}
|
||||||
|
serde_json::Value::Object(map) if map.is_empty() => {
|
||||||
|
serde_json::json!({ "type": "object", "properties": [] })
|
||||||
|
}
|
||||||
|
_ => schema,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: serde_json::Value,
|
||||||
|
_event_stream: ToolCallEventStream,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<AgentToolOutput>> {
|
||||||
|
let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
|
||||||
|
return Task::ready(Err(anyhow!("Context server not found")));
|
||||||
|
};
|
||||||
|
let tool_name = self.tool.name.clone();
|
||||||
|
let server_clone = server.clone();
|
||||||
|
let input_clone = input.clone();
|
||||||
|
|
||||||
|
cx.spawn(async move |_cx| {
|
||||||
|
let Some(protocol) = server_clone.client() else {
|
||||||
|
bail!("Context server not initialized");
|
||||||
|
};
|
||||||
|
|
||||||
|
let arguments = if let serde_json::Value::Object(map) = input_clone {
|
||||||
|
Some(map.into_iter().collect())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"Running tool: {} with arguments: {:?}",
|
||||||
|
tool_name,
|
||||||
|
arguments
|
||||||
|
);
|
||||||
|
let response = protocol
|
||||||
|
.request::<context_server::types::requests::CallTool>(
|
||||||
|
context_server::types::CallToolParams {
|
||||||
|
name: tool_name,
|
||||||
|
arguments,
|
||||||
|
meta: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut result = String::new();
|
||||||
|
for content in response.content {
|
||||||
|
match content {
|
||||||
|
context_server::types::ToolResponseContent::Text { text } => {
|
||||||
|
result.push_str(&text);
|
||||||
|
}
|
||||||
|
context_server::types::ToolResponseContent::Image { .. } => {
|
||||||
|
log::warn!("Ignoring image content from tool response");
|
||||||
|
}
|
||||||
|
context_server::types::ToolResponseContent::Audio { .. } => {
|
||||||
|
log::warn!("Ignoring audio content from tool response");
|
||||||
|
}
|
||||||
|
context_server::types::ToolResponseContent::Resource { .. } => {
|
||||||
|
log::warn!("Ignoring resource content from tool response");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(AgentToolOutput {
|
||||||
|
raw_output: result.clone().into(),
|
||||||
|
llm_output: result.into(),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -85,7 +85,7 @@ impl AgentTool for DiagnosticsTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: Self::Input,
|
input: Self::Input,
|
||||||
event_stream: ToolCallEventStream,
|
_event_stream: ToolCallEventStream,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Self::Output>> {
|
) -> Task<Result<Self::Output>> {
|
||||||
match input.path {
|
match input.path {
|
||||||
|
@ -119,11 +119,6 @@ impl AgentTool for DiagnosticsTool {
|
||||||
range.start.row + 1,
|
range.start.row + 1,
|
||||||
entry.diagnostic.message
|
entry.diagnostic.message
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
|
||||||
content: Some(vec![output.clone().into()]),
|
|
||||||
..Default::default()
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if output.is_empty() {
|
if output.is_empty() {
|
||||||
|
@ -158,18 +153,9 @@ impl AgentTool for DiagnosticsTool {
|
||||||
}
|
}
|
||||||
|
|
||||||
if has_diagnostics {
|
if has_diagnostics {
|
||||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
|
||||||
content: Some(vec![output.clone().into()]),
|
|
||||||
..Default::default()
|
|
||||||
});
|
|
||||||
Task::ready(Ok(output))
|
Task::ready(Ok(output))
|
||||||
} else {
|
} else {
|
||||||
let text = "No errors or warnings found in the project.";
|
Task::ready(Ok("No errors or warnings found in the project.".into()))
|
||||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
|
||||||
content: Some(vec![text.into()]),
|
|
||||||
..Default::default()
|
|
||||||
});
|
|
||||||
Task::ready(Ok(text.into()))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -454,9 +454,8 @@ fn resolve_path(
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::Templates;
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::{ContextServerRegistry, Templates};
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use client::TelemetrySettings;
|
use client::TelemetrySettings;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
|
@ -475,9 +474,20 @@ mod tests {
|
||||||
fs.insert_tree("/root", json!({})).await;
|
fs.insert_tree("/root", json!({})).await;
|
||||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
|
let context_server_registry =
|
||||||
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread =
|
let thread = cx.new(|cx| {
|
||||||
cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model));
|
Thread::new(
|
||||||
|
project,
|
||||||
|
Rc::default(),
|
||||||
|
context_server_registry,
|
||||||
|
action_log,
|
||||||
|
Templates::new(),
|
||||||
|
model,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
let result = cx
|
let result = cx
|
||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
let input = EditFileToolInput {
|
let input = EditFileToolInput {
|
||||||
|
@ -661,14 +671,18 @@ mod tests {
|
||||||
});
|
});
|
||||||
|
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
|
let context_server_registry =
|
||||||
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|_| {
|
let thread = cx.new(|cx| {
|
||||||
Thread::new(
|
Thread::new(
|
||||||
project,
|
project,
|
||||||
Rc::default(),
|
Rc::default(),
|
||||||
|
context_server_registry,
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -792,15 +806,19 @@ mod tests {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
|
let context_server_registry =
|
||||||
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|_| {
|
let thread = cx.new(|cx| {
|
||||||
Thread::new(
|
Thread::new(
|
||||||
project,
|
project,
|
||||||
Rc::default(),
|
Rc::default(),
|
||||||
|
context_server_registry,
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -914,15 +932,19 @@ mod tests {
|
||||||
init_test(cx);
|
init_test(cx);
|
||||||
let fs = project::FakeFs::new(cx.executor());
|
let fs = project::FakeFs::new(cx.executor());
|
||||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||||
|
let context_server_registry =
|
||||||
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|_| {
|
let thread = cx.new(|cx| {
|
||||||
Thread::new(
|
Thread::new(
|
||||||
project,
|
project,
|
||||||
Rc::default(),
|
Rc::default(),
|
||||||
|
context_server_registry,
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool { thread });
|
||||||
|
@ -1041,15 +1063,19 @@ mod tests {
|
||||||
let fs = project::FakeFs::new(cx.executor());
|
let fs = project::FakeFs::new(cx.executor());
|
||||||
fs.insert_tree("/project", json!({})).await;
|
fs.insert_tree("/project", json!({})).await;
|
||||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||||
|
let context_server_registry =
|
||||||
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|_| {
|
let thread = cx.new(|cx| {
|
||||||
Thread::new(
|
Thread::new(
|
||||||
project,
|
project,
|
||||||
Rc::default(),
|
Rc::default(),
|
||||||
|
context_server_registry,
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool { thread });
|
||||||
|
@ -1148,14 +1174,18 @@ mod tests {
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
|
let context_server_registry =
|
||||||
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|_| {
|
let thread = cx.new(|cx| {
|
||||||
Thread::new(
|
Thread::new(
|
||||||
project.clone(),
|
project.clone(),
|
||||||
Rc::default(),
|
Rc::default(),
|
||||||
|
context_server_registry.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool { thread });
|
||||||
|
@ -1225,14 +1255,18 @@ mod tests {
|
||||||
.await;
|
.await;
|
||||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
|
let context_server_registry =
|
||||||
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|_| {
|
let thread = cx.new(|cx| {
|
||||||
Thread::new(
|
Thread::new(
|
||||||
project.clone(),
|
project.clone(),
|
||||||
Rc::default(),
|
Rc::default(),
|
||||||
|
context_server_registry.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool { thread });
|
||||||
|
@ -1305,14 +1339,18 @@ mod tests {
|
||||||
.await;
|
.await;
|
||||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
|
let context_server_registry =
|
||||||
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|_| {
|
let thread = cx.new(|cx| {
|
||||||
Thread::new(
|
Thread::new(
|
||||||
project.clone(),
|
project.clone(),
|
||||||
Rc::default(),
|
Rc::default(),
|
||||||
|
context_server_registry.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool { thread });
|
||||||
|
@ -1382,14 +1420,18 @@ mod tests {
|
||||||
let fs = project::FakeFs::new(cx.executor());
|
let fs = project::FakeFs::new(cx.executor());
|
||||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||||
|
let context_server_registry =
|
||||||
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
let model = Arc::new(FakeLanguageModel::default());
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
let thread = cx.new(|_| {
|
let thread = cx.new(|cx| {
|
||||||
Thread::new(
|
Thread::new(
|
||||||
project.clone(),
|
project.clone(),
|
||||||
Rc::default(),
|
Rc::default(),
|
||||||
|
context_server_registry,
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool { thread });
|
||||||
|
|
|
@ -136,7 +136,7 @@ impl AgentTool for FetchTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: Self::Input,
|
input: Self::Input,
|
||||||
event_stream: ToolCallEventStream,
|
_event_stream: ToolCallEventStream,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Self::Output>> {
|
) -> Task<Result<Self::Output>> {
|
||||||
let text = cx.background_spawn({
|
let text = cx.background_spawn({
|
||||||
|
@ -149,12 +149,6 @@ impl AgentTool for FetchTool {
|
||||||
if text.trim().is_empty() {
|
if text.trim().is_empty() {
|
||||||
bail!("no textual content found");
|
bail!("no textual content found");
|
||||||
}
|
}
|
||||||
|
|
||||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
|
||||||
content: Some(vec![text.clone().into()]),
|
|
||||||
..Default::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(text)
|
Ok(text)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -139,9 +139,6 @@ impl AgentTool for FindPathTool {
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
),
|
),
|
||||||
raw_output: Some(serde_json::json!({
|
|
||||||
"paths": &matches,
|
|
||||||
})),
|
|
||||||
..Default::default()
|
..Default::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,7 @@ impl AgentTool for GrepTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: Self::Input,
|
input: Self::Input,
|
||||||
event_stream: ToolCallEventStream,
|
_event_stream: ToolCallEventStream,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Self::Output>> {
|
) -> Task<Result<Self::Output>> {
|
||||||
const CONTEXT_LINES: u32 = 2;
|
const CONTEXT_LINES: u32 = 2;
|
||||||
|
@ -282,33 +282,22 @@ impl AgentTool for GrepTool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
|
||||||
content: Some(vec![output.clone().into()]),
|
|
||||||
..Default::default()
|
|
||||||
});
|
|
||||||
matches_found += 1;
|
matches_found += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let output = if matches_found == 0 {
|
if matches_found == 0 {
|
||||||
"No matches found".to_string()
|
Ok("No matches found".into())
|
||||||
} else if has_more_matches {
|
} else if has_more_matches {
|
||||||
format!(
|
Ok(format!(
|
||||||
"Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
|
"Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
|
||||||
input.offset + 1,
|
input.offset + 1,
|
||||||
input.offset + matches_found,
|
input.offset + matches_found,
|
||||||
input.offset + RESULTS_PER_PAGE,
|
input.offset + RESULTS_PER_PAGE,
|
||||||
)
|
))
|
||||||
} else {
|
} else {
|
||||||
format!("Found {matches_found} matches:\n{output}")
|
Ok(format!("Found {matches_found} matches:\n{output}"))
|
||||||
};
|
}
|
||||||
|
|
||||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
|
||||||
content: Some(vec![output.clone().into()]),
|
|
||||||
..Default::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(output)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,20 +47,13 @@ impl AgentTool for NowTool {
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: Self::Input,
|
input: Self::Input,
|
||||||
event_stream: ToolCallEventStream,
|
_event_stream: ToolCallEventStream,
|
||||||
_cx: &mut App,
|
_cx: &mut App,
|
||||||
) -> Task<Result<String>> {
|
) -> Task<Result<String>> {
|
||||||
let now = match input.timezone {
|
let now = match input.timezone {
|
||||||
Timezone::Utc => Utc::now().to_rfc3339(),
|
Timezone::Utc => Utc::now().to_rfc3339(),
|
||||||
Timezone::Local => Local::now().to_rfc3339(),
|
Timezone::Local => Local::now().to_rfc3339(),
|
||||||
};
|
};
|
||||||
let content = format!("The current datetime is {now}.");
|
Task::ready(Ok(format!("The current datetime is {now}.")))
|
||||||
|
|
||||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
|
||||||
content: Some(vec![content.clone().into()]),
|
|
||||||
..Default::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
Task::ready(Ok(content))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,6 +48,20 @@ pub struct AgentProfileSettings {
|
||||||
pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
|
pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AgentProfileSettings {
|
||||||
|
pub fn is_tool_enabled(&self, tool_name: &str) -> bool {
|
||||||
|
self.tools.get(tool_name) == Some(&true)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_context_server_tool_enabled(&self, server_id: &str, tool_name: &str) -> bool {
|
||||||
|
self.enable_all_context_servers
|
||||||
|
|| self
|
||||||
|
.context_servers
|
||||||
|
.get(server_id)
|
||||||
|
.map_or(false, |preset| preset.tools.get(tool_name) == Some(&true))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub struct ContextServerPreset {
|
pub struct ContextServerPreset {
|
||||||
pub tools: IndexMap<Arc<str>, bool>,
|
pub tools: IndexMap<Arc<str>, bool>,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue