Introduce rating for assistant threads (#26780)

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Agus Zubiaga <hi@aguz.me>
This commit is contained in:
Antonio Scandurra 2025-03-14 15:41:50 +01:00 committed by GitHub
parent c62210b178
commit f68a475eca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 378 additions and 117 deletions

2
Cargo.lock generated
View file

@ -467,6 +467,7 @@ dependencies = [
"fs", "fs",
"futures 0.3.31", "futures 0.3.31",
"fuzzy", "fuzzy",
"git",
"gpui", "gpui",
"heed", "heed",
"html_to_markdown", "html_to_markdown",
@ -496,6 +497,7 @@ dependencies = [
"settings", "settings",
"smol", "smol",
"streaming_diff", "streaming_diff",
"telemetry",
"telemetry_events", "telemetry_events",
"terminal", "terminal",
"terminal_view", "terminal_view",

View file

@ -38,6 +38,7 @@ file_icons.workspace = true
fs.workspace = true fs.workspace = true
futures.workspace = true futures.workspace = true
fuzzy.workspace = true fuzzy.workspace = true
git.workspace = true
gpui.workspace = true gpui.workspace = true
heed.workspace = true heed.workspace = true
html_to_markdown.workspace = true html_to_markdown.workspace = true
@ -65,6 +66,7 @@ serde_json.workspace = true
settings.workspace = true settings.workspace = true
smol.workspace = true smol.workspace = true
streaming_diff.workspace = true streaming_diff.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true telemetry_events.workspace = true
terminal.workspace = true terminal.workspace = true
terminal_view.workspace = true terminal_view.workspace = true

View file

@ -1,6 +1,7 @@
use std::sync::Arc; use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
use std::time::Duration; use crate::thread_store::ThreadStore;
use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::ui::ContextPill;
use collections::HashMap; use collections::HashMap;
use editor::{Editor, MultiBuffer}; use editor::{Editor, MultiBuffer};
use gpui::{ use gpui::{
@ -14,15 +15,13 @@ use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle}; use markdown::{Markdown, MarkdownStyle};
use scripting_tool::{ScriptingTool, ScriptingToolInput}; use scripting_tool::{ScriptingTool, ScriptingToolInput};
use settings::Settings as _; use settings::Settings as _;
use std::sync::Arc;
use std::time::Duration;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::Color;
use ui::{prelude::*, Disclosure, KeyBinding}; use ui::{prelude::*, Disclosure, KeyBinding};
use util::ResultExt as _; use util::ResultExt as _;
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
use crate::thread_store::ThreadStore;
use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::ui::ContextPill;
pub struct ActiveThread { pub struct ActiveThread {
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
@ -498,7 +497,7 @@ impl ActiveThread {
}; };
let thread = self.thread.read(cx); let thread = self.thread.read(cx);
// Get all the data we need from thread before we start using it in closures
let context = thread.context_for_message(message_id); let context = thread.context_for_message(message_id);
let tool_uses = thread.tool_uses_for_message(message_id); let tool_uses = thread.tool_uses_for_message(message_id);
let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id); let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id);
@ -653,28 +652,27 @@ impl ActiveThread {
) )
.child(message_content), .child(message_content),
), ),
Role::Assistant => v_flex() Role::Assistant => {
.id(("message-container", ix)) v_flex()
.child(message_content) .id(("message-container", ix))
.map(|parent| { .child(message_content)
if tool_uses.is_empty() && scripting_tool_uses.is_empty() { .when(
return parent; !tool_uses.is_empty() || !scripting_tool_uses.is_empty(),
} |parent| {
parent.child(
parent.child( v_flex()
v_flex() .children(
.children( tool_uses
tool_uses .into_iter()
.into_iter() .map(|tool_use| self.render_tool_use(tool_use, cx)),
.map(|tool_use| self.render_tool_use(tool_use, cx)), )
.children(scripting_tool_uses.into_iter().map(|tool_use| {
self.render_scripting_tool_use(tool_use, cx)
})),
) )
.children( },
scripting_tool_uses
.into_iter()
.map(|tool_use| self.render_scripting_tool_use(tool_use, cx)),
),
) )
}), }
Role::System => div().id(("message-container", ix)).py_1().px_2().child( Role::System => div().id(("message-container", ix)).py_1().px_2().child(
v_flex() v_flex()
.bg(colors.editor_background) .bg(colors.editor_background)

View file

@ -2,10 +2,10 @@ use assistant_context_editor::SavedContextMetadata;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use gpui::{prelude::*, Entity}; use gpui::{prelude::*, Entity};
use crate::thread_store::{SavedThreadMetadata, ThreadStore}; use crate::thread_store::{SerializedThreadMetadata, ThreadStore};
pub enum HistoryEntry { pub enum HistoryEntry {
Thread(SavedThreadMetadata), Thread(SerializedThreadMetadata),
Context(SavedContextMetadata), Context(SavedContextMetadata),
} }

View file

@ -20,7 +20,8 @@ use ui::{
Tooltip, Tooltip,
}; };
use vim_mode_setting::VimModeSetting; use vim_mode_setting::VimModeSetting;
use workspace::Workspace; use workspace::notifications::{NotificationId, NotifyTaskExt};
use workspace::{Toast, Workspace};
use crate::assistant_model_selector::AssistantModelSelector; use crate::assistant_model_selector::AssistantModelSelector;
use crate::context_picker::{ConfirmBehavior, ContextPicker}; use crate::context_picker::{ConfirmBehavior, ContextPicker};
@ -34,6 +35,7 @@ use crate::{Chat, ChatMode, RemoveAllContext, ToggleContextPicker};
pub struct MessageEditor { pub struct MessageEditor {
thread: Entity<Thread>, thread: Entity<Thread>,
editor: Entity<Editor>, editor: Entity<Editor>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>, context_store: Entity<ContextStore>,
context_strip: Entity<ContextStrip>, context_strip: Entity<ContextStrip>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>, context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
@ -106,6 +108,7 @@ impl MessageEditor {
Self { Self {
thread, thread,
editor: editor.clone(), editor: editor.clone(),
workspace,
context_store, context_store,
context_strip, context_strip,
context_picker_menu_handle, context_picker_menu_handle,
@ -280,6 +283,34 @@ impl MessageEditor {
self.context_strip.focus_handle(cx).focus(window); self.context_strip.focus_handle(cx).focus(window);
} }
} }
fn handle_feedback_click(
&mut self,
is_positive: bool,
window: &mut Window,
cx: &mut Context<Self>,
) {
let workspace = self.workspace.clone();
let report = self
.thread
.update(cx, |thread, cx| thread.report_feedback(is_positive, cx));
cx.spawn(|_, mut cx| async move {
report.await?;
workspace.update(&mut cx, |workspace, cx| {
let message = if is_positive {
"Positive feedback recorded. Thank you!"
} else {
"Negative feedback recorded. Thank you for helping us improve!"
};
struct ThreadFeedback;
let id = NotificationId::unique::<ThreadFeedback>();
workspace.show_toast(Toast::new(id, message).autohide(), cx)
})
})
.detach_and_notify_err(window, cx);
}
} }
impl Focusable for MessageEditor { impl Focusable for MessageEditor {
@ -497,7 +528,45 @@ impl Render for MessageEditor {
.bg(bg_color) .bg(bg_color)
.border_t_1() .border_t_1()
.border_color(cx.theme().colors().border) .border_color(cx.theme().colors().border)
.child(self.context_strip.clone()) .child(
h_flex()
.justify_between()
.child(self.context_strip.clone())
.when(!self.thread.read(cx).is_empty(), |this| {
this.child(
h_flex()
.gap_2()
.child(
IconButton::new(
"feedback-thumbs-up",
IconName::ThumbsUp,
)
.style(ButtonStyle::Subtle)
.icon_size(IconSize::Small)
.tooltip(Tooltip::text("Helpful"))
.on_click(
cx.listener(|this, _, window, cx| {
this.handle_feedback_click(true, window, cx);
}),
),
)
.child(
IconButton::new(
"feedback-thumbs-down",
IconName::ThumbsDown,
)
.style(ButtonStyle::Subtle)
.icon_size(IconSize::Small)
.tooltip(Tooltip::text("Not Helpful"))
.on_click(
cx.listener(|this, _, window, cx| {
this.handle_feedback_click(false, window, cx);
}),
),
),
)
}),
)
.child( .child(
v_flex() v_flex()
.gap_5() .gap_5()

View file

@ -5,7 +5,9 @@ use anyhow::{Context as _, Result};
use assistant_tool::ToolWorkingSet; use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use futures::StreamExt as _; use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
use git;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task}; use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
@ -21,7 +23,9 @@ use util::{post_inc, ResultExt, TryFutureExt as _};
use uuid::Uuid; use uuid::Uuid;
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot}; use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
use crate::thread_store::SavedThread; use crate::thread_store::{
SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState}; use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@ -63,6 +67,27 @@ pub struct Message {
pub text: String, pub text: String,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProjectSnapshot {
pub worktree_snapshots: Vec<WorktreeSnapshot>,
pub unsaved_buffer_paths: Vec<String>,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorktreeSnapshot {
pub worktree_path: String,
pub git_state: Option<GitState>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GitState {
pub remote_url: Option<String>,
pub head_sha: Option<String>,
pub current_branch: Option<String>,
pub diff: Option<String>,
}
/// A thread of conversation with the LLM. /// A thread of conversation with the LLM.
pub struct Thread { pub struct Thread {
id: ThreadId, id: ThreadId,
@ -81,6 +106,7 @@ pub struct Thread {
tool_use: ToolUseState, tool_use: ToolUseState,
scripting_session: Entity<ScriptingSession>, scripting_session: Entity<ScriptingSession>,
scripting_tool_use: ToolUseState, scripting_tool_use: ToolUseState,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
cumulative_token_usage: TokenUsage, cumulative_token_usage: TokenUsage,
} }
@ -91,8 +117,6 @@ impl Thread {
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
Self { Self {
id: ThreadId::new(), id: ThreadId::new(),
updated_at: Utc::now(), updated_at: Utc::now(),
@ -104,43 +128,52 @@ impl Thread {
context_by_message: HashMap::default(), context_by_message: HashMap::default(),
completion_count: 0, completion_count: 0,
pending_completions: Vec::new(), pending_completions: Vec::new(),
project, project: project.clone(),
prompt_builder, prompt_builder,
tools, tools,
tool_use: ToolUseState::new(), tool_use: ToolUseState::new(),
scripting_session, scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
scripting_tool_use: ToolUseState::new(), scripting_tool_use: ToolUseState::new(),
initial_project_snapshot: {
let project_snapshot = Self::project_snapshot(project, cx);
cx.foreground_executor()
.spawn(async move { Some(project_snapshot.await) })
.shared()
},
cumulative_token_usage: TokenUsage::default(), cumulative_token_usage: TokenUsage::default(),
} }
} }
pub fn from_saved( pub fn deserialize(
id: ThreadId, id: ThreadId,
saved: SavedThread, serialized: SerializedThread,
project: Entity<Project>, project: Entity<Project>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let next_message_id = MessageId( let next_message_id = MessageId(
saved serialized
.messages .messages
.last() .last()
.map(|message| message.id.0 + 1) .map(|message| message.id.0 + 1)
.unwrap_or(0), .unwrap_or(0),
); );
let tool_use = let tool_use = ToolUseState::from_serialized_messages(&serialized.messages, |name| {
ToolUseState::from_saved_messages(&saved.messages, |name| name != ScriptingTool::NAME); name != ScriptingTool::NAME
});
let scripting_tool_use = let scripting_tool_use =
ToolUseState::from_saved_messages(&saved.messages, |name| name == ScriptingTool::NAME); ToolUseState::from_serialized_messages(&serialized.messages, |name| {
name == ScriptingTool::NAME
});
let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx)); let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
Self { Self {
id, id,
updated_at: saved.updated_at, updated_at: serialized.updated_at,
summary: Some(saved.summary), summary: Some(serialized.summary),
pending_summary: Task::ready(None), pending_summary: Task::ready(None),
messages: saved messages: serialized
.messages .messages
.into_iter() .into_iter()
.map(|message| Message { .map(|message| Message {
@ -160,6 +193,7 @@ impl Thread {
tool_use, tool_use,
scripting_session, scripting_session,
scripting_tool_use, scripting_tool_use,
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
// TODO: persist token usage? // TODO: persist token usage?
cumulative_token_usage: TokenUsage::default(), cumulative_token_usage: TokenUsage::default(),
} }
@ -349,6 +383,47 @@ impl Thread {
text text
} }
/// Serializes this thread into a format for storage or telemetry.
pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
let initial_project_snapshot = self.initial_project_snapshot.clone();
cx.spawn(|this, cx| async move {
let initial_project_snapshot = initial_project_snapshot.await;
this.read_with(&cx, |this, _| SerializedThread {
summary: this.summary_or_default(),
updated_at: this.updated_at(),
messages: this
.messages()
.map(|message| SerializedMessage {
id: message.id,
role: message.role,
text: message.text.clone(),
tool_uses: this
.tool_uses_for_message(message.id)
.into_iter()
.chain(this.scripting_tool_uses_for_message(message.id))
.map(|tool_use| SerializedToolUse {
id: tool_use.id,
name: tool_use.name,
input: tool_use.input,
})
.collect(),
tool_results: this
.tool_results_for_message(message.id)
.into_iter()
.chain(this.scripting_tool_results_for_message(message.id))
.map(|tool_result| SerializedToolResult {
tool_use_id: tool_result.tool_use_id.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
})
.collect(),
})
.collect(),
initial_project_snapshot,
})
})
}
pub fn send_to_model( pub fn send_to_model(
&mut self, &mut self,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
@ -807,6 +882,133 @@ impl Thread {
} }
} }
/// Reports feedback about the thread and stores it in our telemetry backend.
pub fn report_feedback(&self, is_positive: bool, cx: &mut Context<Self>) -> Task<Result<()>> {
let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
let serialized_thread = self.serialize(cx);
let thread_id = self.id().clone();
let client = self.project.read(cx).client();
cx.background_spawn(async move {
let final_project_snapshot = final_project_snapshot.await;
let serialized_thread = serialized_thread.await?;
let thread_data =
serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
let rating = if is_positive { "positive" } else { "negative" };
telemetry::event!(
"Assistant Thread Rated",
rating,
thread_id,
thread_data,
final_project_snapshot
);
client.telemetry().flush_events();
Ok(())
})
}
/// Create a snapshot of the current project state including git information and unsaved buffers.
fn project_snapshot(
project: Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Arc<ProjectSnapshot>> {
let worktree_snapshots: Vec<_> = project
.read(cx)
.visible_worktrees(cx)
.map(|worktree| Self::worktree_snapshot(worktree, cx))
.collect();
cx.spawn(move |_, cx| async move {
let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
let mut unsaved_buffers = Vec::new();
cx.update(|app_cx| {
let buffer_store = project.read(app_cx).buffer_store();
for buffer_handle in buffer_store.read(app_cx).buffers() {
let buffer = buffer_handle.read(app_cx);
if buffer.is_dirty() {
if let Some(file) = buffer.file() {
let path = file.path().to_string_lossy().to_string();
unsaved_buffers.push(path);
}
}
}
})
.ok();
Arc::new(ProjectSnapshot {
worktree_snapshots,
unsaved_buffer_paths: unsaved_buffers,
timestamp: Utc::now(),
})
})
}
fn worktree_snapshot(worktree: Entity<project::Worktree>, cx: &App) -> Task<WorktreeSnapshot> {
cx.spawn(move |cx| async move {
// Get worktree path and snapshot
let worktree_info = cx.update(|app_cx| {
let worktree = worktree.read(app_cx);
let path = worktree.abs_path().to_string_lossy().to_string();
let snapshot = worktree.snapshot();
(path, snapshot)
});
let Ok((worktree_path, snapshot)) = worktree_info else {
return WorktreeSnapshot {
worktree_path: String::new(),
git_state: None,
};
};
// Extract git information
let git_state = match snapshot.repositories().first() {
None => None,
Some(repo_entry) => {
// Get branch information
let current_branch = repo_entry.branch().map(|branch| branch.name.to_string());
// Get repository info
let repo_result = worktree.read_with(&cx, |worktree, _cx| {
if let project::Worktree::Local(local_worktree) = &worktree {
local_worktree.get_local_repo(repo_entry).map(|local_repo| {
let repo = local_repo.repo();
(repo.remote_url("origin"), repo.head_sha(), repo.clone())
})
} else {
None
}
});
match repo_result {
Ok(Some((remote_url, head_sha, repository))) => {
// Get diff asynchronously
let diff = repository
.diff(git::repository::DiffType::HeadToWorktree, cx)
.await
.ok();
Some(GitState {
remote_url,
head_sha,
current_branch,
diff,
})
}
Err(_) | Ok(None) => None,
}
}
};
WorktreeSnapshot {
worktree_path,
git_state,
}
})
}
pub fn to_markdown(&self) -> Result<String> { pub fn to_markdown(&self) -> Result<String> {
let mut markdown = Vec::new(); let mut markdown = Vec::new();

View file

@ -7,7 +7,7 @@ use time::{OffsetDateTime, UtcOffset};
use ui::{prelude::*, IconButtonShape, ListItem, ListItemSpacing, Tooltip}; use ui::{prelude::*, IconButtonShape, ListItem, ListItemSpacing, Tooltip};
use crate::history_store::{HistoryEntry, HistoryStore}; use crate::history_store::{HistoryEntry, HistoryStore};
use crate::thread_store::SavedThreadMetadata; use crate::thread_store::SerializedThreadMetadata;
use crate::{AssistantPanel, RemoveSelectedThread}; use crate::{AssistantPanel, RemoveSelectedThread};
pub struct ThreadHistory { pub struct ThreadHistory {
@ -221,14 +221,14 @@ impl Render for ThreadHistory {
#[derive(IntoElement)] #[derive(IntoElement)]
pub struct PastThread { pub struct PastThread {
thread: SavedThreadMetadata, thread: SerializedThreadMetadata,
assistant_panel: WeakEntity<AssistantPanel>, assistant_panel: WeakEntity<AssistantPanel>,
selected: bool, selected: bool,
} }
impl PastThread { impl PastThread {
pub fn new( pub fn new(
thread: SavedThreadMetadata, thread: SerializedThreadMetadata,
assistant_panel: WeakEntity<AssistantPanel>, assistant_panel: WeakEntity<AssistantPanel>,
selected: bool, selected: bool,
) -> Self { ) -> Self {

View file

@ -20,7 +20,7 @@ use prompt_store::PromptBuilder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use util::ResultExt as _; use util::ResultExt as _;
use crate::thread::{MessageId, Thread, ThreadId}; use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadId};
pub fn init(cx: &mut App) { pub fn init(cx: &mut App) {
ThreadsDatabase::init(cx); ThreadsDatabase::init(cx);
@ -32,7 +32,7 @@ pub struct ThreadStore {
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
context_server_manager: Entity<ContextServerManager>, context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>, context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SavedThreadMetadata>, threads: Vec<SerializedThreadMetadata>,
} }
impl ThreadStore { impl ThreadStore {
@ -70,13 +70,13 @@ impl ThreadStore {
self.threads.len() self.threads.len()
} }
pub fn threads(&self) -> Vec<SavedThreadMetadata> { pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
let mut threads = self.threads.iter().cloned().collect::<Vec<_>>(); let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at)); threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
threads threads
} }
pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> { pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
self.threads().into_iter().take(limit).collect() self.threads().into_iter().take(limit).collect()
} }
@ -107,7 +107,7 @@ impl ThreadStore {
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
cx.new(|cx| { cx.new(|cx| {
Thread::from_saved( Thread::deserialize(
id.clone(), id.clone(),
thread, thread,
this.project.clone(), this.project.clone(),
@ -121,53 +121,14 @@ impl ThreadStore {
} }
pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> { pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
let (metadata, thread) = thread.update(cx, |thread, _cx| { let (metadata, serialized_thread) =
let id = thread.id().clone(); thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
let thread = SavedThread {
summary: thread.summary_or_default(),
updated_at: thread.updated_at(),
messages: thread
.messages()
.map(|message| {
let all_tool_uses = thread
.tool_uses_for_message(message.id)
.into_iter()
.chain(thread.scripting_tool_uses_for_message(message.id))
.map(|tool_use| SavedToolUse {
id: tool_use.id,
name: tool_use.name,
input: tool_use.input,
})
.collect();
let all_tool_results = thread
.tool_results_for_message(message.id)
.into_iter()
.chain(thread.scripting_tool_results_for_message(message.id))
.map(|tool_result| SavedToolResult {
tool_use_id: tool_result.tool_use_id.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
})
.collect();
SavedMessage {
id: message.id,
role: message.role,
text: message.text.clone(),
tool_uses: all_tool_uses,
tool_results: all_tool_results,
}
})
.collect(),
};
(id, thread)
});
let database_future = ThreadsDatabase::global_future(cx); let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let serialized_thread = serialized_thread.await?;
let database = database_future.await.map_err(|err| anyhow!(err))?; let database = database_future.await.map_err(|err| anyhow!(err))?;
database.save_thread(metadata, thread).await?; database.save_thread(metadata, serialized_thread).await?;
this.update(&mut cx, |this, cx| this.reload(cx))?.await this.update(&mut cx, |this, cx| this.reload(cx))?.await
}) })
@ -270,39 +231,41 @@ impl ThreadStore {
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SavedThreadMetadata { pub struct SerializedThreadMetadata {
pub id: ThreadId, pub id: ThreadId,
pub summary: SharedString, pub summary: SharedString,
pub updated_at: DateTime<Utc>, pub updated_at: DateTime<Utc>,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct SavedThread { pub struct SerializedThread {
pub summary: SharedString, pub summary: SharedString,
pub updated_at: DateTime<Utc>, pub updated_at: DateTime<Utc>,
pub messages: Vec<SavedMessage>, pub messages: Vec<SerializedMessage>,
#[serde(default)]
pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct SavedMessage { pub struct SerializedMessage {
pub id: MessageId, pub id: MessageId,
pub role: Role, pub role: Role,
pub text: String, pub text: String,
#[serde(default)] #[serde(default)]
pub tool_uses: Vec<SavedToolUse>, pub tool_uses: Vec<SerializedToolUse>,
#[serde(default)] #[serde(default)]
pub tool_results: Vec<SavedToolResult>, pub tool_results: Vec<SerializedToolResult>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct SavedToolUse { pub struct SerializedToolUse {
pub id: LanguageModelToolUseId, pub id: LanguageModelToolUseId,
pub name: SharedString, pub name: SharedString,
pub input: serde_json::Value, pub input: serde_json::Value,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct SavedToolResult { pub struct SerializedToolResult {
pub tool_use_id: LanguageModelToolUseId, pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool, pub is_error: bool,
pub content: Arc<str>, pub content: Arc<str>,
@ -317,7 +280,7 @@ impl Global for GlobalThreadsDatabase {}
pub(crate) struct ThreadsDatabase { pub(crate) struct ThreadsDatabase {
executor: BackgroundExecutor, executor: BackgroundExecutor,
env: heed::Env, env: heed::Env,
threads: Database<SerdeBincode<ThreadId>, SerdeJson<SavedThread>>, threads: Database<SerdeBincode<ThreadId>, SerdeJson<SerializedThread>>,
} }
impl ThreadsDatabase { impl ThreadsDatabase {
@ -364,7 +327,7 @@ impl ThreadsDatabase {
}) })
} }
pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> { pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
let env = self.env.clone(); let env = self.env.clone();
let threads = self.threads; let threads = self.threads;
@ -373,7 +336,7 @@ impl ThreadsDatabase {
let mut iter = threads.iter(&txn)?; let mut iter = threads.iter(&txn)?;
let mut threads = Vec::new(); let mut threads = Vec::new();
while let Some((key, value)) = iter.next().transpose()? { while let Some((key, value)) = iter.next().transpose()? {
threads.push(SavedThreadMetadata { threads.push(SerializedThreadMetadata {
id: key, id: key,
summary: value.summary, summary: value.summary,
updated_at: value.updated_at, updated_at: value.updated_at,
@ -384,7 +347,7 @@ impl ThreadsDatabase {
}) })
} }
pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> { pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
let env = self.env.clone(); let env = self.env.clone();
let threads = self.threads; let threads = self.threads;
@ -395,7 +358,7 @@ impl ThreadsDatabase {
}) })
} }
pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> { pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
let env = self.env.clone(); let env = self.env.clone();
let threads = self.threads; let threads = self.threads;

View file

@ -11,7 +11,7 @@ use language_model::{
}; };
use crate::thread::MessageId; use crate::thread::MessageId;
use crate::thread_store::SavedMessage; use crate::thread_store::SerializedMessage;
#[derive(Debug)] #[derive(Debug)]
pub struct ToolUse { pub struct ToolUse {
@ -46,11 +46,11 @@ impl ToolUseState {
} }
} }
/// Constructs a [`ToolUseState`] from the given list of [`SavedMessage`]s. /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
/// ///
/// Accepts a function to filter the tools that should be used to populate the state. /// Accepts a function to filter the tools that should be used to populate the state.
pub fn from_saved_messages( pub fn from_serialized_messages(
messages: &[SavedMessage], messages: &[SerializedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool, mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self { ) -> Self {
let mut this = Self::new(); let mut this = Self::new();

View file

@ -660,6 +660,10 @@ fn for_snowflake(
e.event_type.clone(), e.event_type.clone(),
serde_json::to_value(&e.event_properties).unwrap(), serde_json::to_value(&e.event_properties).unwrap(),
), ),
Event::AssistantThreadFeedback(e) => (
"Assistant Feedback".to_string(),
serde_json::to_value(&e).unwrap(),
),
}; };
if let serde_json::Value::Object(ref mut map) = event_properties { if let serde_json::Value::Object(ref mut map) = event_properties {

View file

@ -1046,7 +1046,7 @@ impl App {
&self.foreground_executor &self.foreground_executor
} }
/// Spawns the future returned by the given function on the thread pool. The closure will be invoked /// Spawns the future returned by the given function on the main thread. The closure will be invoked
/// with [AsyncApp], which allows the application state to be accessed across await points. /// with [AsyncApp], which allows the application state to be accessed across await points.
#[track_caller] #[track_caller]
pub fn spawn<Fut, R>(&self, f: impl FnOnce(AsyncApp) -> Fut) -> Task<R> pub fn spawn<Fut, R>(&self, f: impl FnOnce(AsyncApp) -> Fut) -> Task<R>

View file

@ -97,6 +97,7 @@ pub enum Event {
InlineCompletionRating(InlineCompletionRatingEvent), InlineCompletionRating(InlineCompletionRatingEvent),
Call(CallEvent), Call(CallEvent),
Assistant(AssistantEvent), Assistant(AssistantEvent),
AssistantThreadFeedback(AssistantThreadFeedbackEvent),
Cpu(CpuEvent), Cpu(CpuEvent),
Memory(MemoryEvent), Memory(MemoryEvent),
App(AppEvent), App(AppEvent),
@ -230,6 +231,26 @@ pub struct ReplEvent {
pub repl_session_id: String, pub repl_session_id: String,
} }
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum ThreadFeedbackRating {
Positive,
Negative,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AssistantThreadFeedbackEvent {
/// Unique identifier for the thread
pub thread_id: String,
/// The feedback rating (thumbs up or thumbs down)
pub rating: ThreadFeedbackRating,
/// The serialized thread data containing messages, tool calls, etc.
pub thread_data: serde_json::Value,
/// The initial project snapshot taken when the thread was created
pub initial_project_snapshot: serde_json::Value,
/// The final project snapshot taken when the thread was first saved
pub final_project_snapshot: serde_json::Value,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct BacktraceFrame { pub struct BacktraceFrame {
pub ip: usize, pub ip: usize,