Take a weak thread in EditFileTool to avoid cycle
This commit is contained in:
parent
fae5900749
commit
a231fd3ee5
3 changed files with 89 additions and 75 deletions
|
@ -574,7 +574,7 @@ impl NativeAgentConnection {
|
||||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||||
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
|
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
|
||||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||||
thread.add_tool(EditFileTool::new(cx.entity()));
|
thread.add_tool(EditFileTool::new(cx.weak_entity()));
|
||||||
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
||||||
thread.add_tool(FindPathTool::new(project.clone()));
|
thread.add_tool(FindPathTool::new(project.clone()));
|
||||||
thread.add_tool(GrepTool::new(project.clone()));
|
thread.add_tool(GrepTool::new(project.clone()));
|
||||||
|
@ -801,7 +801,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
fn load_thread(
|
fn load_thread(
|
||||||
self: Rc<Self>,
|
self: Rc<Self>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
cwd: &Path,
|
_cwd: &Path,
|
||||||
session_id: acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
||||||
|
@ -828,46 +828,43 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
let agent = self.0.clone();
|
let agent = self.0.clone();
|
||||||
|
|
||||||
// Create Thread
|
// Create Thread
|
||||||
let thread = agent.update(
|
let thread = agent.update(cx, |agent, cx| {
|
||||||
cx,
|
let configured_model = LanguageModelRegistry::global(cx)
|
||||||
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
|
.update(cx, |registry, cx| {
|
||||||
let configured_model = LanguageModelRegistry::global(cx)
|
db_thread
|
||||||
.update(cx, |registry, cx| {
|
.model
|
||||||
db_thread
|
.and_then(|model| {
|
||||||
.model
|
let model = SelectedModel {
|
||||||
.and_then(|model| {
|
provider: model.provider.clone().into(),
|
||||||
let model = SelectedModel {
|
model: model.model.clone().into(),
|
||||||
provider: model.provider.clone().into(),
|
};
|
||||||
model: model.model.clone().into(),
|
registry.select_model(&model, cx)
|
||||||
};
|
})
|
||||||
registry.select_model(&model, cx)
|
.or_else(|| registry.default_model())
|
||||||
})
|
})
|
||||||
.or_else(|| registry.default_model())
|
.context("no default model configured")?;
|
||||||
})
|
|
||||||
.context("no default model configured")?;
|
|
||||||
|
|
||||||
let model = agent
|
let model = agent
|
||||||
.models
|
.models
|
||||||
.model_from_id(&LanguageModels::model_id(&configured_model.model))
|
.model_from_id(&LanguageModels::model_id(&configured_model.model))
|
||||||
.context("no model by id")?;
|
.context("no model by id")?;
|
||||||
|
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
let mut thread = Thread::new(
|
let mut thread = Thread::new(
|
||||||
project.clone(),
|
project.clone(),
|
||||||
agent.project_context.clone(),
|
agent.project_context.clone(),
|
||||||
agent.context_server_registry.clone(),
|
agent.context_server_registry.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
agent.templates.clone(),
|
agent.templates.clone(),
|
||||||
model,
|
model,
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
Self::register_tools(&mut thread, project, action_log, cx);
|
Self::register_tools(&mut thread, project, action_log, cx);
|
||||||
thread
|
thread
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(thread)
|
anyhow::Ok(thread)
|
||||||
},
|
})??;
|
||||||
)??;
|
|
||||||
|
|
||||||
// Store the session
|
// Store the session
|
||||||
agent.update(cx, |agent, cx| {
|
agent.update(cx, |agent, cx| {
|
||||||
|
@ -884,7 +881,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// we need to actually deserialize the DbThread.
|
// we need to actually deserialize the DbThread.
|
||||||
todo!()
|
// todo!()
|
||||||
|
|
||||||
Ok(acp_thread)
|
Ok(acp_thread)
|
||||||
})
|
})
|
||||||
|
|
|
@ -441,7 +441,7 @@ impl Thread {
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||||
let this = Self {
|
Self {
|
||||||
messages: Vec::new(),
|
messages: Vec::new(),
|
||||||
completion_mode: CompletionMode::Normal,
|
completion_mode: CompletionMode::Normal,
|
||||||
running_turn: None,
|
running_turn: None,
|
||||||
|
@ -455,7 +455,7 @@ impl Thread {
|
||||||
model,
|
model,
|
||||||
project,
|
project,
|
||||||
action_log,
|
action_log,
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn project(&self) -> &Entity<Project> {
|
pub fn project(&self) -> &Entity<Project> {
|
||||||
|
|
|
@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
|
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
|
||||||
use cloud_llm_client::CompletionIntent;
|
use cloud_llm_client::CompletionIntent;
|
||||||
use collections::HashSet;
|
use collections::HashSet;
|
||||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
|
||||||
use indoc::formatdoc;
|
use indoc::formatdoc;
|
||||||
use language::ToPoint;
|
use language::ToPoint;
|
||||||
use language::language_settings::{self, FormatOnSave};
|
use language::language_settings::{self, FormatOnSave};
|
||||||
|
@ -122,11 +122,11 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct EditFileTool {
|
pub struct EditFileTool {
|
||||||
thread: Entity<Thread>,
|
thread: WeakEntity<Thread>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EditFileTool {
|
impl EditFileTool {
|
||||||
pub fn new(thread: Entity<Thread>) -> Self {
|
pub fn new(thread: WeakEntity<Thread>) -> Self {
|
||||||
Self { thread }
|
Self { thread }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,8 +167,11 @@ impl EditFileTool {
|
||||||
|
|
||||||
// Check if path is inside the global config directory
|
// Check if path is inside the global config directory
|
||||||
// First check if it's already inside project - if not, try to canonicalize
|
// First check if it's already inside project - if not, try to canonicalize
|
||||||
let thread = self.thread.read(cx);
|
let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
|
||||||
let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
|
thread.project().read(cx).find_project_path(&input.path, cx)
|
||||||
|
}) else {
|
||||||
|
return Task::ready(Err(anyhow!("thread was dropped")));
|
||||||
|
};
|
||||||
|
|
||||||
// If the path is inside the project, and it's not one of the above edge cases,
|
// If the path is inside the project, and it's not one of the above edge cases,
|
||||||
// then no confirmation is necessary. Otherwise, confirmation is necessary.
|
// then no confirmation is necessary. Otherwise, confirmation is necessary.
|
||||||
|
@ -221,7 +224,12 @@ impl AgentTool for EditFileTool {
|
||||||
event_stream: ToolCallEventStream,
|
event_stream: ToolCallEventStream,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Self::Output>> {
|
) -> Task<Result<Self::Output>> {
|
||||||
let project = self.thread.read(cx).project().clone();
|
let Ok(project) = self
|
||||||
|
.thread
|
||||||
|
.read_with(cx, |thread, _cx| thread.project().clone())
|
||||||
|
else {
|
||||||
|
return Task::ready(Err(anyhow!("thread was dropped")));
|
||||||
|
};
|
||||||
let project_path = match resolve_path(&input, project.clone(), cx) {
|
let project_path = match resolve_path(&input, project.clone(), cx) {
|
||||||
Ok(path) => path,
|
Ok(path) => path,
|
||||||
Err(err) => return Task::ready(Err(anyhow!(err))),
|
Err(err) => return Task::ready(Err(anyhow!(err))),
|
||||||
|
@ -237,17 +245,15 @@ impl AgentTool for EditFileTool {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let request = self.thread.update(cx, |thread, cx| {
|
|
||||||
thread.build_completion_request(CompletionIntent::ToolResults, cx)
|
|
||||||
});
|
|
||||||
let thread = self.thread.read(cx);
|
|
||||||
let model = thread.model().clone();
|
|
||||||
let action_log = thread.action_log().clone();
|
|
||||||
|
|
||||||
let authorize = self.authorize(&input, &event_stream, cx);
|
let authorize = self.authorize(&input, &event_stream, cx);
|
||||||
cx.spawn(async move |cx: &mut AsyncApp| {
|
cx.spawn(async move |cx: &mut AsyncApp| {
|
||||||
authorize.await?;
|
authorize.await?;
|
||||||
|
|
||||||
|
let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
|
||||||
|
let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
|
||||||
|
(request, thread.model().clone(), thread.action_log().clone())
|
||||||
|
})?;
|
||||||
|
|
||||||
let edit_format = EditFormat::from_model(model.clone())?;
|
let edit_format = EditFormat::from_model(model.clone())?;
|
||||||
let edit_agent = EditAgent::new(
|
let edit_agent = EditAgent::new(
|
||||||
model,
|
model,
|
||||||
|
@ -531,7 +537,11 @@ mod tests {
|
||||||
path: "root/nonexistent_file.txt".into(),
|
path: "root/nonexistent_file.txt".into(),
|
||||||
mode: EditFileMode::Edit,
|
mode: EditFileMode::Edit,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
|
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||||
|
input,
|
||||||
|
ToolCallEventStream::test().0,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -744,10 +754,11 @@ mod tests {
|
||||||
path: "root/src/main.rs".into(),
|
path: "root/src/main.rs".into(),
|
||||||
mode: EditFileMode::Overwrite,
|
mode: EditFileMode::Overwrite,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool {
|
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||||
thread: thread.clone(),
|
input,
|
||||||
})
|
ToolCallEventStream::test().0,
|
||||||
.run(input, ToolCallEventStream::test().0, cx)
|
cx,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Stream the unformatted content
|
// Stream the unformatted content
|
||||||
|
@ -800,7 +811,11 @@ mod tests {
|
||||||
path: "root/src/main.rs".into(),
|
path: "root/src/main.rs".into(),
|
||||||
mode: EditFileMode::Overwrite,
|
mode: EditFileMode::Overwrite,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
|
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||||
|
input,
|
||||||
|
ToolCallEventStream::test().0,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Stream the unformatted content
|
// Stream the unformatted content
|
||||||
|
@ -881,10 +896,11 @@ mod tests {
|
||||||
path: "root/src/main.rs".into(),
|
path: "root/src/main.rs".into(),
|
||||||
mode: EditFileMode::Overwrite,
|
mode: EditFileMode::Overwrite,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool {
|
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||||
thread: thread.clone(),
|
input,
|
||||||
})
|
ToolCallEventStream::test().0,
|
||||||
.run(input, ToolCallEventStream::test().0, cx)
|
cx,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Stream the content with trailing whitespace
|
// Stream the content with trailing whitespace
|
||||||
|
@ -932,10 +948,11 @@ mod tests {
|
||||||
path: "root/src/main.rs".into(),
|
path: "root/src/main.rs".into(),
|
||||||
mode: EditFileMode::Overwrite,
|
mode: EditFileMode::Overwrite,
|
||||||
};
|
};
|
||||||
Arc::new(EditFileTool {
|
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||||
thread: thread.clone(),
|
input,
|
||||||
})
|
ToolCallEventStream::test().0,
|
||||||
.run(input, ToolCallEventStream::test().0, cx)
|
cx,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Stream the content with trailing whitespace
|
// Stream the content with trailing whitespace
|
||||||
|
@ -983,7 +1000,7 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||||
fs.insert_tree("/root", json!({})).await;
|
fs.insert_tree("/root", json!({})).await;
|
||||||
|
|
||||||
// Test 1: Path with .zed component should require confirmation
|
// Test 1: Path with .zed component should require confirmation
|
||||||
|
@ -1114,7 +1131,7 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||||
|
|
||||||
// Test global config paths - these should require confirmation if they exist and are outside the project
|
// Test global config paths - these should require confirmation if they exist and are outside the project
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
|
@ -1224,7 +1241,7 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||||
|
|
||||||
// Test files in different worktrees
|
// Test files in different worktrees
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
|
@ -1305,7 +1322,7 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||||
|
|
||||||
// Test edge cases
|
// Test edge cases
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
|
@ -1389,7 +1406,7 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||||
|
|
||||||
// Test different EditFileMode values
|
// Test different EditFileMode values
|
||||||
let modes = vec![
|
let modes = vec![
|
||||||
|
@ -1470,7 +1487,7 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
let tool = Arc::new(EditFileTool { thread });
|
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tool.initial_title(Err(json!({
|
tool.initial_title(Err(json!({
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue