Saving history with thread titles
This commit is contained in:
parent
cc196427f0
commit
5d88de13da
10 changed files with 241 additions and 183 deletions
|
@ -691,6 +691,7 @@ pub struct AcpThread {
|
|||
|
||||
pub enum AcpThreadEvent {
|
||||
NewEntry,
|
||||
TitleUpdated,
|
||||
EntryUpdated(usize),
|
||||
EntriesRemoved(Range<usize>),
|
||||
ToolAuthorizationRequired,
|
||||
|
@ -934,6 +935,12 @@ impl AcpThread {
|
|||
cx.emit(AcpThreadEvent::NewEntry);
|
||||
}
|
||||
|
||||
pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
|
||||
self.title = title;
|
||||
cx.emit(AcpThreadEvent::TitleUpdated);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_tool_call(
|
||||
&mut self,
|
||||
update: impl Into<ToolCallUpdate>,
|
||||
|
|
|
@ -255,6 +255,9 @@ impl NativeAgent {
|
|||
this.sessions.remove(acp_thread.session_id());
|
||||
}),
|
||||
cx.observe(&thread, |this, thread, cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.generate_title_if_needed(cx);
|
||||
});
|
||||
this.save_thread(thread.clone(), cx)
|
||||
}),
|
||||
],
|
||||
|
@ -262,13 +265,14 @@ impl NativeAgent {
|
|||
);
|
||||
}
|
||||
|
||||
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
|
||||
let id = thread.read(cx).id().clone();
|
||||
fn save_thread(&mut self, thread_handle: Entity<Thread>, cx: &mut Context<Self>) {
|
||||
let thread = thread_handle.read(cx);
|
||||
let id = thread.id().clone();
|
||||
let Some(session) = self.sessions.get_mut(&id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let thread = thread.downgrade();
|
||||
let thread = thread_handle.downgrade();
|
||||
let thread_database = self.thread_database.clone();
|
||||
session.save_task = cx.spawn(async move |this, cx| {
|
||||
cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
|
||||
|
@ -507,7 +511,7 @@ impl NativeAgent {
|
|||
|
||||
fn handle_models_updated_event(
|
||||
&mut self,
|
||||
_registry: Entity<LanguageModelRegistry>,
|
||||
registry: Entity<LanguageModelRegistry>,
|
||||
_event: &language_model::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
|
@ -518,6 +522,11 @@ impl NativeAgent {
|
|||
if let Some(model) = self.models.model_from_id(&model_id) {
|
||||
thread.set_model(model.clone(), cx);
|
||||
}
|
||||
let summarization_model = registry
|
||||
.read(cx)
|
||||
.thread_summary_model()
|
||||
.map(|model| model.model.clone());
|
||||
thread.set_summarization_model(summarization_model, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -641,6 +650,10 @@ impl NativeAgentConnection {
|
|||
thread.update_tool_call(update, cx)
|
||||
})??;
|
||||
}
|
||||
ThreadEvent::TitleUpdate(title) => {
|
||||
acp_thread
|
||||
.update(cx, |thread, cx| thread.update_title(title, cx))??;
|
||||
}
|
||||
ThreadEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
return Ok(acp::PromptResponse { stop_reason });
|
||||
|
@ -821,6 +834,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
)
|
||||
})?;
|
||||
|
||||
let summarization_model = registry.thread_summary_model().map(|c| c.model);
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(
|
||||
session_id.clone(),
|
||||
|
@ -830,6 +845,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
action_log.clone(),
|
||||
agent.templates.clone(),
|
||||
default_model,
|
||||
summarization_model,
|
||||
cx,
|
||||
);
|
||||
Self::register_tools(&mut thread, project, action_log, cx);
|
||||
|
@ -894,7 +910,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
|
||||
// Create Thread
|
||||
let thread = agent.update(cx, |agent, cx| {
|
||||
let configured_model = LanguageModelRegistry::global(cx)
|
||||
let language_model_registry = LanguageModelRegistry::global(cx);
|
||||
let configured_model = language_model_registry
|
||||
.update(cx, |registry, cx| {
|
||||
db_thread
|
||||
.model
|
||||
|
@ -915,6 +932,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
.model_from_id(&LanguageModels::model_id(&configured_model.model))
|
||||
.context("no model by id")?;
|
||||
|
||||
let summarization_model = language_model_registry
|
||||
.read(cx)
|
||||
.thread_summary_model()
|
||||
.map(|c| c.model);
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::from_db(
|
||||
session_id,
|
||||
|
@ -925,6 +947,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
action_log.clone(),
|
||||
agent.templates.clone(),
|
||||
model,
|
||||
summarization_model,
|
||||
cx,
|
||||
);
|
||||
Self::register_tools(&mut thread, project, action_log, cx);
|
||||
|
@ -1047,12 +1070,13 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{HistoryEntry, HistoryStore};
|
||||
use crate::HistoryStore;
|
||||
|
||||
use super::*;
|
||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
|
||||
use fs::FakeFs;
|
||||
use gpui::TestAppContext;
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
@ -1245,13 +1269,6 @@ mod tests {
|
|||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let model = cx.update(|cx| {
|
||||
LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.default_model()
|
||||
.unwrap()
|
||||
.model
|
||||
});
|
||||
let connection = NativeAgentConnection(agent.clone());
|
||||
let history_store = cx.new(|cx| {
|
||||
let mut store = HistoryStore::new(cx);
|
||||
|
@ -1268,6 +1285,16 @@ mod tests {
|
|||
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
||||
let selector = connection.model_selector().unwrap();
|
||||
|
||||
let summarization_model: Arc<dyn LanguageModel> =
|
||||
Arc::new(FakeLanguageModel::default()) as _;
|
||||
|
||||
agent.update(cx, |agent, cx| {
|
||||
let thread = agent.sessions.get(&session_id).unwrap().thread.clone();
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_summarization_model(Some(summarization_model.clone()), cx);
|
||||
})
|
||||
});
|
||||
|
||||
let model = cx
|
||||
.update(|cx| selector.selected_model(&session_id, cx))
|
||||
.await
|
||||
|
@ -1283,11 +1310,16 @@ mod tests {
|
|||
model.send_last_completion_stream_text_chunk("Hey");
|
||||
model.end_last_completion_stream();
|
||||
send.await.unwrap();
|
||||
|
||||
summarization_model
|
||||
.as_fake()
|
||||
.send_last_completion_stream_text_chunk("Saying Hello");
|
||||
summarization_model.as_fake().end_last_completion_stream();
|
||||
cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE);
|
||||
|
||||
let history = history_store.update(cx, |store, cx| store.entries(cx));
|
||||
assert_eq!(history.len(), 1);
|
||||
assert_eq!(history[0].title(), "Hi");
|
||||
assert_eq!(history[0].title(), "Saying Hello");
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
|
|
|
@ -386,8 +386,6 @@ impl ThreadsDatabase {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::NativeAgent;
|
||||
use crate::Templates;
|
||||
|
||||
use super::*;
|
||||
use agent::MessageSegment;
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result};
|
||||
use assistant_context::SavedContextMetadata;
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use gpui::{SharedString, Task, prelude::*};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::stream::StreamExt;
|
||||
|
||||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
|
||||
const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
|
||||
|
|
|
@ -1506,6 +1506,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
action_log,
|
||||
templates,
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
|
|
@ -5,7 +5,7 @@ use acp_thread::{MentionUri, UserMessageId};
|
|||
use action_log::ActionLog;
|
||||
use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::adapt_schema_to_format;
|
||||
use chrono::{DateTime, Utc};
|
||||
|
@ -24,7 +24,7 @@ use language_model::{
|
|||
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use project::{
|
||||
Project,
|
||||
|
@ -75,6 +75,18 @@ impl Message {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
match self {
|
||||
Message::User(message) => vec![message.to_request()],
|
||||
Message::Agent(message) => message.to_request(),
|
||||
Message::Resume => vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Continue where you left off".into()],
|
||||
cache: false,
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_markdown(&self) -> String {
|
||||
match self {
|
||||
Message::User(message) => message.to_markdown(),
|
||||
|
@ -82,6 +94,13 @@ impl Message {
|
|||
Message::Resume => "[resumed after tool use limit was reached]".into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn role(&self) -> Role {
|
||||
match self {
|
||||
Message::User(_) | Message::Resume => Role::User,
|
||||
Message::Agent(_) => Role::Assistant,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
@ -426,6 +445,7 @@ pub enum ThreadEvent {
|
|||
ToolCall(acp::ToolCall),
|
||||
ToolCallUpdate(acp_thread::ToolCallUpdate),
|
||||
ToolCallAuthorization(ToolCallAuthorization),
|
||||
TitleUpdate(SharedString),
|
||||
Stop(acp::StopReason),
|
||||
}
|
||||
|
||||
|
@ -475,6 +495,7 @@ pub struct Thread {
|
|||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
templates: Arc<Templates>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
summarization_model: Option<Arc<dyn LanguageModel>>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
@ -488,6 +509,7 @@ impl Thread {
|
|||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
summarization_model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||
|
@ -516,11 +538,37 @@ impl Thread {
|
|||
project_context,
|
||||
templates,
|
||||
model,
|
||||
summarization_model,
|
||||
project,
|
||||
action_log,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn test(
|
||||
model: Arc<dyn LanguageModel>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
use crate::generate_session_id;
|
||||
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
|
||||
Self::new(
|
||||
generate_session_id(),
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log,
|
||||
Templates::new(),
|
||||
model,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &acp::SessionId {
|
||||
&self.id
|
||||
}
|
||||
|
@ -534,6 +582,7 @@ impl Thread {
|
|||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
summarization_model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let profile_id = db_thread
|
||||
|
@ -558,6 +607,7 @@ impl Thread {
|
|||
project_context,
|
||||
templates,
|
||||
model,
|
||||
summarization_model,
|
||||
project,
|
||||
action_log,
|
||||
updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list)
|
||||
|
@ -807,6 +857,15 @@ impl Thread {
|
|||
cx.notify()
|
||||
}
|
||||
|
||||
pub fn set_summarization_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.summarization_model = model;
|
||||
cx.notify()
|
||||
}
|
||||
|
||||
pub fn completion_mode(&self) -> CompletionMode {
|
||||
self.completion_mode
|
||||
}
|
||||
|
@ -1018,6 +1077,86 @@ impl Thread {
|
|||
events_rx
|
||||
}
|
||||
|
||||
pub fn generate_title_if_needed(&mut self, cx: &mut Context<Self>) {
|
||||
if !matches!(self.title, ThreadTitle::None) {
|
||||
return;
|
||||
}
|
||||
|
||||
// todo!() copy logic from agent1 re: tool calls, etc.?
|
||||
if self.messages.len() < 2 {
|
||||
return;
|
||||
}
|
||||
|
||||
self.generate_title(cx);
|
||||
}
|
||||
|
||||
fn generate_title(&mut self, cx: &mut Context<Self>) {
|
||||
let Some(model) = self.summarization_model.clone() else {
|
||||
println!("No thread summary model");
|
||||
return;
|
||||
};
|
||||
let mut request = LanguageModelRequest {
|
||||
intent: Some(CompletionIntent::ThreadSummarization),
|
||||
temperature: AgentSettings::temperature_for_model(&model, cx),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
for message in &self.messages {
|
||||
request.messages.extend(message.to_request());
|
||||
}
|
||||
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text(SUMMARIZE_THREAD_PROMPT.into())],
|
||||
cache: false,
|
||||
});
|
||||
|
||||
let task = cx.spawn(async move |this, cx| {
|
||||
let result = async {
|
||||
let mut messages = model.stream_completion(request, &cx).await?;
|
||||
|
||||
let mut new_summary = String::new();
|
||||
while let Some(event) = messages.next().await {
|
||||
let Ok(event) = event else {
|
||||
continue;
|
||||
};
|
||||
let text = match event {
|
||||
LanguageModelCompletionEvent::Text(text) => text,
|
||||
LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::UsageUpdated { .. },
|
||||
) => {
|
||||
// this.update(cx, |thread, cx| {
|
||||
// thread.update_model_request_usage(amount as u32, limit, cx);
|
||||
// })?;
|
||||
// todo!()? not sure if this is the right place to do this.
|
||||
continue;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let mut lines = text.lines();
|
||||
new_summary.extend(lines.next());
|
||||
|
||||
// Stop if the LLM generated multiple lines.
|
||||
if lines.next().is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(new_summary.into())
|
||||
}
|
||||
.await;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.title = ThreadTitle::Done(result);
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
});
|
||||
|
||||
self.title = ThreadTitle::Pending(task);
|
||||
}
|
||||
|
||||
pub fn build_system_message(&self) -> LanguageModelRequestMessage {
|
||||
log::debug!("Building system message");
|
||||
let prompt = SystemPromptTemplate {
|
||||
|
@ -1373,15 +1512,7 @@ impl Thread {
|
|||
);
|
||||
let mut messages = vec![self.build_system_message()];
|
||||
for message in &self.messages {
|
||||
match message {
|
||||
Message::User(message) => messages.push(message.to_request()),
|
||||
Message::Agent(message) => messages.extend(message.to_request()),
|
||||
Message::Resume => messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Continue where you left off".into()],
|
||||
cache: false,
|
||||
}),
|
||||
}
|
||||
messages.extend(message.to_request());
|
||||
}
|
||||
|
||||
if let Some(message) = self.pending_message.as_ref() {
|
||||
|
@ -1924,7 +2055,7 @@ impl From<UserMessageContent> for acp::ContentBlock {
|
|||
annotations: None,
|
||||
uri: None,
|
||||
}),
|
||||
UserMessageContent::Mention { uri, content } => {
|
||||
UserMessageContent::Mention { .. } => {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -521,7 +521,6 @@ fn resolve_path(
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{ContextServerRegistry, Templates, generate_session_id};
|
||||
use action_log::ActionLog;
|
||||
use client::TelemetrySettings;
|
||||
use fs::Fs;
|
||||
|
@ -529,7 +528,6 @@ mod tests {
|
|||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::rc::Rc;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -541,21 +539,8 @@ mod tests {
|
|||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log,
|
||||
Templates::new(),
|
||||
model,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = EditFileToolInput {
|
||||
|
@ -743,21 +728,8 @@ mod tests {
|
|||
});
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log.clone(), cx));
|
||||
|
||||
// First, test with format_on_save enabled
|
||||
cx.update(|cx| {
|
||||
|
@ -885,22 +857,9 @@ mod tests {
|
|||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log, cx));
|
||||
|
||||
// First, test with remove_trailing_whitespace_on_save enabled
|
||||
cx.update(|cx| {
|
||||
|
@ -1015,22 +974,10 @@ mod tests {
|
|||
let fs = project::FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
|
||||
|
@ -1154,22 +1101,10 @@ mod tests {
|
|||
fs.insert_tree("/project", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
// Test global config paths - these should require confirmation if they exist and are outside the project
|
||||
|
@ -1266,21 +1201,9 @@ mod tests {
|
|||
.await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
// Test files in different worktrees
|
||||
|
@ -1349,21 +1272,9 @@ mod tests {
|
|||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
// Test edge cases
|
||||
|
@ -1435,21 +1346,9 @@ mod tests {
|
|||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
// Test different EditFileMode values
|
||||
|
@ -1518,21 +1417,9 @@ mod tests {
|
|||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
generate_session_id(),
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
|
||||
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
|
||||
|
||||
assert_eq!(
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
use crate::{AgentPanel, RemoveSelectedThread};
|
||||
use crate::RemoveSelectedThread;
|
||||
use agent_servers::AgentServer;
|
||||
use agent2::{
|
||||
NativeAgentServer,
|
||||
history_store::{HistoryEntry, HistoryStore},
|
||||
};
|
||||
use agent2::{HistoryEntry, HistoryStore, NativeAgentServer};
|
||||
use chrono::{Datelike as _, Local, NaiveDate, TimeDelta};
|
||||
use editor::{Editor, EditorEvent};
|
||||
use fuzzy::{StringMatch, StringMatchCandidate};
|
||||
use gpui::{
|
||||
App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Stateful, Task,
|
||||
UniformListScrollHandle, WeakEntity, Window, uniform_list,
|
||||
UniformListScrollHandle, Window, uniform_list,
|
||||
};
|
||||
use project::Project;
|
||||
use std::{fmt::Display, ops::Range, sync::Arc};
|
||||
|
@ -72,7 +69,7 @@ impl AcpThreadHistory {
|
|||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let history_store = cx.new(|cx| agent2::history_store::HistoryStore::new(cx));
|
||||
let history_store = cx.new(|cx| agent2::HistoryStore::new(cx));
|
||||
|
||||
let agent = NativeAgentServer::new(project.read(cx).fs().clone());
|
||||
|
||||
|
|
|
@ -687,6 +687,7 @@ impl AcpThreadView {
|
|||
AcpThreadEvent::ServerExited(status) => {
|
||||
self.thread_state = ThreadState::ServerExited { status: *status };
|
||||
}
|
||||
AcpThreadEvent::TitleUpdated => {}
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
|
|
|
@ -199,24 +199,21 @@ impl AgentDiffPane {
|
|||
let action_log = thread.action_log(cx).clone();
|
||||
|
||||
let mut this = Self {
|
||||
_subscriptions: [
|
||||
Some(
|
||||
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
|
||||
this.update_excerpts(window, cx)
|
||||
}),
|
||||
),
|
||||
_subscriptions: vec![
|
||||
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
|
||||
this.update_excerpts(window, cx)
|
||||
}),
|
||||
match &thread {
|
||||
AgentDiffThread::Native(thread) => {
|
||||
Some(cx.subscribe(&thread, |this, _thread, event, cx| {
|
||||
this.handle_thread_event(event, cx)
|
||||
}))
|
||||
}
|
||||
AgentDiffThread::AcpThread(_) => None,
|
||||
AgentDiffThread::Native(thread) => cx
|
||||
.subscribe(&thread, |this, _thread, event, cx| {
|
||||
this.handle_native_thread_event(event, cx)
|
||||
}),
|
||||
AgentDiffThread::AcpThread(thread) => cx
|
||||
.subscribe(&thread, |this, _thread, event, cx| {
|
||||
this.handle_acp_thread_event(event, cx)
|
||||
}),
|
||||
},
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
],
|
||||
title: SharedString::default(),
|
||||
multibuffer,
|
||||
editor,
|
||||
|
@ -324,13 +321,20 @@ impl AgentDiffPane {
|
|||
}
|
||||
}
|
||||
|
||||
fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
|
||||
fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
|
||||
match event {
|
||||
ThreadEvent::SummaryGenerated => self.update_title(cx),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
|
||||
match event {
|
||||
AcpThreadEvent::TitleUpdated => self.update_title(cx),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) {
|
||||
if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
|
@ -1521,7 +1525,8 @@ impl AgentDiff {
|
|||
self.update_reviewing_editors(workspace, window, cx);
|
||||
}
|
||||
}
|
||||
AcpThreadEvent::EntriesRemoved(_)
|
||||
AcpThreadEvent::TitleUpdated
|
||||
| AcpThreadEvent::EntriesRemoved(_)
|
||||
| AcpThreadEvent::Stopped
|
||||
| AcpThreadEvent::ToolAuthorizationRequired
|
||||
| AcpThreadEvent::Error
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue