use std::io::Write; use std::ops::Range; use std::sync::Arc; use std::time::Instant; use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::HashMap; use editor::display_map::CreaseMetadata; use feature_flags::{self, FeatureFlagAppExt}; use futures::future::Shared; use futures::{FutureExt, StreamExt as _}; use git::repository::DiffType; use gpui::{ AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, }; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel, StopReason, TokenUsage, }; use postage::stream::Stream as _; use project::Project; use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState}; use prompt_store::{ModelContext, PromptBuilder}; use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; use thiserror::Error; use ui::Window; use util::{ResultExt as _, post_inc}; use uuid::Uuid; use zed_llm_client::{CompletionIntent, CompletionRequestStatus}; use crate::ThreadStore; use crate::agent_profile::AgentProfile; use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext}; use crate::thread_store::{ SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext, }; use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState}; #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, )] pub struct ThreadId(Arc); impl ThreadId { pub fn new() -> Self { Self(Uuid::new_v4().to_string().into()) } } impl std::fmt::Display for ThreadId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } impl From<&str> for ThreadId { fn from(value: &str) -> Self { Self(value.into()) } } /// The ID of the user prompt that initiated a request. /// /// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key). #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] pub struct PromptId(Arc); impl PromptId { pub fn new() -> Self { Self(Uuid::new_v4().to_string().into()) } } impl std::fmt::Display for PromptId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] pub struct MessageId(pub(crate) usize); impl MessageId { fn post_inc(&mut self) -> Self { Self(post_inc(&mut self.0)) } } /// Stored information that can be used to resurrect a context crease when creating an editor for a past message. #[derive(Clone, Debug)] pub struct MessageCrease { pub range: Range, pub metadata: CreaseMetadata, /// None for a deserialized message, Some otherwise. pub context: Option, } /// A message in a [`Thread`]. #[derive(Debug, Clone)] pub struct Message { pub id: MessageId, pub role: Role, pub segments: Vec, pub loaded_context: LoadedContext, pub creases: Vec, pub is_hidden: bool, } impl Message { /// Returns whether the message contains any meaningful text that should be displayed /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`]) pub fn should_display_content(&self) -> bool { self.segments.iter().all(|segment| segment.should_display()) } pub fn push_thinking(&mut self, text: &str, signature: Option) { if let Some(MessageSegment::Thinking { text: segment, signature: current_signature, }) = self.segments.last_mut() { if let Some(signature) = signature { *current_signature = Some(signature); } segment.push_str(text); } else { self.segments.push(MessageSegment::Thinking { text: text.to_string(), signature, }); } } pub fn push_text(&mut self, text: &str) { if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() { segment.push_str(text); } else { self.segments.push(MessageSegment::Text(text.to_string())); } } pub fn to_string(&self) -> String { let mut result = String::new(); if !self.loaded_context.text.is_empty() { result.push_str(&self.loaded_context.text); } for segment in &self.segments { match segment { MessageSegment::Text(text) => result.push_str(text), MessageSegment::Thinking { text, .. } => { result.push_str("\n"); result.push_str(text); result.push_str("\n"); } MessageSegment::RedactedThinking(_) => {} } } result } } #[derive(Debug, Clone, PartialEq, Eq)] pub enum MessageSegment { Text(String), Thinking { text: String, signature: Option, }, RedactedThinking(Vec), } impl MessageSegment { pub fn should_display(&self) -> bool { match self { Self::Text(text) => text.is_empty(), Self::Thinking { text, .. } => text.is_empty(), Self::RedactedThinking(_) => false, } } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ProjectSnapshot { pub worktree_snapshots: Vec, pub unsaved_buffer_paths: Vec, pub timestamp: DateTime, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct WorktreeSnapshot { pub worktree_path: String, pub git_state: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct GitState { pub remote_url: Option, pub head_sha: Option, pub current_branch: Option, pub diff: Option, } #[derive(Clone, Debug)] pub struct ThreadCheckpoint { message_id: MessageId, git_checkpoint: GitStoreCheckpoint, } #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum ThreadFeedback { Positive, Negative, } pub enum LastRestoreCheckpoint { Pending { message_id: MessageId, }, Error { message_id: MessageId, error: String, }, } impl LastRestoreCheckpoint { pub fn message_id(&self) -> MessageId { match self { LastRestoreCheckpoint::Pending { message_id } => *message_id, LastRestoreCheckpoint::Error { message_id, .. } => *message_id, } } } #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] pub enum DetailedSummaryState { #[default] NotGenerated, Generating { message_id: MessageId, }, Generated { text: SharedString, message_id: MessageId, }, } impl DetailedSummaryState { fn text(&self) -> Option { if let Self::Generated { text, .. } = self { Some(text.clone()) } else { None } } } #[derive(Default, Debug)] pub struct TotalTokenUsage { pub total: u64, pub max: u64, } impl TotalTokenUsage { pub fn ratio(&self) -> TokenUsageRatio { #[cfg(debug_assertions)] let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") .unwrap_or("0.8".to_string()) .parse() .unwrap(); #[cfg(not(debug_assertions))] let warning_threshold: f32 = 0.8; // When the maximum is unknown because there is no selected model, // avoid showing the token limit warning. if self.max == 0 { TokenUsageRatio::Normal } else if self.total >= self.max { TokenUsageRatio::Exceeded } else if self.total as f32 / self.max as f32 >= warning_threshold { TokenUsageRatio::Warning } else { TokenUsageRatio::Normal } } pub fn add(&self, tokens: u64) -> TotalTokenUsage { TotalTokenUsage { total: self.total + tokens, max: self.max, } } } #[derive(Debug, Default, PartialEq, Eq)] pub enum TokenUsageRatio { #[default] Normal, Warning, Exceeded, } #[derive(Debug, Clone, Copy)] pub enum QueueState { Sending, Queued { position: usize }, Started, } /// A thread of conversation with the LLM. pub struct Thread { id: ThreadId, updated_at: DateTime, summary: ThreadSummary, pending_summary: Task>, detailed_summary_task: Task>, detailed_summary_tx: postage::watch::Sender, detailed_summary_rx: postage::watch::Receiver, completion_mode: agent_settings::CompletionMode, messages: Vec, next_message_id: MessageId, last_prompt_id: PromptId, project_context: SharedProjectContext, checkpoints_by_message: HashMap, completion_count: usize, pending_completions: Vec, project: Entity, prompt_builder: Arc, tools: Entity, tool_use: ToolUseState, action_log: Entity, last_restore_checkpoint: Option, pending_checkpoint: Option, initial_project_snapshot: Shared>>>, request_token_usage: Vec, cumulative_token_usage: TokenUsage, exceeded_window_error: Option, last_usage: Option, tool_use_limit_reached: bool, feedback: Option, message_feedback: HashMap, last_auto_capture_at: Option, last_received_chunk_at: Option, request_callback: Option< Box])>, >, remaining_turns: u32, configured_model: Option, profile: AgentProfile, } #[derive(Clone, Debug, PartialEq, Eq)] pub enum ThreadSummary { Pending, Generating, Ready(SharedString), Error, } impl ThreadSummary { pub const DEFAULT: SharedString = SharedString::new_static("New Thread"); pub fn or_default(&self) -> SharedString { self.unwrap_or(Self::DEFAULT) } pub fn unwrap_or(&self, message: impl Into) -> SharedString { self.ready().unwrap_or_else(|| message.into()) } pub fn ready(&self) -> Option { match self { ThreadSummary::Ready(summary) => Some(summary.clone()), ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None, } } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ExceededWindowError { /// Model used when last message exceeded context window model_id: LanguageModelId, /// Token count including last message token_count: u64, } impl Thread { pub fn new( project: Entity, tools: Entity, prompt_builder: Arc, system_prompt: SharedProjectContext, cx: &mut Context, ) -> Self { let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel(); let configured_model = LanguageModelRegistry::read_global(cx).default_model(); let profile_id = AgentSettings::get_global(cx).default_profile.clone(); Self { id: ThreadId::new(), updated_at: Utc::now(), summary: ThreadSummary::Pending, pending_summary: Task::ready(None), detailed_summary_task: Task::ready(None), detailed_summary_tx, detailed_summary_rx, completion_mode: AgentSettings::get_global(cx).preferred_completion_mode, messages: Vec::new(), next_message_id: MessageId(0), last_prompt_id: PromptId::new(), project_context: system_prompt, checkpoints_by_message: HashMap::default(), completion_count: 0, pending_completions: Vec::new(), project: project.clone(), prompt_builder, tools: tools.clone(), last_restore_checkpoint: None, pending_checkpoint: None, tool_use: ToolUseState::new(tools.clone()), action_log: cx.new(|_| ActionLog::new(project.clone())), initial_project_snapshot: { let project_snapshot = Self::project_snapshot(project, cx); cx.foreground_executor() .spawn(async move { Some(project_snapshot.await) }) .shared() }, request_token_usage: Vec::new(), cumulative_token_usage: TokenUsage::default(), exceeded_window_error: None, last_usage: None, tool_use_limit_reached: false, feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, last_received_chunk_at: None, request_callback: None, remaining_turns: u32::MAX, configured_model, profile: AgentProfile::new(profile_id, tools), } } pub fn deserialize( id: ThreadId, serialized: SerializedThread, project: Entity, tools: Entity, prompt_builder: Arc, project_context: SharedProjectContext, window: Option<&mut Window>, // None in headless mode cx: &mut Context, ) -> Self { let next_message_id = MessageId( serialized .messages .last() .map(|message| message.id.0 + 1) .unwrap_or(0), ); let tool_use = ToolUseState::from_serialized_messages( tools.clone(), &serialized.messages, project.clone(), window, cx, ); let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel_with(serialized.detailed_summary_state); let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| { serialized .model .and_then(|model| { let model = SelectedModel { provider: model.provider.clone().into(), model: model.model.clone().into(), }; registry.select_model(&model, cx) }) .or_else(|| registry.default_model()) }); let completion_mode = serialized .completion_mode .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode); let profile_id = serialized .profile .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); Self { id, updated_at: serialized.updated_at, summary: ThreadSummary::Ready(serialized.summary), pending_summary: Task::ready(None), detailed_summary_task: Task::ready(None), detailed_summary_tx, detailed_summary_rx, completion_mode, messages: serialized .messages .into_iter() .map(|message| Message { id: message.id, role: message.role, segments: message .segments .into_iter() .map(|segment| match segment { SerializedMessageSegment::Text { text } => MessageSegment::Text(text), SerializedMessageSegment::Thinking { text, signature } => { MessageSegment::Thinking { text, signature } } SerializedMessageSegment::RedactedThinking { data } => { MessageSegment::RedactedThinking(data) } }) .collect(), loaded_context: LoadedContext { contexts: Vec::new(), text: message.context, images: Vec::new(), }, creases: message .creases .into_iter() .map(|crease| MessageCrease { range: crease.start..crease.end, metadata: CreaseMetadata { icon_path: crease.icon_path, label: crease.label, }, context: None, }) .collect(), is_hidden: message.is_hidden, }) .collect(), next_message_id, last_prompt_id: PromptId::new(), project_context, checkpoints_by_message: HashMap::default(), completion_count: 0, pending_completions: Vec::new(), last_restore_checkpoint: None, pending_checkpoint: None, project: project.clone(), prompt_builder, tools: tools.clone(), tool_use, action_log: cx.new(|_| ActionLog::new(project)), initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), request_token_usage: serialized.request_token_usage, cumulative_token_usage: serialized.cumulative_token_usage, exceeded_window_error: None, last_usage: None, tool_use_limit_reached: serialized.tool_use_limit_reached, feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, last_received_chunk_at: None, request_callback: None, remaining_turns: u32::MAX, configured_model, profile: AgentProfile::new(profile_id, tools), } } pub fn set_request_callback( &mut self, callback: impl 'static + FnMut(&LanguageModelRequest, &[Result]), ) { self.request_callback = Some(Box::new(callback)); } pub fn id(&self) -> &ThreadId { &self.id } pub fn profile(&self) -> &AgentProfile { &self.profile } pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context) { if &id != self.profile.id() { self.profile = AgentProfile::new(id, self.tools.clone()); cx.emit(ThreadEvent::ProfileChanged); } } pub fn is_empty(&self) -> bool { self.messages.is_empty() } pub fn updated_at(&self) -> DateTime { self.updated_at } pub fn touch_updated_at(&mut self) { self.updated_at = Utc::now(); } pub fn advance_prompt_id(&mut self) { self.last_prompt_id = PromptId::new(); } pub fn project_context(&self) -> SharedProjectContext { self.project_context.clone() } pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option { if self.configured_model.is_none() { self.configured_model = LanguageModelRegistry::read_global(cx).default_model(); } self.configured_model.clone() } pub fn configured_model(&self) -> Option { self.configured_model.clone() } pub fn set_configured_model(&mut self, model: Option, cx: &mut Context) { self.configured_model = model; cx.notify(); } pub fn summary(&self) -> &ThreadSummary { &self.summary } pub fn set_summary(&mut self, new_summary: impl Into, cx: &mut Context) { let current_summary = match &self.summary { ThreadSummary::Pending | ThreadSummary::Generating => return, ThreadSummary::Ready(summary) => summary, ThreadSummary::Error => &ThreadSummary::DEFAULT, }; let mut new_summary = new_summary.into(); if new_summary.is_empty() { new_summary = ThreadSummary::DEFAULT; } if current_summary != &new_summary { self.summary = ThreadSummary::Ready(new_summary); cx.emit(ThreadEvent::SummaryChanged); } } pub fn completion_mode(&self) -> CompletionMode { self.completion_mode } pub fn set_completion_mode(&mut self, mode: CompletionMode) { self.completion_mode = mode; } pub fn message(&self, id: MessageId) -> Option<&Message> { let index = self .messages .binary_search_by(|message| message.id.cmp(&id)) .ok()?; self.messages.get(index) } pub fn messages(&self) -> impl ExactSizeIterator { self.messages.iter() } pub fn is_generating(&self) -> bool { !self.pending_completions.is_empty() || !self.all_tools_finished() } /// Indicates whether streaming of language model events is stale. /// When `is_generating()` is false, this method returns `None`. pub fn is_generation_stale(&self) -> Option { const STALE_THRESHOLD: u128 = 250; self.last_received_chunk_at .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD) } fn received_chunk(&mut self) { self.last_received_chunk_at = Some(Instant::now()); } pub fn queue_state(&self) -> Option { self.pending_completions .first() .map(|pending_completion| pending_completion.queue_state) } pub fn tools(&self) -> &Entity { &self.tools } pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> { self.tool_use .pending_tool_uses() .into_iter() .find(|tool_use| &tool_use.id == id) } pub fn tools_needing_confirmation(&self) -> impl Iterator { self.tool_use .pending_tool_uses() .into_iter() .filter(|tool_use| tool_use.status.needs_confirmation()) } pub fn has_pending_tool_uses(&self) -> bool { !self.tool_use.pending_tool_uses().is_empty() } pub fn checkpoint_for_message(&self, id: MessageId) -> Option { self.checkpoints_by_message.get(&id).cloned() } pub fn restore_checkpoint( &mut self, checkpoint: ThreadCheckpoint, cx: &mut Context, ) -> Task> { self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending { message_id: checkpoint.message_id, }); cx.emit(ThreadEvent::CheckpointChanged); cx.notify(); let git_store = self.project().read(cx).git_store().clone(); let restore = git_store.update(cx, |git_store, cx| { git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx) }); cx.spawn(async move |this, cx| { let result = restore.await; this.update(cx, |this, cx| { if let Err(err) = result.as_ref() { this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error { message_id: checkpoint.message_id, error: err.to_string(), }); } else { this.truncate(checkpoint.message_id, cx); this.last_restore_checkpoint = None; } this.pending_checkpoint = None; cx.emit(ThreadEvent::CheckpointChanged); cx.notify(); })?; result }) } fn finalize_pending_checkpoint(&mut self, cx: &mut Context) { let pending_checkpoint = if self.is_generating() { return; } else if let Some(checkpoint) = self.pending_checkpoint.take() { checkpoint } else { return; }; self.finalize_checkpoint(pending_checkpoint, cx); } fn finalize_checkpoint( &mut self, pending_checkpoint: ThreadCheckpoint, cx: &mut Context, ) { let git_store = self.project.read(cx).git_store().clone(); let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx)); cx.spawn(async move |this, cx| match final_checkpoint.await { Ok(final_checkpoint) => { let equal = git_store .update(cx, |store, cx| { store.compare_checkpoints( pending_checkpoint.git_checkpoint.clone(), final_checkpoint.clone(), cx, ) })? .await .unwrap_or(false); if !equal { this.update(cx, |this, cx| { this.insert_checkpoint(pending_checkpoint, cx) })?; } Ok(()) } Err(_) => this.update(cx, |this, cx| { this.insert_checkpoint(pending_checkpoint, cx) }), }) .detach(); } fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context) { self.checkpoints_by_message .insert(checkpoint.message_id, checkpoint); cx.emit(ThreadEvent::CheckpointChanged); cx.notify(); } pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> { self.last_restore_checkpoint.as_ref() } pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context) { let Some(message_ix) = self .messages .iter() .rposition(|message| message.id == message_id) else { return; }; for deleted_message in self.messages.drain(message_ix..) { self.checkpoints_by_message.remove(&deleted_message.id); } cx.notify(); } pub fn context_for_message(&self, id: MessageId) -> impl Iterator { self.messages .iter() .find(|message| message.id == id) .into_iter() .flat_map(|message| message.loaded_context.contexts.iter()) } pub fn is_turn_end(&self, ix: usize) -> bool { if self.messages.is_empty() { return false; } if !self.is_generating() && ix == self.messages.len() - 1 { return true; } let Some(message) = self.messages.get(ix) else { return false; }; if message.role != Role::Assistant { return false; } self.messages .get(ix + 1) .and_then(|message| { self.message(message.id) .map(|next_message| next_message.role == Role::User && !next_message.is_hidden) }) .unwrap_or(false) } pub fn last_usage(&self) -> Option { self.last_usage } pub fn tool_use_limit_reached(&self) -> bool { self.tool_use_limit_reached } /// Returns whether all of the tool uses have finished running. pub fn all_tools_finished(&self) -> bool { // If the only pending tool uses left are the ones with errors, then // that means that we've finished running all of the pending tools. self.tool_use .pending_tool_uses() .iter() .all(|pending_tool_use| pending_tool_use.status.is_error()) } /// Returns whether any pending tool uses may perform edits pub fn has_pending_edit_tool_uses(&self) -> bool { self.tool_use .pending_tool_uses() .iter() .filter(|pending_tool_use| !pending_tool_use.status.is_error()) .any(|pending_tool_use| pending_tool_use.may_perform_edits) } pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { self.tool_use.tool_uses_for_message(id, cx) } pub fn tool_results_for_message( &self, assistant_message_id: MessageId, ) -> Vec<&LanguageModelToolResult> { self.tool_use.tool_results_for_message(assistant_message_id) } pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> { self.tool_use.tool_result(id) } pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc> { match &self.tool_use.tool_result(id)?.content { LanguageModelToolResultContent::Text(text) => Some(text), LanguageModelToolResultContent::Image(_) => { // TODO: We should display image None } } } pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option { self.tool_use.tool_result_card(id).cloned() } /// Return tools that are both enabled and supported by the model pub fn available_tools( &self, cx: &App, model: Arc, ) -> Vec { if model.supports_tools() { self.profile .enabled_tools(cx) .into_iter() .filter_map(|tool| { // Skip tools that cannot be supported let input_schema = tool.input_schema(model.tool_input_format()).ok()?; Some(LanguageModelRequestTool { name: tool.name(), description: tool.description(), input_schema, }) }) .collect() } else { Vec::default() } } pub fn insert_user_message( &mut self, text: impl Into, loaded_context: ContextLoadResult, git_checkpoint: Option, creases: Vec, cx: &mut Context, ) -> MessageId { if !loaded_context.referenced_buffers.is_empty() { self.action_log.update(cx, |log, cx| { for buffer in loaded_context.referenced_buffers { log.buffer_read(buffer, cx); } }); } let message_id = self.insert_message( Role::User, vec![MessageSegment::Text(text.into())], loaded_context.loaded_context, creases, false, cx, ); if let Some(git_checkpoint) = git_checkpoint { self.pending_checkpoint = Some(ThreadCheckpoint { message_id, git_checkpoint, }); } self.auto_capture_telemetry(cx); message_id } pub fn insert_invisible_continue_message(&mut self, cx: &mut Context) -> MessageId { let id = self.insert_message( Role::User, vec![MessageSegment::Text("Continue where you left off".into())], LoadedContext::default(), vec![], true, cx, ); self.pending_checkpoint = None; id } pub fn insert_assistant_message( &mut self, segments: Vec, cx: &mut Context, ) -> MessageId { self.insert_message( Role::Assistant, segments, LoadedContext::default(), Vec::new(), false, cx, ) } pub fn insert_message( &mut self, role: Role, segments: Vec, loaded_context: LoadedContext, creases: Vec, is_hidden: bool, cx: &mut Context, ) -> MessageId { let id = self.next_message_id.post_inc(); self.messages.push(Message { id, role, segments, loaded_context, creases, is_hidden, }); self.touch_updated_at(); cx.emit(ThreadEvent::MessageAdded(id)); id } pub fn edit_message( &mut self, id: MessageId, new_role: Role, new_segments: Vec, creases: Vec, loaded_context: Option, checkpoint: Option, cx: &mut Context, ) -> bool { let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else { return false; }; message.role = new_role; message.segments = new_segments; message.creases = creases; if let Some(context) = loaded_context { message.loaded_context = context; } if let Some(git_checkpoint) = checkpoint { self.checkpoints_by_message.insert( id, ThreadCheckpoint { message_id: id, git_checkpoint, }, ); } self.touch_updated_at(); cx.emit(ThreadEvent::MessageEdited(id)); true } pub fn delete_message(&mut self, id: MessageId, cx: &mut Context) -> bool { let Some(index) = self.messages.iter().position(|message| message.id == id) else { return false; }; self.messages.remove(index); self.touch_updated_at(); cx.emit(ThreadEvent::MessageDeleted(id)); true } /// Returns the representation of this [`Thread`] in a textual form. /// /// This is the representation we use when attaching a thread as context to another thread. pub fn text(&self) -> String { let mut text = String::new(); for message in &self.messages { text.push_str(match message.role { language_model::Role::User => "User:", language_model::Role::Assistant => "Agent:", language_model::Role::System => "System:", }); text.push('\n'); for segment in &message.segments { match segment { MessageSegment::Text(content) => text.push_str(content), MessageSegment::Thinking { text: content, .. } => { text.push_str(&format!("{}", content)) } MessageSegment::RedactedThinking(_) => {} } } text.push('\n'); } text } /// Serializes this thread into a format for storage or telemetry. pub fn serialize(&self, cx: &mut Context) -> Task> { let initial_project_snapshot = self.initial_project_snapshot.clone(); cx.spawn(async move |this, cx| { let initial_project_snapshot = initial_project_snapshot.await; this.read_with(cx, |this, cx| SerializedThread { version: SerializedThread::VERSION.to_string(), summary: this.summary().or_default(), updated_at: this.updated_at(), messages: this .messages() .map(|message| SerializedMessage { id: message.id, role: message.role, segments: message .segments .iter() .map(|segment| match segment { MessageSegment::Text(text) => { SerializedMessageSegment::Text { text: text.clone() } } MessageSegment::Thinking { text, signature } => { SerializedMessageSegment::Thinking { text: text.clone(), signature: signature.clone(), } } MessageSegment::RedactedThinking(data) => { SerializedMessageSegment::RedactedThinking { data: data.clone(), } } }) .collect(), tool_uses: this .tool_uses_for_message(message.id, cx) .into_iter() .map(|tool_use| SerializedToolUse { id: tool_use.id, name: tool_use.name, input: tool_use.input, }) .collect(), tool_results: this .tool_results_for_message(message.id) .into_iter() .map(|tool_result| SerializedToolResult { tool_use_id: tool_result.tool_use_id.clone(), is_error: tool_result.is_error, content: tool_result.content.clone(), output: tool_result.output.clone(), }) .collect(), context: message.loaded_context.text.clone(), creases: message .creases .iter() .map(|crease| SerializedCrease { start: crease.range.start, end: crease.range.end, icon_path: crease.metadata.icon_path.clone(), label: crease.metadata.label.clone(), }) .collect(), is_hidden: message.is_hidden, }) .collect(), initial_project_snapshot, cumulative_token_usage: this.cumulative_token_usage, request_token_usage: this.request_token_usage.clone(), detailed_summary_state: this.detailed_summary_rx.borrow().clone(), exceeded_window_error: this.exceeded_window_error.clone(), model: this .configured_model .as_ref() .map(|model| SerializedLanguageModel { provider: model.provider.id().0.to_string(), model: model.model.id().0.to_string(), }), completion_mode: Some(this.completion_mode), tool_use_limit_reached: this.tool_use_limit_reached, profile: Some(this.profile.id().clone()), }) }) } pub fn remaining_turns(&self) -> u32 { self.remaining_turns } pub fn set_remaining_turns(&mut self, remaining_turns: u32) { self.remaining_turns = remaining_turns; } pub fn send_to_model( &mut self, model: Arc, intent: CompletionIntent, window: Option, cx: &mut Context, ) { if self.remaining_turns == 0 { return; } self.remaining_turns -= 1; let request = self.to_completion_request(model.clone(), intent, cx); self.stream_completion(request, model, window, cx); } pub fn used_tools_since_last_user_message(&self) -> bool { for message in self.messages.iter().rev() { if self.tool_use.message_has_tool_results(message.id) { return true; } else if message.role == Role::User { return false; } } false } pub fn to_completion_request( &self, model: Arc, intent: CompletionIntent, cx: &mut Context, ) -> LanguageModelRequest { let mut request = LanguageModelRequest { thread_id: Some(self.id.to_string()), prompt_id: Some(self.last_prompt_id.to_string()), intent: Some(intent), mode: None, messages: vec![], tools: Vec::new(), tool_choice: None, stop: Vec::new(), temperature: AgentSettings::temperature_for_model(&model, cx), }; let available_tools = self.available_tools(cx, model.clone()); let available_tool_names = available_tools .iter() .map(|tool| tool.name.clone()) .collect(); let model_context = &ModelContext { available_tools: available_tool_names, }; if let Some(project_context) = self.project_context.borrow().as_ref() { match self .prompt_builder .generate_assistant_system_prompt(project_context, model_context) { Err(err) => { let message = format!("{err:?}").into(); log::error!("{message}"); cx.emit(ThreadEvent::ShowError(ThreadError::Message { header: "Error generating system prompt".into(), message, })); } Ok(system_prompt) => { request.messages.push(LanguageModelRequestMessage { role: Role::System, content: vec![MessageContent::Text(system_prompt)], cache: true, }); } } } else { let message = "Context for system prompt unexpectedly not ready.".into(); log::error!("{message}"); cx.emit(ThreadEvent::ShowError(ThreadError::Message { header: "Error generating system prompt".into(), message, })); } let mut message_ix_to_cache = None; for message in &self.messages { let mut request_message = LanguageModelRequestMessage { role: message.role, content: Vec::new(), cache: false, }; message .loaded_context .add_to_request_message(&mut request_message); for segment in &message.segments { match segment { MessageSegment::Text(text) => { if !text.is_empty() { request_message .content .push(MessageContent::Text(text.into())); } } MessageSegment::Thinking { text, signature } => { if !text.is_empty() { request_message.content.push(MessageContent::Thinking { text: text.into(), signature: signature.clone(), }); } } MessageSegment::RedactedThinking(data) => { request_message .content .push(MessageContent::RedactedThinking(data.clone())); } }; } let mut cache_message = true; let mut tool_results_message = LanguageModelRequestMessage { role: Role::User, content: Vec::new(), cache: false, }; for (tool_use, tool_result) in self.tool_use.tool_results(message.id) { if let Some(tool_result) = tool_result { request_message .content .push(MessageContent::ToolUse(tool_use.clone())); tool_results_message .content .push(MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_use.id.clone(), tool_name: tool_result.tool_name.clone(), is_error: tool_result.is_error, content: if tool_result.content.is_empty() { // Surprisingly, the API fails if we return an empty string here. // It thinks we are sending a tool use without a tool result. "".into() } else { tool_result.content.clone() }, output: None, })); } else { cache_message = false; log::debug!( "skipped tool use {:?} because it is still pending", tool_use ); } } if cache_message { message_ix_to_cache = Some(request.messages.len()); } request.messages.push(request_message); if !tool_results_message.content.is_empty() { if cache_message { message_ix_to_cache = Some(request.messages.len()); } request.messages.push(tool_results_message); } } // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching if let Some(message_ix_to_cache) = message_ix_to_cache { request.messages[message_ix_to_cache].cache = true; } request.tools = available_tools; request.mode = if model.supports_max_mode() { Some(self.completion_mode.into()) } else { Some(CompletionMode::Normal.into()) }; request } fn to_summarize_request( &self, model: &Arc, intent: CompletionIntent, added_user_message: String, cx: &App, ) -> LanguageModelRequest { let mut request = LanguageModelRequest { thread_id: None, prompt_id: None, intent: Some(intent), mode: None, messages: vec![], tools: Vec::new(), tool_choice: None, stop: Vec::new(), temperature: AgentSettings::temperature_for_model(model, cx), }; for message in &self.messages { let mut request_message = LanguageModelRequestMessage { role: message.role, content: Vec::new(), cache: false, }; for segment in &message.segments { match segment { MessageSegment::Text(text) => request_message .content .push(MessageContent::Text(text.clone())), MessageSegment::Thinking { .. } => {} MessageSegment::RedactedThinking(_) => {} } } if request_message.content.is_empty() { continue; } request.messages.push(request_message); } request.messages.push(LanguageModelRequestMessage { role: Role::User, content: vec![MessageContent::Text(added_user_message)], cache: false, }); request } pub fn stream_completion( &mut self, request: LanguageModelRequest, model: Arc, window: Option, cx: &mut Context, ) { self.tool_use_limit_reached = false; let pending_completion_id = post_inc(&mut self.completion_count); let mut request_callback_parameters = if self.request_callback.is_some() { Some((request.clone(), Vec::new())) } else { None }; let prompt_id = self.last_prompt_id.clone(); let tool_use_metadata = ToolUseMetadata { model: model.clone(), thread_id: self.id.clone(), prompt_id: prompt_id.clone(), }; self.last_received_chunk_at = Some(Instant::now()); let task = cx.spawn(async move |thread, cx| { let stream_completion_future = model.stream_completion(request, &cx); let initial_token_usage = thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); let stream_completion = async { let mut events = stream_completion_future.await?; let mut stop_reason = StopReason::EndTurn; let mut current_token_usage = TokenUsage::default(); thread .update(cx, |_thread, cx| { cx.emit(ThreadEvent::NewRequest); }) .ok(); let mut request_assistant_message_id = None; while let Some(event) = events.next().await { if let Some((_, response_events)) = request_callback_parameters.as_mut() { response_events .push(event.as_ref().map_err(|error| error.to_string()).cloned()); } thread.update(cx, |thread, cx| { let event = match event { Ok(event) => event, Err(LanguageModelCompletionError::BadInputJson { id, tool_name, raw_input: invalid_input_json, json_parse_error, }) => { thread.receive_invalid_tool_json( id, tool_name, invalid_input_json, json_parse_error, window, cx, ); return Ok(()); } Err(LanguageModelCompletionError::Other(error)) => { return Err(error); } Err(err @ LanguageModelCompletionError::RateLimit(..)) => { return Err(err.into()); } }; match event { LanguageModelCompletionEvent::StartMessage { .. } => { request_assistant_message_id = Some(thread.insert_assistant_message( vec![MessageSegment::Text(String::new())], cx, )); } LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; } LanguageModelCompletionEvent::UsageUpdate(token_usage) => { thread.update_token_usage_at_last_message(token_usage); thread.cumulative_token_usage = thread.cumulative_token_usage + token_usage - current_token_usage; current_token_usage = token_usage; } LanguageModelCompletionEvent::Text(chunk) => { thread.received_chunk(); cx.emit(ThreadEvent::ReceivedTextChunk); if let Some(last_message) = thread.messages.last_mut() { if last_message.role == Role::Assistant && !thread.tool_use.has_tool_results(last_message.id) { last_message.push_text(&chunk); cx.emit(ThreadEvent::StreamedAssistantText( last_message.id, chunk, )); } else { // If we won't have an Assistant message yet, assume this chunk marks the beginning // of a new Assistant response. // // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it // will result in duplicating the text of the chunk in the rendered Markdown. request_assistant_message_id = Some(thread.insert_assistant_message( vec![MessageSegment::Text(chunk.to_string())], cx, )); }; } } LanguageModelCompletionEvent::Thinking { text: chunk, signature, } => { thread.received_chunk(); if let Some(last_message) = thread.messages.last_mut() { if last_message.role == Role::Assistant && !thread.tool_use.has_tool_results(last_message.id) { last_message.push_thinking(&chunk, signature); cx.emit(ThreadEvent::StreamedAssistantThinking( last_message.id, chunk, )); } else { // If we won't have an Assistant message yet, assume this chunk marks the beginning // of a new Assistant response. // // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it // will result in duplicating the text of the chunk in the rendered Markdown. request_assistant_message_id = Some(thread.insert_assistant_message( vec![MessageSegment::Thinking { text: chunk.to_string(), signature, }], cx, )); }; } } LanguageModelCompletionEvent::ToolUse(tool_use) => { let last_assistant_message_id = request_assistant_message_id .unwrap_or_else(|| { let new_assistant_message_id = thread.insert_assistant_message(vec![], cx); request_assistant_message_id = Some(new_assistant_message_id); new_assistant_message_id }); let tool_use_id = tool_use.id.clone(); let streamed_input = if tool_use.is_input_complete { None } else { Some((&tool_use.input).clone()) }; let ui_text = thread.tool_use.request_tool_use( last_assistant_message_id, tool_use, tool_use_metadata.clone(), cx, ); if let Some(input) = streamed_input { cx.emit(ThreadEvent::StreamedToolUse { tool_use_id, ui_text, input, }); } } LanguageModelCompletionEvent::StatusUpdate(status_update) => { if let Some(completion) = thread .pending_completions .iter_mut() .find(|completion| completion.id == pending_completion_id) { match status_update { CompletionRequestStatus::Queued { position, } => { completion.queue_state = QueueState::Queued { position }; } CompletionRequestStatus::Started => { completion.queue_state = QueueState::Started; } CompletionRequestStatus::Failed { code, message, request_id } => { anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"); } CompletionRequestStatus::UsageUpdated { amount, limit } => { let usage = RequestUsage { limit, amount: amount as i32 }; thread.last_usage = Some(usage); } CompletionRequestStatus::ToolUseLimitReached => { thread.tool_use_limit_reached = true; cx.emit(ThreadEvent::ToolUseLimitReached); } } } } } thread.touch_updated_at(); cx.emit(ThreadEvent::StreamedCompletion); cx.notify(); thread.auto_capture_telemetry(cx); Ok(()) })??; smol::future::yield_now().await; } thread.update(cx, |thread, cx| { thread.last_received_chunk_at = None; thread .pending_completions .retain(|completion| completion.id != pending_completion_id); // If there is a response without tool use, summarize the message. Otherwise, // allow two tool uses before summarizing. if matches!(thread.summary, ThreadSummary::Pending) && thread.messages.len() >= 2 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6) { thread.summarize(cx); } })?; anyhow::Ok(stop_reason) }; let result = stream_completion.await; thread .update(cx, |thread, cx| { thread.finalize_pending_checkpoint(cx); match result.as_ref() { Ok(stop_reason) => match stop_reason { StopReason::ToolUse => { let tool_uses = thread.use_pending_tools(window, cx, model.clone()); cx.emit(ThreadEvent::UsePendingTools { tool_uses }); } StopReason::EndTurn | StopReason::MaxTokens => { thread.project.update(cx, |project, cx| { project.set_agent_location(None, cx); }); } StopReason::Refusal => { thread.project.update(cx, |project, cx| { project.set_agent_location(None, cx); }); // Remove the turn that was refused. // // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal { let mut messages_to_remove = Vec::new(); for (ix, message) in thread.messages.iter().enumerate().rev() { messages_to_remove.push(message.id); if message.role == Role::User { if ix == 0 { break; } if let Some(prev_message) = thread.messages.get(ix - 1) { if prev_message.role == Role::Assistant { break; } } } } for message_id in messages_to_remove { thread.delete_message(message_id, cx); } } cx.emit(ThreadEvent::ShowError(ThreadError::Message { header: "Language model refusal".into(), message: "Model refused to generate content for safety reasons.".into(), })); } }, Err(error) => { thread.project.update(cx, |project, cx| { project.set_agent_location(None, cx); }); if error.is::() { cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired)); } else if let Some(error) = error.downcast_ref::() { cx.emit(ThreadEvent::ShowError( ThreadError::ModelRequestLimitReached { plan: error.plan }, )); } else if let Some(known_error) = error.downcast_ref::() { match known_error { LanguageModelKnownError::ContextWindowLimitExceeded { tokens, } => { thread.exceeded_window_error = Some(ExceededWindowError { model_id: model.id(), token_count: *tokens, }); cx.notify(); } } } else { let error_message = error .chain() .map(|err| err.to_string()) .collect::>() .join("\n"); cx.emit(ThreadEvent::ShowError(ThreadError::Message { header: "Error interacting with language model".into(), message: SharedString::from(error_message.clone()), })); } thread.cancel_last_completion(window, cx); } } cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new))); if let Some((request_callback, (request, response_events))) = thread .request_callback .as_mut() .zip(request_callback_parameters.as_ref()) { request_callback(request, response_events); } thread.auto_capture_telemetry(cx); if let Ok(initial_usage) = initial_token_usage { let usage = thread.cumulative_token_usage - initial_usage; telemetry::event!( "Assistant Thread Completion", thread_id = thread.id().to_string(), prompt_id = prompt_id, model = model.telemetry_id(), model_provider = model.provider_id().to_string(), input_tokens = usage.input_tokens, output_tokens = usage.output_tokens, cache_creation_input_tokens = usage.cache_creation_input_tokens, cache_read_input_tokens = usage.cache_read_input_tokens, ); } }) .ok(); }); self.pending_completions.push(PendingCompletion { id: pending_completion_id, queue_state: QueueState::Sending, _task: task, }); } pub fn summarize(&mut self, cx: &mut Context) { let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else { println!("No thread summary model"); return; }; if !model.provider.is_authenticated(cx) { return; } let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt"); let request = self.to_summarize_request( &model.model, CompletionIntent::ThreadSummarization, added_user_message.into(), cx, ); self.summary = ThreadSummary::Generating; self.pending_summary = cx.spawn(async move |this, cx| { let result = async { let mut messages = model.model.stream_completion(request, &cx).await?; let mut new_summary = String::new(); while let Some(event) = messages.next().await { let Ok(event) = event else { continue; }; let text = match event { LanguageModelCompletionEvent::Text(text) => text, LanguageModelCompletionEvent::StatusUpdate( CompletionRequestStatus::UsageUpdated { amount, limit }, ) => { this.update(cx, |thread, _cx| { thread.last_usage = Some(RequestUsage { limit, amount: amount as i32, }); })?; continue; } _ => continue, }; let mut lines = text.lines(); new_summary.extend(lines.next()); // Stop if the LLM generated multiple lines. if lines.next().is_some() { break; } } anyhow::Ok(new_summary) } .await; this.update(cx, |this, cx| { match result { Ok(new_summary) => { if new_summary.is_empty() { this.summary = ThreadSummary::Error; } else { this.summary = ThreadSummary::Ready(new_summary.into()); } } Err(err) => { this.summary = ThreadSummary::Error; log::error!("Failed to generate thread summary: {}", err); } } cx.emit(ThreadEvent::SummaryGenerated); }) .log_err()?; Some(()) }); } pub fn start_generating_detailed_summary_if_needed( &mut self, thread_store: WeakEntity, cx: &mut Context, ) { let Some(last_message_id) = self.messages.last().map(|message| message.id) else { return; }; match &*self.detailed_summary_rx.borrow() { DetailedSummaryState::Generating { message_id, .. } | DetailedSummaryState::Generated { message_id, .. } if *message_id == last_message_id => { // Already up-to-date return; } _ => {} } let Some(ConfiguredModel { model, provider }) = LanguageModelRegistry::read_global(cx).thread_summary_model() else { return; }; if !provider.is_authenticated(cx) { return; } let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt"); let request = self.to_summarize_request( &model, CompletionIntent::ThreadContextSummarization, added_user_message.into(), cx, ); *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating { message_id: last_message_id, }; // Replace the detailed summarization task if there is one, cancelling it. It would probably // be better to allow the old task to complete, but this would require logic for choosing // which result to prefer (the old task could complete after the new one, resulting in a // stale summary). self.detailed_summary_task = cx.spawn(async move |thread, cx| { let stream = model.stream_completion_text(request, &cx); let Some(mut messages) = stream.await.log_err() else { thread .update(cx, |thread, _cx| { *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::NotGenerated; }) .ok()?; return None; }; let mut new_detailed_summary = String::new(); while let Some(chunk) = messages.stream.next().await { if let Some(chunk) = chunk.log_err() { new_detailed_summary.push_str(&chunk); } } thread .update(cx, |thread, _cx| { *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated { text: new_detailed_summary.into(), message_id: last_message_id, }; }) .ok()?; // Save thread so its summary can be reused later if let Some(thread) = thread.upgrade() { if let Ok(Ok(save_task)) = cx.update(|cx| { thread_store .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)) }) { save_task.await.log_err(); } } Some(()) }); } pub async fn wait_for_detailed_summary_or_text( this: &Entity, cx: &mut AsyncApp, ) -> Option { let mut detailed_summary_rx = this .read_with(cx, |this, _cx| this.detailed_summary_rx.clone()) .ok()?; loop { match detailed_summary_rx.recv().await? { DetailedSummaryState::Generating { .. } => {} DetailedSummaryState::NotGenerated => { return this.read_with(cx, |this, _cx| this.text().into()).ok(); } DetailedSummaryState::Generated { text, .. } => return Some(text), } } } pub fn latest_detailed_summary_or_text(&self) -> SharedString { self.detailed_summary_rx .borrow() .text() .unwrap_or_else(|| self.text().into()) } pub fn is_generating_detailed_summary(&self) -> bool { matches!( &*self.detailed_summary_rx.borrow(), DetailedSummaryState::Generating { .. } ) } pub fn use_pending_tools( &mut self, window: Option, cx: &mut Context, model: Arc, ) -> Vec { self.auto_capture_telemetry(cx); let request = Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx)); let pending_tool_uses = self .tool_use .pending_tool_uses() .into_iter() .filter(|tool_use| tool_use.status.is_idle()) .cloned() .collect::>(); for tool_use in pending_tool_uses.iter() { if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) { if tool.needs_confirmation(&tool_use.input, cx) && !AgentSettings::get_global(cx).always_allow_tool_actions { self.tool_use.confirm_tool_use( tool_use.id.clone(), tool_use.ui_text.clone(), tool_use.input.clone(), request.clone(), tool, ); cx.emit(ThreadEvent::ToolConfirmationNeeded); } else { self.run_tool( tool_use.id.clone(), tool_use.ui_text.clone(), tool_use.input.clone(), request.clone(), tool, model.clone(), window, cx, ); } } else { self.handle_hallucinated_tool_use( tool_use.id.clone(), tool_use.name.clone(), window, cx, ); } } pending_tool_uses } pub fn handle_hallucinated_tool_use( &mut self, tool_use_id: LanguageModelToolUseId, hallucinated_tool_name: Arc, window: Option, cx: &mut Context, ) { let available_tools = self.profile.enabled_tools(cx); let tool_list = available_tools .iter() .map(|tool| format!("- {}: {}", tool.name(), tool.description())) .collect::>() .join("\n"); let error_message = format!( "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}", hallucinated_tool_name, tool_list ); let pending_tool_use = self.tool_use.insert_tool_output( tool_use_id.clone(), hallucinated_tool_name, Err(anyhow!("Missing tool call: {error_message}")), self.configured_model.as_ref(), ); cx.emit(ThreadEvent::MissingToolUse { tool_use_id: tool_use_id.clone(), ui_text: error_message.into(), }); self.tool_finished(tool_use_id, pending_tool_use, false, window, cx); } pub fn receive_invalid_tool_json( &mut self, tool_use_id: LanguageModelToolUseId, tool_name: Arc, invalid_json: Arc, error: String, window: Option, cx: &mut Context, ) { log::error!("The model returned invalid input JSON: {invalid_json}"); let pending_tool_use = self.tool_use.insert_tool_output( tool_use_id.clone(), tool_name, Err(anyhow!("Error parsing input JSON: {error}")), self.configured_model.as_ref(), ); let ui_text = if let Some(pending_tool_use) = &pending_tool_use { pending_tool_use.ui_text.clone() } else { log::error!( "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)." ); format!("Unknown tool {}", tool_use_id).into() }; cx.emit(ThreadEvent::InvalidToolInput { tool_use_id: tool_use_id.clone(), ui_text, invalid_input_json: invalid_json, }); self.tool_finished(tool_use_id, pending_tool_use, false, window, cx); } pub fn run_tool( &mut self, tool_use_id: LanguageModelToolUseId, ui_text: impl Into, input: serde_json::Value, request: Arc, tool: Arc, model: Arc, window: Option, cx: &mut Context, ) { let task = self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx); self.tool_use .run_pending_tool(tool_use_id, ui_text.into(), task); } fn spawn_tool_use( &mut self, tool_use_id: LanguageModelToolUseId, request: Arc, input: serde_json::Value, tool: Arc, model: Arc, window: Option, cx: &mut Context, ) -> Task<()> { let tool_name: Arc = tool.name().into(); let tool_result = tool.run( input, request, self.project.clone(), self.action_log.clone(), model, window, cx, ); // Store the card separately if it exists if let Some(card) = tool_result.card.clone() { self.tool_use .insert_tool_result_card(tool_use_id.clone(), card); } cx.spawn({ async move |thread: WeakEntity, cx| { let output = tool_result.output.await; thread .update(cx, |thread, cx| { let pending_tool_use = thread.tool_use.insert_tool_output( tool_use_id.clone(), tool_name, output, thread.configured_model.as_ref(), ); thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx); }) .ok(); } }) } fn tool_finished( &mut self, tool_use_id: LanguageModelToolUseId, pending_tool_use: Option, canceled: bool, window: Option, cx: &mut Context, ) { if self.all_tools_finished() { if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() { if !canceled { self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx); } self.auto_capture_telemetry(cx); } } cx.emit(ThreadEvent::ToolFinished { tool_use_id, pending_tool_use, }); } /// Cancels the last pending completion, if there are any pending. /// /// Returns whether a completion was canceled. pub fn cancel_last_completion( &mut self, window: Option, cx: &mut Context, ) -> bool { let mut canceled = self.pending_completions.pop().is_some(); for pending_tool_use in self.tool_use.cancel_pending() { canceled = true; self.tool_finished( pending_tool_use.id.clone(), Some(pending_tool_use), true, window, cx, ); } if canceled { cx.emit(ThreadEvent::CompletionCanceled); // When canceled, we always want to insert the checkpoint. // (We skip over finalize_pending_checkpoint, because it // would conclude we didn't have anything to insert here.) if let Some(checkpoint) = self.pending_checkpoint.take() { self.insert_checkpoint(checkpoint, cx); } } else { self.finalize_pending_checkpoint(cx); } canceled } /// Signals that any in-progress editing should be canceled. /// /// This method is used to notify listeners (like ActiveThread) that /// they should cancel any editing operations. pub fn cancel_editing(&mut self, cx: &mut Context) { cx.emit(ThreadEvent::CancelEditing); } pub fn feedback(&self) -> Option { self.feedback } pub fn message_feedback(&self, message_id: MessageId) -> Option { self.message_feedback.get(&message_id).copied() } pub fn report_message_feedback( &mut self, message_id: MessageId, feedback: ThreadFeedback, cx: &mut Context, ) -> Task> { if self.message_feedback.get(&message_id) == Some(&feedback) { return Task::ready(Ok(())); } let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx); let serialized_thread = self.serialize(cx); let thread_id = self.id().clone(); let client = self.project.read(cx).client(); let enabled_tool_names: Vec = self .profile .enabled_tools(cx) .iter() .map(|tool| tool.name()) .collect(); self.message_feedback.insert(message_id, feedback); cx.notify(); let message_content = self .message(message_id) .map(|msg| msg.to_string()) .unwrap_or_default(); cx.background_spawn(async move { let final_project_snapshot = final_project_snapshot.await; let serialized_thread = serialized_thread.await?; let thread_data = serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null); let rating = match feedback { ThreadFeedback::Positive => "positive", ThreadFeedback::Negative => "negative", }; telemetry::event!( "Assistant Thread Rated", rating, thread_id, enabled_tool_names, message_id = message_id.0, message_content, thread_data, final_project_snapshot ); client.telemetry().flush_events().await; Ok(()) }) } pub fn report_feedback( &mut self, feedback: ThreadFeedback, cx: &mut Context, ) -> Task> { let last_assistant_message_id = self .messages .iter() .rev() .find(|msg| msg.role == Role::Assistant) .map(|msg| msg.id); if let Some(message_id) = last_assistant_message_id { self.report_message_feedback(message_id, feedback, cx) } else { let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx); let serialized_thread = self.serialize(cx); let thread_id = self.id().clone(); let client = self.project.read(cx).client(); self.feedback = Some(feedback); cx.notify(); cx.background_spawn(async move { let final_project_snapshot = final_project_snapshot.await; let serialized_thread = serialized_thread.await?; let thread_data = serde_json::to_value(serialized_thread) .unwrap_or_else(|_| serde_json::Value::Null); let rating = match feedback { ThreadFeedback::Positive => "positive", ThreadFeedback::Negative => "negative", }; telemetry::event!( "Assistant Thread Rated", rating, thread_id, thread_data, final_project_snapshot ); client.telemetry().flush_events().await; Ok(()) }) } } /// Create a snapshot of the current project state including git information and unsaved buffers. fn project_snapshot( project: Entity, cx: &mut Context, ) -> Task> { let git_store = project.read(cx).git_store().clone(); let worktree_snapshots: Vec<_> = project .read(cx) .visible_worktrees(cx) .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx)) .collect(); cx.spawn(async move |_, cx| { let worktree_snapshots = futures::future::join_all(worktree_snapshots).await; let mut unsaved_buffers = Vec::new(); cx.update(|app_cx| { let buffer_store = project.read(app_cx).buffer_store(); for buffer_handle in buffer_store.read(app_cx).buffers() { let buffer = buffer_handle.read(app_cx); if buffer.is_dirty() { if let Some(file) = buffer.file() { let path = file.path().to_string_lossy().to_string(); unsaved_buffers.push(path); } } } }) .ok(); Arc::new(ProjectSnapshot { worktree_snapshots, unsaved_buffer_paths: unsaved_buffers, timestamp: Utc::now(), }) }) } fn worktree_snapshot( worktree: Entity, git_store: Entity, cx: &App, ) -> Task { cx.spawn(async move |cx| { // Get worktree path and snapshot let worktree_info = cx.update(|app_cx| { let worktree = worktree.read(app_cx); let path = worktree.abs_path().to_string_lossy().to_string(); let snapshot = worktree.snapshot(); (path, snapshot) }); let Ok((worktree_path, _snapshot)) = worktree_info else { return WorktreeSnapshot { worktree_path: String::new(), git_state: None, }; }; let git_state = git_store .update(cx, |git_store, cx| { git_store .repositories() .values() .find(|repo| { repo.read(cx) .abs_path_to_repo_path(&worktree.read(cx).abs_path()) .is_some() }) .cloned() }) .ok() .flatten() .map(|repo| { repo.update(cx, |repo, _| { let current_branch = repo.branch.as_ref().map(|branch| branch.name().to_owned()); repo.send_job(None, |state, _| async move { let RepositoryState::Local { backend, .. } = state else { return GitState { remote_url: None, head_sha: None, current_branch, diff: None, }; }; let remote_url = backend.remote_url("origin"); let head_sha = backend.head_sha().await; let diff = backend.diff(DiffType::HeadToWorktree).await.ok(); GitState { remote_url, head_sha, current_branch, diff, } }) }) }); let git_state = match git_state { Some(git_state) => match git_state.ok() { Some(git_state) => git_state.await.ok(), None => None, }, None => None, }; WorktreeSnapshot { worktree_path, git_state, } }) } pub fn to_markdown(&self, cx: &App) -> Result { let mut markdown = Vec::new(); let summary = self.summary().or_default(); writeln!(markdown, "# {summary}\n")?; for message in self.messages() { writeln!( markdown, "## {role}\n", role = match message.role { Role::User => "User", Role::Assistant => "Agent", Role::System => "System", } )?; if !message.loaded_context.text.is_empty() { writeln!(markdown, "{}", message.loaded_context.text)?; } if !message.loaded_context.images.is_empty() { writeln!( markdown, "\n{} images attached as context.\n", message.loaded_context.images.len() )?; } for segment in &message.segments { match segment { MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?, MessageSegment::Thinking { text, .. } => { writeln!(markdown, "\n{}\n\n", text)? } MessageSegment::RedactedThinking(_) => {} } } for tool_use in self.tool_uses_for_message(message.id, cx) { writeln!( markdown, "**Use Tool: {} ({})**", tool_use.name, tool_use.id )?; writeln!(markdown, "```json")?; writeln!( markdown, "{}", serde_json::to_string_pretty(&tool_use.input)? )?; writeln!(markdown, "```")?; } for tool_result in self.tool_results_for_message(message.id) { write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?; if tool_result.is_error { write!(markdown, " (Error)")?; } writeln!(markdown, "**\n")?; match &tool_result.content { LanguageModelToolResultContent::Text(text) => { writeln!(markdown, "{text}")?; } LanguageModelToolResultContent::Image(image) => { writeln!(markdown, "![Image](data:base64,{})", image.source)?; } } if let Some(output) = tool_result.output.as_ref() { writeln!( markdown, "\n\nDebug Output:\n\n```json\n{}\n```\n", serde_json::to_string_pretty(output)? )?; } } } Ok(String::from_utf8_lossy(&markdown).to_string()) } pub fn keep_edits_in_range( &mut self, buffer: Entity, buffer_range: Range, cx: &mut Context, ) { self.action_log.update(cx, |action_log, cx| { action_log.keep_edits_in_range(buffer, buffer_range, cx) }); } pub fn keep_all_edits(&mut self, cx: &mut Context) { self.action_log .update(cx, |action_log, cx| action_log.keep_all_edits(cx)); } pub fn reject_edits_in_ranges( &mut self, buffer: Entity, buffer_ranges: Vec>, cx: &mut Context, ) -> Task> { self.action_log.update(cx, |action_log, cx| { action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx) }) } pub fn action_log(&self) -> &Entity { &self.action_log } pub fn project(&self) -> &Entity { &self.project } pub fn auto_capture_telemetry(&mut self, cx: &mut Context) { if !cx.has_flag::() { return; } let now = Instant::now(); if let Some(last) = self.last_auto_capture_at { if now.duration_since(last).as_secs() < 10 { return; } } self.last_auto_capture_at = Some(now); let thread_id = self.id().clone(); let github_login = self .project .read(cx) .user_store() .read(cx) .current_user() .map(|user| user.github_login.clone()); let client = self.project.read(cx).client(); let serialize_task = self.serialize(cx); cx.background_executor() .spawn(async move { if let Ok(serialized_thread) = serialize_task.await { if let Ok(thread_data) = serde_json::to_value(serialized_thread) { telemetry::event!( "Agent Thread Auto-Captured", thread_id = thread_id.to_string(), thread_data = thread_data, auto_capture_reason = "tracked_user", github_login = github_login ); client.telemetry().flush_events().await; } } }) .detach(); } pub fn cumulative_token_usage(&self) -> TokenUsage { self.cumulative_token_usage } pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage { let Some(model) = self.configured_model.as_ref() else { return TotalTokenUsage::default(); }; let max = model.model.max_token_count(); let index = self .messages .iter() .position(|msg| msg.id == message_id) .unwrap_or(0); if index == 0 { return TotalTokenUsage { total: 0, max }; } let token_usage = &self .request_token_usage .get(index - 1) .cloned() .unwrap_or_default(); TotalTokenUsage { total: token_usage.total_tokens(), max, } } pub fn total_token_usage(&self) -> Option { let model = self.configured_model.as_ref()?; let max = model.model.max_token_count(); if let Some(exceeded_error) = &self.exceeded_window_error { if model.model.id() == exceeded_error.model_id { return Some(TotalTokenUsage { total: exceeded_error.token_count, max, }); } } let total = self .token_usage_at_last_message() .unwrap_or_default() .total_tokens(); Some(TotalTokenUsage { total, max }) } fn token_usage_at_last_message(&self) -> Option { self.request_token_usage .get(self.messages.len().saturating_sub(1)) .or_else(|| self.request_token_usage.last()) .cloned() } fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) { let placeholder = self.token_usage_at_last_message().unwrap_or_default(); self.request_token_usage .resize(self.messages.len(), placeholder); if let Some(last) = self.request_token_usage.last_mut() { *last = token_usage; } } pub fn deny_tool_use( &mut self, tool_use_id: LanguageModelToolUseId, tool_name: Arc, window: Option, cx: &mut Context, ) { let err = Err(anyhow::anyhow!( "Permission to run tool action denied by user" )); self.tool_use.insert_tool_output( tool_use_id.clone(), tool_name, err, self.configured_model.as_ref(), ); self.tool_finished(tool_use_id.clone(), None, true, window, cx); } } #[derive(Debug, Clone, Error)] pub enum ThreadError { #[error("Payment required")] PaymentRequired, #[error("Model request limit reached")] ModelRequestLimitReached { plan: Plan }, #[error("Message {header}: {message}")] Message { header: SharedString, message: SharedString, }, } #[derive(Debug, Clone)] pub enum ThreadEvent { ShowError(ThreadError), StreamedCompletion, ReceivedTextChunk, NewRequest, StreamedAssistantText(MessageId, String), StreamedAssistantThinking(MessageId, String), StreamedToolUse { tool_use_id: LanguageModelToolUseId, ui_text: Arc, input: serde_json::Value, }, MissingToolUse { tool_use_id: LanguageModelToolUseId, ui_text: Arc, }, InvalidToolInput { tool_use_id: LanguageModelToolUseId, ui_text: Arc, invalid_input_json: Arc, }, Stopped(Result>), MessageAdded(MessageId), MessageEdited(MessageId), MessageDeleted(MessageId), SummaryGenerated, SummaryChanged, UsePendingTools { tool_uses: Vec, }, ToolFinished { #[allow(unused)] tool_use_id: LanguageModelToolUseId, /// The pending tool use that corresponds to this tool. pending_tool_use: Option, }, CheckpointChanged, ToolConfirmationNeeded, ToolUseLimitReached, CancelEditing, CompletionCanceled, ProfileChanged, } impl EventEmitter for Thread {} struct PendingCompletion { id: usize, queue_state: QueueState, _task: Task<()>, } #[cfg(test)] mod tests { use super::*; use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store}; use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters}; use assistant_tool::ToolRegistry; use editor::EditorSettings; use gpui::TestAppContext; use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; use project::{FakeFs, Project}; use prompt_store::PromptBuilder; use serde_json::json; use settings::{Settings, SettingsStore}; use std::sync::Arc; use theme::ThemeSettings; use util::path; use workspace::Workspace; #[gpui::test] async fn test_message_with_context(cx: &mut TestAppContext) { init_test_settings(cx); let project = create_test_project( cx, json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), ) .await; let (_workspace, _thread_store, thread, context_store, model) = setup_test_environment(cx, project.clone()).await; add_file_to_context(&project, &context_store, "test/code.rs", cx) .await .unwrap(); let context = context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap()); let loaded_context = cx .update(|cx| load_context(vec![context], &project, &None, cx)) .await; // Insert user message with context let message_id = thread.update(cx, |thread, cx| { thread.insert_user_message( "Please explain this code", loaded_context, None, Vec::new(), cx, ) }); // Check content and context in message object let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone()); // Use different path format strings based on platform for the test #[cfg(windows)] let path_part = r"test\code.rs"; #[cfg(not(windows))] let path_part = "test/code.rs"; let expected_context = format!( r#" The following items were attached by the user. They are up-to-date and don't need to be re-read. ```rs {path_part} fn main() {{ println!("Hello, world!"); }} ``` "# ); assert_eq!(message.role, Role::User); assert_eq!(message.segments.len(), 1); assert_eq!( message.segments[0], MessageSegment::Text("Please explain this code".to_string()) ); assert_eq!(message.loaded_context.text, expected_context); // Check message in request let request = thread.update(cx, |thread, cx| { thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) }); assert_eq!(request.messages.len(), 2); let expected_full_message = format!("{}Please explain this code", expected_context); assert_eq!(request.messages[1].string_contents(), expected_full_message); } #[gpui::test] async fn test_only_include_new_contexts(cx: &mut TestAppContext) { init_test_settings(cx); let project = create_test_project( cx, json!({ "file1.rs": "fn function1() {}\n", "file2.rs": "fn function2() {}\n", "file3.rs": "fn function3() {}\n", "file4.rs": "fn function4() {}\n", }), ) .await; let (_, _thread_store, thread, context_store, model) = setup_test_environment(cx, project.clone()).await; // First message with context 1 add_file_to_context(&project, &context_store, "test/file1.rs", cx) .await .unwrap(); let new_contexts = context_store.update(cx, |store, cx| { store.new_context_for_thread(thread.read(cx), None) }); assert_eq!(new_contexts.len(), 1); let loaded_context = cx .update(|cx| load_context(new_contexts, &project, &None, cx)) .await; let message1_id = thread.update(cx, |thread, cx| { thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx) }); // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included) add_file_to_context(&project, &context_store, "test/file2.rs", cx) .await .unwrap(); let new_contexts = context_store.update(cx, |store, cx| { store.new_context_for_thread(thread.read(cx), None) }); assert_eq!(new_contexts.len(), 1); let loaded_context = cx .update(|cx| load_context(new_contexts, &project, &None, cx)) .await; let message2_id = thread.update(cx, |thread, cx| { thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx) }); // Third message with all three contexts (contexts 1 and 2 should be skipped) // add_file_to_context(&project, &context_store, "test/file3.rs", cx) .await .unwrap(); let new_contexts = context_store.update(cx, |store, cx| { store.new_context_for_thread(thread.read(cx), None) }); assert_eq!(new_contexts.len(), 1); let loaded_context = cx .update(|cx| load_context(new_contexts, &project, &None, cx)) .await; let message3_id = thread.update(cx, |thread, cx| { thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx) }); // Check what contexts are included in each message let (message1, message2, message3) = thread.read_with(cx, |thread, _| { ( thread.message(message1_id).unwrap().clone(), thread.message(message2_id).unwrap().clone(), thread.message(message3_id).unwrap().clone(), ) }); // First message should include context 1 assert!(message1.loaded_context.text.contains("file1.rs")); // Second message should include only context 2 (not 1) assert!(!message2.loaded_context.text.contains("file1.rs")); assert!(message2.loaded_context.text.contains("file2.rs")); // Third message should include only context 3 (not 1 or 2) assert!(!message3.loaded_context.text.contains("file1.rs")); assert!(!message3.loaded_context.text.contains("file2.rs")); assert!(message3.loaded_context.text.contains("file3.rs")); // Check entire request to make sure all contexts are properly included let request = thread.update(cx, |thread, cx| { thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) }); // The request should contain all 3 messages assert_eq!(request.messages.len(), 4); // Check that the contexts are properly formatted in each message assert!(request.messages[1].string_contents().contains("file1.rs")); assert!(!request.messages[1].string_contents().contains("file2.rs")); assert!(!request.messages[1].string_contents().contains("file3.rs")); assert!(!request.messages[2].string_contents().contains("file1.rs")); assert!(request.messages[2].string_contents().contains("file2.rs")); assert!(!request.messages[2].string_contents().contains("file3.rs")); assert!(!request.messages[3].string_contents().contains("file1.rs")); assert!(!request.messages[3].string_contents().contains("file2.rs")); assert!(request.messages[3].string_contents().contains("file3.rs")); add_file_to_context(&project, &context_store, "test/file4.rs", cx) .await .unwrap(); let new_contexts = context_store.update(cx, |store, cx| { store.new_context_for_thread(thread.read(cx), Some(message2_id)) }); assert_eq!(new_contexts.len(), 3); let loaded_context = cx .update(|cx| load_context(new_contexts, &project, &None, cx)) .await .loaded_context; assert!(!loaded_context.text.contains("file1.rs")); assert!(loaded_context.text.contains("file2.rs")); assert!(loaded_context.text.contains("file3.rs")); assert!(loaded_context.text.contains("file4.rs")); let new_contexts = context_store.update(cx, |store, cx| { // Remove file4.rs store.remove_context(&loaded_context.contexts[2].handle(), cx); store.new_context_for_thread(thread.read(cx), Some(message2_id)) }); assert_eq!(new_contexts.len(), 2); let loaded_context = cx .update(|cx| load_context(new_contexts, &project, &None, cx)) .await .loaded_context; assert!(!loaded_context.text.contains("file1.rs")); assert!(loaded_context.text.contains("file2.rs")); assert!(loaded_context.text.contains("file3.rs")); assert!(!loaded_context.text.contains("file4.rs")); let new_contexts = context_store.update(cx, |store, cx| { // Remove file3.rs store.remove_context(&loaded_context.contexts[1].handle(), cx); store.new_context_for_thread(thread.read(cx), Some(message2_id)) }); assert_eq!(new_contexts.len(), 1); let loaded_context = cx .update(|cx| load_context(new_contexts, &project, &None, cx)) .await .loaded_context; assert!(!loaded_context.text.contains("file1.rs")); assert!(loaded_context.text.contains("file2.rs")); assert!(!loaded_context.text.contains("file3.rs")); assert!(!loaded_context.text.contains("file4.rs")); } #[gpui::test] async fn test_message_without_files(cx: &mut TestAppContext) { init_test_settings(cx); let project = create_test_project( cx, json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), ) .await; let (_, _thread_store, thread, _context_store, model) = setup_test_environment(cx, project.clone()).await; // Insert user message without any context (empty context vector) let message_id = thread.update(cx, |thread, cx| { thread.insert_user_message( "What is the best way to learn Rust?", ContextLoadResult::default(), None, Vec::new(), cx, ) }); // Check content and context in message object let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone()); // Context should be empty when no files are included assert_eq!(message.role, Role::User); assert_eq!(message.segments.len(), 1); assert_eq!( message.segments[0], MessageSegment::Text("What is the best way to learn Rust?".to_string()) ); assert_eq!(message.loaded_context.text, ""); // Check message in request let request = thread.update(cx, |thread, cx| { thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) }); assert_eq!(request.messages.len(), 2); assert_eq!( request.messages[1].string_contents(), "What is the best way to learn Rust?" ); // Add second message, also without context let message2_id = thread.update(cx, |thread, cx| { thread.insert_user_message( "Are there any good books?", ContextLoadResult::default(), None, Vec::new(), cx, ) }); let message2 = thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone()); assert_eq!(message2.loaded_context.text, ""); // Check that both messages appear in the request let request = thread.update(cx, |thread, cx| { thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) }); assert_eq!(request.messages.len(), 3); assert_eq!( request.messages[1].string_contents(), "What is the best way to learn Rust?" ); assert_eq!( request.messages[2].string_contents(), "Are there any good books?" ); } #[gpui::test] async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) { init_test_settings(cx); let project = create_test_project( cx, json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), ) .await; let (_workspace, thread_store, thread, _context_store, _model) = setup_test_environment(cx, project.clone()).await; // Check that we are starting with the default profile let profile = cx.read(|cx| thread.read(cx).profile.clone()); let tool_set = cx.read(|cx| thread_store.read(cx).tools()); assert_eq!( profile, AgentProfile::new(AgentProfileId::default(), tool_set) ); } #[gpui::test] async fn test_serializing_thread_profile(cx: &mut TestAppContext) { init_test_settings(cx); let project = create_test_project( cx, json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), ) .await; let (_workspace, thread_store, thread, _context_store, _model) = setup_test_environment(cx, project.clone()).await; // Profile gets serialized with default values let serialized = thread .update(cx, |thread, cx| thread.serialize(cx)) .await .unwrap(); assert_eq!(serialized.profile, Some(AgentProfileId::default())); let deserialized = cx.update(|cx| { thread.update(cx, |thread, cx| { Thread::deserialize( thread.id.clone(), serialized, thread.project.clone(), thread.tools.clone(), thread.prompt_builder.clone(), thread.project_context.clone(), None, cx, ) }) }); let tool_set = cx.read(|cx| thread_store.read(cx).tools()); assert_eq!( deserialized.profile, AgentProfile::new(AgentProfileId::default(), tool_set) ); } #[gpui::test] async fn test_temperature_setting(cx: &mut TestAppContext) { init_test_settings(cx); let project = create_test_project( cx, json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), ) .await; let (_workspace, _thread_store, thread, _context_store, model) = setup_test_environment(cx, project.clone()).await; // Both model and provider cx.update(|cx| { AgentSettings::override_global( AgentSettings { model_parameters: vec![LanguageModelParameters { provider: Some(model.provider_id().0.to_string().into()), model: Some(model.id().0.clone()), temperature: Some(0.66), }], ..AgentSettings::get_global(cx).clone() }, cx, ); }); let request = thread.update(cx, |thread, cx| { thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) }); assert_eq!(request.temperature, Some(0.66)); // Only model cx.update(|cx| { AgentSettings::override_global( AgentSettings { model_parameters: vec![LanguageModelParameters { provider: None, model: Some(model.id().0.clone()), temperature: Some(0.66), }], ..AgentSettings::get_global(cx).clone() }, cx, ); }); let request = thread.update(cx, |thread, cx| { thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) }); assert_eq!(request.temperature, Some(0.66)); // Only provider cx.update(|cx| { AgentSettings::override_global( AgentSettings { model_parameters: vec![LanguageModelParameters { provider: Some(model.provider_id().0.to_string().into()), model: None, temperature: Some(0.66), }], ..AgentSettings::get_global(cx).clone() }, cx, ); }); let request = thread.update(cx, |thread, cx| { thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) }); assert_eq!(request.temperature, Some(0.66)); // Same model name, different provider cx.update(|cx| { AgentSettings::override_global( AgentSettings { model_parameters: vec![LanguageModelParameters { provider: Some("anthropic".into()), model: Some(model.id().0.clone()), temperature: Some(0.66), }], ..AgentSettings::get_global(cx).clone() }, cx, ); }); let request = thread.update(cx, |thread, cx| { thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) }); assert_eq!(request.temperature, None); } #[gpui::test] async fn test_thread_summary(cx: &mut TestAppContext) { init_test_settings(cx); let project = create_test_project(cx, json!({})).await; let (_, _thread_store, thread, _context_store, model) = setup_test_environment(cx, project.clone()).await; // Initial state should be pending thread.read_with(cx, |thread, _| { assert!(matches!(thread.summary(), ThreadSummary::Pending)); assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); }); // Manually setting the summary should not be allowed in this state thread.update(cx, |thread, cx| { thread.set_summary("This should not work", cx); }); thread.read_with(cx, |thread, _| { assert!(matches!(thread.summary(), ThreadSummary::Pending)); }); // Send a message thread.update(cx, |thread, cx| { thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); thread.send_to_model( model.clone(), CompletionIntent::ThreadSummarization, None, cx, ); }); let fake_model = model.as_fake(); simulate_successful_response(&fake_model, cx); // Should start generating summary when there are >= 2 messages thread.read_with(cx, |thread, _| { assert_eq!(*thread.summary(), ThreadSummary::Generating); }); // Should not be able to set the summary while generating thread.update(cx, |thread, cx| { thread.set_summary("This should not work either", cx); }); thread.read_with(cx, |thread, _| { assert!(matches!(thread.summary(), ThreadSummary::Generating)); assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); }); cx.run_until_parked(); fake_model.stream_last_completion_response("Brief"); fake_model.stream_last_completion_response(" Introduction"); fake_model.end_last_completion_stream(); cx.run_until_parked(); // Summary should be set thread.read_with(cx, |thread, _| { assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); assert_eq!(thread.summary().or_default(), "Brief Introduction"); }); // Now we should be able to set a summary thread.update(cx, |thread, cx| { thread.set_summary("Brief Intro", cx); }); thread.read_with(cx, |thread, _| { assert_eq!(thread.summary().or_default(), "Brief Intro"); }); // Test setting an empty summary (should default to DEFAULT) thread.update(cx, |thread, cx| { thread.set_summary("", cx); }); thread.read_with(cx, |thread, _| { assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); }); } #[gpui::test] async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) { init_test_settings(cx); let project = create_test_project(cx, json!({})).await; let (_, _thread_store, thread, _context_store, model) = setup_test_environment(cx, project.clone()).await; test_summarize_error(&model, &thread, cx); // Now we should be able to set a summary thread.update(cx, |thread, cx| { thread.set_summary("Brief Intro", cx); }); thread.read_with(cx, |thread, _| { assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); assert_eq!(thread.summary().or_default(), "Brief Intro"); }); } #[gpui::test] async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { init_test_settings(cx); let project = create_test_project(cx, json!({})).await; let (_, _thread_store, thread, _context_store, model) = setup_test_environment(cx, project.clone()).await; test_summarize_error(&model, &thread, cx); // Sending another message should not trigger another summarize request thread.update(cx, |thread, cx| { thread.insert_user_message( "How are you?", ContextLoadResult::default(), None, vec![], cx, ); thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); }); let fake_model = model.as_fake(); simulate_successful_response(&fake_model, cx); thread.read_with(cx, |thread, _| { // State is still Error, not Generating assert!(matches!(thread.summary(), ThreadSummary::Error)); }); // But the summarize request can be invoked manually thread.update(cx, |thread, cx| { thread.summarize(cx); }); thread.read_with(cx, |thread, _| { assert!(matches!(thread.summary(), ThreadSummary::Generating)); }); cx.run_until_parked(); fake_model.stream_last_completion_response("A successful summary"); fake_model.end_last_completion_stream(); cx.run_until_parked(); thread.read_with(cx, |thread, _| { assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); assert_eq!(thread.summary().or_default(), "A successful summary"); }); } fn test_summarize_error( model: &Arc, thread: &Entity, cx: &mut TestAppContext, ) { thread.update(cx, |thread, cx| { thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); thread.send_to_model( model.clone(), CompletionIntent::ThreadSummarization, None, cx, ); }); let fake_model = model.as_fake(); simulate_successful_response(&fake_model, cx); thread.read_with(cx, |thread, _| { assert!(matches!(thread.summary(), ThreadSummary::Generating)); assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); }); // Simulate summary request ending cx.run_until_parked(); fake_model.end_last_completion_stream(); cx.run_until_parked(); // State is set to Error and default message thread.read_with(cx, |thread, _| { assert!(matches!(thread.summary(), ThreadSummary::Error)); assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); }); } fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { cx.run_until_parked(); fake_model.stream_last_completion_response("Assistant response"); fake_model.end_last_completion_stream(); cx.run_until_parked(); } fn init_test_settings(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); language::init(cx); Project::init_settings(cx); AgentSettings::register(cx); prompt_store::init(cx); thread_store::init(cx); workspace::init_settings(cx); language_model::init_settings(cx); ThemeSettings::register(cx); EditorSettings::register(cx); ToolRegistry::default_global(cx); }); } // Helper to create a test project with test files async fn create_test_project( cx: &mut TestAppContext, files: serde_json::Value, ) -> Entity { let fs = FakeFs::new(cx.executor()); fs.insert_tree(path!("/test"), files).await; Project::test(fs, [path!("/test").as_ref()], cx).await } async fn setup_test_environment( cx: &mut TestAppContext, project: Entity, ) -> ( Entity, Entity, Entity, Entity, Arc, ) { let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); let thread_store = cx .update(|_, cx| { ThreadStore::load( project.clone(), cx.new(|_| ToolWorkingSet::default()), None, Arc::new(PromptBuilder::new(None).unwrap()), cx, ) }) .await .unwrap(); let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); let provider = Arc::new(FakeLanguageModelProvider); let model = provider.test_model(); let model: Arc = Arc::new(model); cx.update(|_, cx| { LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.set_default_model( Some(ConfiguredModel { provider: provider.clone(), model: model.clone(), }), cx, ); registry.set_thread_summary_model( Some(ConfiguredModel { provider, model: model.clone(), }), cx, ); }) }); (workspace, thread_store, thread, context_store, model) } async fn add_file_to_context( project: &Entity, context_store: &Entity, path: &str, cx: &mut TestAppContext, ) -> Result> { let buffer_path = project .read_with(cx, |project, cx| project.find_project_path(path, cx)) .unwrap(); let buffer = project .update(cx, |project, cx| { project.open_buffer(buffer_path.clone(), cx) }) .await .unwrap(); context_store.update(cx, |context_store, cx| { context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx); }); Ok(buffer) } }