Fix mentions roundtrip from/to database and other history bugs (#36575)
Release Notes: - N/A
This commit is contained in:
parent
f80a0ba056
commit
4ee565cd39
2 changed files with 200 additions and 28 deletions
|
@ -577,6 +577,10 @@ impl NativeAgent {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
|
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
|
||||||
|
if thread.read(cx).is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let database_future = ThreadsDatabase::connect(cx);
|
let database_future = ThreadsDatabase::connect(cx);
|
||||||
let (id, db_thread) =
|
let (id, db_thread) =
|
||||||
thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
|
thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
|
||||||
|
@ -989,12 +993,19 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use crate::HistoryEntryId;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
|
use acp_thread::{
|
||||||
|
AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri,
|
||||||
|
};
|
||||||
use fs::FakeFs;
|
use fs::FakeFs;
|
||||||
use gpui::TestAppContext;
|
use gpui::TestAppContext;
|
||||||
|
use indoc::indoc;
|
||||||
|
use language_model::fake_provider::FakeLanguageModel;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
|
use util::path;
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_maintaining_project_context(cx: &mut TestAppContext) {
|
async fn test_maintaining_project_context(cx: &mut TestAppContext) {
|
||||||
|
@ -1179,6 +1190,163 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
#[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
|
||||||
|
async fn test_save_load_thread(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
let fs = FakeFs::new(cx.executor());
|
||||||
|
fs.insert_tree(
|
||||||
|
"/",
|
||||||
|
json!({
|
||||||
|
"a": {
|
||||||
|
"b.md": "Lorem"
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
||||||
|
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
|
||||||
|
let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
|
||||||
|
let agent = NativeAgent::new(
|
||||||
|
project.clone(),
|
||||||
|
history_store.clone(),
|
||||||
|
Templates::new(),
|
||||||
|
None,
|
||||||
|
fs.clone(),
|
||||||
|
&mut cx.to_async(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
||||||
|
|
||||||
|
let acp_thread = cx
|
||||||
|
.update(|cx| {
|
||||||
|
connection
|
||||||
|
.clone()
|
||||||
|
.new_thread(project.clone(), Path::new(""), cx)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
||||||
|
let thread = agent.read_with(cx, |agent, _| {
|
||||||
|
agent.sessions.get(&session_id).unwrap().thread.clone()
|
||||||
|
});
|
||||||
|
|
||||||
|
// Ensure empty threads are not saved, even if they get mutated.
|
||||||
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
|
let summary_model = Arc::new(FakeLanguageModel::default());
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.set_model(model, cx);
|
||||||
|
thread.set_summarization_model(Some(summary_model), cx);
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
assert_eq!(history_entries(&history_store, cx), vec![]);
|
||||||
|
|
||||||
|
let model = thread.read_with(cx, |thread, _| thread.model().unwrap().clone());
|
||||||
|
let model = model.as_fake();
|
||||||
|
let summary_model = thread.read_with(cx, |thread, _| {
|
||||||
|
thread.summarization_model().unwrap().clone()
|
||||||
|
});
|
||||||
|
let summary_model = summary_model.as_fake();
|
||||||
|
let send = acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(
|
||||||
|
vec![
|
||||||
|
"What does ".into(),
|
||||||
|
acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
||||||
|
name: "b.md".into(),
|
||||||
|
uri: MentionUri::File {
|
||||||
|
abs_path: path!("/a/b.md").into(),
|
||||||
|
}
|
||||||
|
.to_uri()
|
||||||
|
.to_string(),
|
||||||
|
annotations: None,
|
||||||
|
description: None,
|
||||||
|
mime_type: None,
|
||||||
|
size: None,
|
||||||
|
title: None,
|
||||||
|
}),
|
||||||
|
" mean?".into(),
|
||||||
|
],
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
let send = cx.foreground_executor().spawn(send);
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
model.send_last_completion_stream_text_chunk("Lorem.");
|
||||||
|
model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
|
||||||
|
summary_model.end_last_completion_stream();
|
||||||
|
|
||||||
|
send.await.unwrap();
|
||||||
|
acp_thread.read_with(cx, |thread, cx| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(cx),
|
||||||
|
indoc! {"
|
||||||
|
## User
|
||||||
|
|
||||||
|
What does [@b.md](file:///a/b.md) mean?
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
Lorem.
|
||||||
|
|
||||||
|
"}
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Drop the ACP thread, which should cause the session to be dropped as well.
|
||||||
|
cx.update(|_| {
|
||||||
|
drop(thread);
|
||||||
|
drop(acp_thread);
|
||||||
|
});
|
||||||
|
agent.read_with(cx, |agent, _| {
|
||||||
|
assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Ensure the thread can be reloaded from disk.
|
||||||
|
assert_eq!(
|
||||||
|
history_entries(&history_store, cx),
|
||||||
|
vec![(
|
||||||
|
HistoryEntryId::AcpThread(session_id.clone()),
|
||||||
|
"Explaining /a/b.md".into()
|
||||||
|
)]
|
||||||
|
);
|
||||||
|
let acp_thread = agent
|
||||||
|
.update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
acp_thread.read_with(cx, |thread, cx| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(cx),
|
||||||
|
indoc! {"
|
||||||
|
## User
|
||||||
|
|
||||||
|
What does [@b.md](file:///a/b.md) mean?
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
Lorem.
|
||||||
|
|
||||||
|
"}
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn history_entries(
|
||||||
|
history: &Entity<HistoryStore>,
|
||||||
|
cx: &mut TestAppContext,
|
||||||
|
) -> Vec<(HistoryEntryId, String)> {
|
||||||
|
history.read_with(cx, |history, cx| {
|
||||||
|
history
|
||||||
|
.entries(cx)
|
||||||
|
.iter()
|
||||||
|
.map(|e| (e.id(), e.title().to_string()))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn init_test(cx: &mut TestAppContext) {
|
fn init_test(cx: &mut TestAppContext) {
|
||||||
env_logger::try_init().ok();
|
env_logger::try_init().ok();
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
|
|
|
@ -720,7 +720,7 @@ impl Thread {
|
||||||
pub fn to_db(&self, cx: &App) -> Task<DbThread> {
|
pub fn to_db(&self, cx: &App) -> Task<DbThread> {
|
||||||
let initial_project_snapshot = self.initial_project_snapshot.clone();
|
let initial_project_snapshot = self.initial_project_snapshot.clone();
|
||||||
let mut thread = DbThread {
|
let mut thread = DbThread {
|
||||||
title: self.title.clone().unwrap_or_default(),
|
title: self.title(),
|
||||||
messages: self.messages.clone(),
|
messages: self.messages.clone(),
|
||||||
updated_at: self.updated_at,
|
updated_at: self.updated_at,
|
||||||
detailed_summary: self.summary.clone(),
|
detailed_summary: self.summary.clone(),
|
||||||
|
@ -870,6 +870,10 @@ impl Thread {
|
||||||
&self.action_log
|
&self.action_log
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.messages.is_empty() && self.title.is_none()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
|
pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
|
||||||
self.model.as_ref()
|
self.model.as_ref()
|
||||||
}
|
}
|
||||||
|
@ -884,6 +888,10 @@ impl Thread {
|
||||||
cx.notify()
|
cx.notify()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn summarization_model(&self) -> Option<&Arc<dyn LanguageModel>> {
|
||||||
|
self.summarization_model.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_summarization_model(
|
pub fn set_summarization_model(
|
||||||
&mut self,
|
&mut self,
|
||||||
model: Option<Arc<dyn LanguageModel>>,
|
model: Option<Arc<dyn LanguageModel>>,
|
||||||
|
@ -1068,6 +1076,7 @@ impl Thread {
|
||||||
event_stream: event_stream.clone(),
|
event_stream: event_stream.clone(),
|
||||||
_task: cx.spawn(async move |this, cx| {
|
_task: cx.spawn(async move |this, cx| {
|
||||||
log::info!("Starting agent turn execution");
|
log::info!("Starting agent turn execution");
|
||||||
|
let mut update_title = None;
|
||||||
let turn_result: Result<StopReason> = async {
|
let turn_result: Result<StopReason> = async {
|
||||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||||
loop {
|
loop {
|
||||||
|
@ -1122,10 +1131,15 @@ impl Thread {
|
||||||
this.pending_message()
|
this.pending_message()
|
||||||
.tool_results
|
.tool_results
|
||||||
.insert(tool_result.tool_use_id.clone(), tool_result);
|
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||||
})
|
})?;
|
||||||
.ok();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
if this.title.is_none() && update_title.is_none() {
|
||||||
|
update_title = Some(this.update_title(&event_stream, cx));
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
if tool_use_limit_reached {
|
if tool_use_limit_reached {
|
||||||
log::info!("Tool use limit reached, completing turn");
|
log::info!("Tool use limit reached, completing turn");
|
||||||
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
|
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
|
||||||
|
@ -1146,10 +1160,6 @@ impl Thread {
|
||||||
Ok(reason) => {
|
Ok(reason) => {
|
||||||
log::info!("Turn execution completed: {:?}", reason);
|
log::info!("Turn execution completed: {:?}", reason);
|
||||||
|
|
||||||
let update_title = this
|
|
||||||
.update(cx, |this, cx| this.update_title(&event_stream, cx))
|
|
||||||
.ok()
|
|
||||||
.flatten();
|
|
||||||
if let Some(update_title) = update_title {
|
if let Some(update_title) = update_title {
|
||||||
update_title.await.context("update title failed").log_err();
|
update_title.await.context("update title failed").log_err();
|
||||||
}
|
}
|
||||||
|
@ -1593,17 +1603,14 @@ impl Thread {
|
||||||
&mut self,
|
&mut self,
|
||||||
event_stream: &ThreadEventStream,
|
event_stream: &ThreadEventStream,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Option<Task<Result<()>>> {
|
) -> Task<Result<()>> {
|
||||||
if self.title.is_some() {
|
|
||||||
log::debug!("Skipping title generation because we already have one.");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
log::info!(
|
log::info!(
|
||||||
"Generating title with model: {:?}",
|
"Generating title with model: {:?}",
|
||||||
self.summarization_model.as_ref().map(|model| model.name())
|
self.summarization_model.as_ref().map(|model| model.name())
|
||||||
);
|
);
|
||||||
let model = self.summarization_model.clone()?;
|
let Some(model) = self.summarization_model.clone() else {
|
||||||
|
return Task::ready(Ok(()));
|
||||||
|
};
|
||||||
let event_stream = event_stream.clone();
|
let event_stream = event_stream.clone();
|
||||||
let mut request = LanguageModelRequest {
|
let mut request = LanguageModelRequest {
|
||||||
intent: Some(CompletionIntent::ThreadSummarization),
|
intent: Some(CompletionIntent::ThreadSummarization),
|
||||||
|
@ -1620,7 +1627,7 @@ impl Thread {
|
||||||
content: vec![SUMMARIZE_THREAD_PROMPT.into()],
|
content: vec![SUMMARIZE_THREAD_PROMPT.into()],
|
||||||
cache: false,
|
cache: false,
|
||||||
});
|
});
|
||||||
Some(cx.spawn(async move |this, cx| {
|
cx.spawn(async move |this, cx| {
|
||||||
let mut title = String::new();
|
let mut title = String::new();
|
||||||
let mut messages = model.stream_completion(request, cx).await?;
|
let mut messages = model.stream_completion(request, cx).await?;
|
||||||
while let Some(event) = messages.next().await {
|
while let Some(event) = messages.next().await {
|
||||||
|
@ -1655,7 +1662,7 @@ impl Thread {
|
||||||
this.title = Some(title);
|
this.title = Some(title);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
})
|
})
|
||||||
}))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn last_user_message(&self) -> Option<&UserMessage> {
|
fn last_user_message(&self) -> Option<&UserMessage> {
|
||||||
|
@ -2457,18 +2464,15 @@ impl From<UserMessageContent> for acp::ContentBlock {
|
||||||
uri: None,
|
uri: None,
|
||||||
}),
|
}),
|
||||||
UserMessageContent::Mention { uri, content } => {
|
UserMessageContent::Mention { uri, content } => {
|
||||||
acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
acp::ContentBlock::Resource(acp::EmbeddedResource {
|
||||||
uri: uri.to_uri().to_string(),
|
resource: acp::EmbeddedResourceResource::TextResourceContents(
|
||||||
name: uri.name(),
|
acp::TextResourceContents {
|
||||||
|
mime_type: None,
|
||||||
|
text: content,
|
||||||
|
uri: uri.to_uri().to_string(),
|
||||||
|
},
|
||||||
|
),
|
||||||
annotations: None,
|
annotations: None,
|
||||||
description: if content.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(content)
|
|
||||||
},
|
|
||||||
mime_type: None,
|
|
||||||
size: None,
|
|
||||||
title: None,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue