Take a weak thread in EditFileTool to avoid cycle
This commit is contained in:
parent
fae5900749
commit
a231fd3ee5
3 changed files with 89 additions and 75 deletions
|
@ -574,7 +574,7 @@ impl NativeAgentConnection {
|
|||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
|
||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||
thread.add_tool(EditFileTool::new(cx.entity()));
|
||||
thread.add_tool(EditFileTool::new(cx.weak_entity()));
|
||||
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
||||
thread.add_tool(FindPathTool::new(project.clone()));
|
||||
thread.add_tool(GrepTool::new(project.clone()));
|
||||
|
@ -801,7 +801,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
fn load_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
_cwd: &Path,
|
||||
session_id: acp::SessionId,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
||||
|
@ -828,46 +828,43 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
let agent = self.0.clone();
|
||||
|
||||
// Create Thread
|
||||
let thread = agent.update(
|
||||
cx,
|
||||
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
|
||||
let configured_model = LanguageModelRegistry::global(cx)
|
||||
.update(cx, |registry, cx| {
|
||||
db_thread
|
||||
.model
|
||||
.and_then(|model| {
|
||||
let model = SelectedModel {
|
||||
provider: model.provider.clone().into(),
|
||||
model: model.model.clone().into(),
|
||||
};
|
||||
registry.select_model(&model, cx)
|
||||
})
|
||||
.or_else(|| registry.default_model())
|
||||
})
|
||||
.context("no default model configured")?;
|
||||
let thread = agent.update(cx, |agent, cx| {
|
||||
let configured_model = LanguageModelRegistry::global(cx)
|
||||
.update(cx, |registry, cx| {
|
||||
db_thread
|
||||
.model
|
||||
.and_then(|model| {
|
||||
let model = SelectedModel {
|
||||
provider: model.provider.clone().into(),
|
||||
model: model.model.clone().into(),
|
||||
};
|
||||
registry.select_model(&model, cx)
|
||||
})
|
||||
.or_else(|| registry.default_model())
|
||||
})
|
||||
.context("no default model configured")?;
|
||||
|
||||
let model = agent
|
||||
.models
|
||||
.model_from_id(&LanguageModels::model_id(&configured_model.model))
|
||||
.context("no model by id")?;
|
||||
let model = agent
|
||||
.models
|
||||
.model_from_id(&LanguageModels::model_id(&configured_model.model))
|
||||
.context("no model by id")?;
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(
|
||||
project.clone(),
|
||||
agent.project_context.clone(),
|
||||
agent.context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
agent.templates.clone(),
|
||||
model,
|
||||
cx,
|
||||
);
|
||||
Self::register_tools(&mut thread, project, action_log, cx);
|
||||
thread
|
||||
});
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(
|
||||
project.clone(),
|
||||
agent.project_context.clone(),
|
||||
agent.context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
agent.templates.clone(),
|
||||
model,
|
||||
cx,
|
||||
);
|
||||
Self::register_tools(&mut thread, project, action_log, cx);
|
||||
thread
|
||||
});
|
||||
|
||||
Ok(thread)
|
||||
},
|
||||
)??;
|
||||
anyhow::Ok(thread)
|
||||
})??;
|
||||
|
||||
// Store the session
|
||||
agent.update(cx, |agent, cx| {
|
||||
|
@ -884,7 +881,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
})?;
|
||||
|
||||
// we need to actually deserialize the DbThread.
|
||||
todo!()
|
||||
// todo!()
|
||||
|
||||
Ok(acp_thread)
|
||||
})
|
||||
|
|
|
@ -441,7 +441,7 @@ impl Thread {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||
let this = Self {
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
completion_mode: CompletionMode::Normal,
|
||||
running_turn: None,
|
||||
|
@ -455,7 +455,7 @@ impl Thread {
|
|||
model,
|
||||
project,
|
||||
action_log,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub fn project(&self) -> &Entity<Project> {
|
||||
|
|
|
@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow};
|
|||
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashSet;
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
|
||||
use indoc::formatdoc;
|
||||
use language::ToPoint;
|
||||
use language::language_settings::{self, FormatOnSave};
|
||||
|
@ -122,11 +122,11 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
|
|||
}
|
||||
|
||||
pub struct EditFileTool {
|
||||
thread: Entity<Thread>,
|
||||
thread: WeakEntity<Thread>,
|
||||
}
|
||||
|
||||
impl EditFileTool {
|
||||
pub fn new(thread: Entity<Thread>) -> Self {
|
||||
pub fn new(thread: WeakEntity<Thread>) -> Self {
|
||||
Self { thread }
|
||||
}
|
||||
|
||||
|
@ -167,8 +167,11 @@ impl EditFileTool {
|
|||
|
||||
// Check if path is inside the global config directory
|
||||
// First check if it's already inside project - if not, try to canonicalize
|
||||
let thread = self.thread.read(cx);
|
||||
let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
|
||||
let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
|
||||
thread.project().read(cx).find_project_path(&input.path, cx)
|
||||
}) else {
|
||||
return Task::ready(Err(anyhow!("thread was dropped")));
|
||||
};
|
||||
|
||||
// If the path is inside the project, and it's not one of the above edge cases,
|
||||
// then no confirmation is necessary. Otherwise, confirmation is necessary.
|
||||
|
@ -221,7 +224,12 @@ impl AgentTool for EditFileTool {
|
|||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
let project = self.thread.read(cx).project().clone();
|
||||
let Ok(project) = self
|
||||
.thread
|
||||
.read_with(cx, |thread, _cx| thread.project().clone())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("thread was dropped")));
|
||||
};
|
||||
let project_path = match resolve_path(&input, project.clone(), cx) {
|
||||
Ok(path) => path,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))),
|
||||
|
@ -237,17 +245,15 @@ impl AgentTool for EditFileTool {
|
|||
});
|
||||
}
|
||||
|
||||
let request = self.thread.update(cx, |thread, cx| {
|
||||
thread.build_completion_request(CompletionIntent::ToolResults, cx)
|
||||
});
|
||||
let thread = self.thread.read(cx);
|
||||
let model = thread.model().clone();
|
||||
let action_log = thread.action_log().clone();
|
||||
|
||||
let authorize = self.authorize(&input, &event_stream, cx);
|
||||
cx.spawn(async move |cx: &mut AsyncApp| {
|
||||
authorize.await?;
|
||||
|
||||
let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
|
||||
let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
|
||||
(request, thread.model().clone(), thread.action_log().clone())
|
||||
})?;
|
||||
|
||||
let edit_format = EditFormat::from_model(model.clone())?;
|
||||
let edit_agent = EditAgent::new(
|
||||
model,
|
||||
|
@ -531,7 +537,11 @@ mod tests {
|
|||
path: "root/nonexistent_file.txt".into(),
|
||||
mode: EditFileMode::Edit,
|
||||
};
|
||||
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
|
||||
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(
|
||||
|
@ -744,10 +754,11 @@ mod tests {
|
|||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
Arc::new(EditFileTool {
|
||||
thread: thread.clone(),
|
||||
})
|
||||
.run(input, ToolCallEventStream::test().0, cx)
|
||||
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Stream the unformatted content
|
||||
|
@ -800,7 +811,11 @@ mod tests {
|
|||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
|
||||
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Stream the unformatted content
|
||||
|
@ -881,10 +896,11 @@ mod tests {
|
|||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
Arc::new(EditFileTool {
|
||||
thread: thread.clone(),
|
||||
})
|
||||
.run(input, ToolCallEventStream::test().0, cx)
|
||||
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Stream the content with trailing whitespace
|
||||
|
@ -932,10 +948,11 @@ mod tests {
|
|||
path: "root/src/main.rs".into(),
|
||||
mode: EditFileMode::Overwrite,
|
||||
};
|
||||
Arc::new(EditFileTool {
|
||||
thread: thread.clone(),
|
||||
})
|
||||
.run(input, ToolCallEventStream::test().0, cx)
|
||||
Arc::new(EditFileTool::new(thread.downgrade())).run(
|
||||
input,
|
||||
ToolCallEventStream::test().0,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Stream the content with trailing whitespace
|
||||
|
@ -983,7 +1000,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
|
||||
// Test 1: Path with .zed component should require confirmation
|
||||
|
@ -1114,7 +1131,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||
|
||||
// Test global config paths - these should require confirmation if they exist and are outside the project
|
||||
let test_cases = vec![
|
||||
|
@ -1224,7 +1241,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||
|
||||
// Test files in different worktrees
|
||||
let test_cases = vec![
|
||||
|
@ -1305,7 +1322,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||
|
||||
// Test edge cases
|
||||
let test_cases = vec![
|
||||
|
@ -1389,7 +1406,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||
|
||||
// Test different EditFileMode values
|
||||
let modes = vec![
|
||||
|
@ -1470,7 +1487,7 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade()));
|
||||
|
||||
assert_eq!(
|
||||
tool.initial_title(Err(json!({
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue