ZIm/crates/agent/src/thread.rs
Anthony Eid b349a8f34c
ai: Auto select user model when there's no default (#36722)
This PR identifies automatic configuration options that users can select
from the agent panel. If no default provider is set in their settings,
the PR defaults to the first recommended option. Additionally, it
updates the selected provider for a thread when a user changes the
default provider through the settings file, if the thread hasn't had any
queries yet.

Release Notes:

- agent: automatically select a language model provider if there's no
user set provider.

---------

Co-authored-by: Michael Sloan <michael@zed.dev>
2025-08-22 01:12:12 -04:00

5446 lines
191 KiB
Rust

use crate::{
agent_profile::AgentProfile,
context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
thread_store::{
SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
ThreadStore,
},
tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
};
use action_log::ActionLog;
use agent_settings::{
AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
SUMMARIZE_THREAD_PROMPT,
};
use anyhow::{Result, anyhow};
use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
use collections::HashMap;
use futures::{FutureExt, StreamExt as _, future::Shared};
use git::repository::DiffType;
use gpui::{
AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
WeakEntity, Window,
};
use http_client::StatusCode;
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
TokenUsage,
};
use postage::stream::Stream as _;
use project::{
Project,
git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
};
use prompt_store::{ModelContext, PromptBuilder};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{
io::Write,
ops::Range,
sync::Arc,
time::{Duration, Instant},
};
use thiserror::Error;
use util::{ResultExt as _, post_inc};
use uuid::Uuid;
const MAX_RETRY_ATTEMPTS: u8 = 4;
const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
#[derive(Debug, Clone)]
enum RetryStrategy {
ExponentialBackoff {
initial_delay: Duration,
max_attempts: u8,
},
Fixed {
delay: Duration,
max_attempts: u8,
},
}
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
)]
pub struct ThreadId(Arc<str>);
impl ThreadId {
pub fn new() -> Self {
Self(Uuid::new_v4().to_string().into())
}
}
impl std::fmt::Display for ThreadId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for ThreadId {
fn from(value: &str) -> Self {
Self(value.into())
}
}
/// The ID of the user prompt that initiated a request.
///
/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub struct PromptId(Arc<str>);
impl PromptId {
pub fn new() -> Self {
Self(Uuid::new_v4().to_string().into())
}
}
impl std::fmt::Display for PromptId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct MessageId(pub usize);
impl MessageId {
fn post_inc(&mut self) -> Self {
Self(post_inc(&mut self.0))
}
pub fn as_usize(&self) -> usize {
self.0
}
}
/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
#[derive(Clone, Debug)]
pub struct MessageCrease {
pub range: Range<usize>,
pub icon_path: SharedString,
pub label: SharedString,
/// None for a deserialized message, Some otherwise.
pub context: Option<AgentContextHandle>,
}
/// A message in a [`Thread`].
#[derive(Debug, Clone)]
pub struct Message {
pub id: MessageId,
pub role: Role,
pub segments: Vec<MessageSegment>,
pub loaded_context: LoadedContext,
pub creases: Vec<MessageCrease>,
pub is_hidden: bool,
pub ui_only: bool,
}
impl Message {
/// Returns whether the message contains any meaningful text that should be displayed
/// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
pub fn should_display_content(&self) -> bool {
self.segments.iter().all(|segment| segment.should_display())
}
pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
if let Some(MessageSegment::Thinking {
text: segment,
signature: current_signature,
}) = self.segments.last_mut()
{
if let Some(signature) = signature {
*current_signature = Some(signature);
}
segment.push_str(text);
} else {
self.segments.push(MessageSegment::Thinking {
text: text.to_string(),
signature,
});
}
}
pub fn push_redacted_thinking(&mut self, data: String) {
self.segments.push(MessageSegment::RedactedThinking(data));
}
pub fn push_text(&mut self, text: &str) {
if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
segment.push_str(text);
} else {
self.segments.push(MessageSegment::Text(text.to_string()));
}
}
pub fn to_message_content(&self) -> String {
let mut result = String::new();
if !self.loaded_context.text.is_empty() {
result.push_str(&self.loaded_context.text);
}
for segment in &self.segments {
match segment {
MessageSegment::Text(text) => result.push_str(text),
MessageSegment::Thinking { text, .. } => {
result.push_str("<think>\n");
result.push_str(text);
result.push_str("\n</think>");
}
MessageSegment::RedactedThinking(_) => {}
}
}
result
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageSegment {
Text(String),
Thinking {
text: String,
signature: Option<String>,
},
RedactedThinking(String),
}
impl MessageSegment {
pub fn should_display(&self) -> bool {
match self {
Self::Text(text) => text.is_empty(),
Self::Thinking { text, .. } => text.is_empty(),
Self::RedactedThinking(_) => false,
}
}
pub fn text(&self) -> Option<&str> {
match self {
MessageSegment::Text(text) => Some(text),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ProjectSnapshot {
pub worktree_snapshots: Vec<WorktreeSnapshot>,
pub unsaved_buffer_paths: Vec<String>,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WorktreeSnapshot {
pub worktree_path: String,
pub git_state: Option<GitState>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GitState {
pub remote_url: Option<String>,
pub head_sha: Option<String>,
pub current_branch: Option<String>,
pub diff: Option<String>,
}
#[derive(Clone, Debug)]
pub struct ThreadCheckpoint {
message_id: MessageId,
git_checkpoint: GitStoreCheckpoint,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ThreadFeedback {
Positive,
Negative,
}
pub enum LastRestoreCheckpoint {
Pending {
message_id: MessageId,
},
Error {
message_id: MessageId,
error: String,
},
}
impl LastRestoreCheckpoint {
pub fn message_id(&self) -> MessageId {
match self {
LastRestoreCheckpoint::Pending { message_id } => *message_id,
LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
}
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub enum DetailedSummaryState {
#[default]
NotGenerated,
Generating {
message_id: MessageId,
},
Generated {
text: SharedString,
message_id: MessageId,
},
}
impl DetailedSummaryState {
fn text(&self) -> Option<SharedString> {
if let Self::Generated { text, .. } = self {
Some(text.clone())
} else {
None
}
}
}
#[derive(Default, Debug)]
pub struct TotalTokenUsage {
pub total: u64,
pub max: u64,
}
impl TotalTokenUsage {
pub fn ratio(&self) -> TokenUsageRatio {
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
.unwrap_or("0.8".to_string())
.parse()
.unwrap();
#[cfg(not(debug_assertions))]
let warning_threshold: f32 = 0.8;
// When the maximum is unknown because there is no selected model,
// avoid showing the token limit warning.
if self.max == 0 {
TokenUsageRatio::Normal
} else if self.total >= self.max {
TokenUsageRatio::Exceeded
} else if self.total as f32 / self.max as f32 >= warning_threshold {
TokenUsageRatio::Warning
} else {
TokenUsageRatio::Normal
}
}
pub fn add(&self, tokens: u64) -> TotalTokenUsage {
TotalTokenUsage {
total: self.total + tokens,
max: self.max,
}
}
}
#[derive(Debug, Default, PartialEq, Eq)]
pub enum TokenUsageRatio {
#[default]
Normal,
Warning,
Exceeded,
}
#[derive(Debug, Clone, Copy)]
pub enum QueueState {
Sending,
Queued { position: usize },
Started,
}
/// A thread of conversation with the LLM.
pub struct Thread {
id: ThreadId,
updated_at: DateTime<Utc>,
summary: ThreadSummary,
pending_summary: Task<Option<()>>,
detailed_summary_task: Task<Option<()>>,
detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
completion_mode: agent_settings::CompletionMode,
messages: Vec<Message>,
next_message_id: MessageId,
last_prompt_id: PromptId,
project_context: SharedProjectContext,
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
project: Entity<Project>,
prompt_builder: Arc<PromptBuilder>,
tools: Entity<ToolWorkingSet>,
tool_use: ToolUseState,
action_log: Entity<ActionLog>,
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
pending_checkpoint: Option<ThreadCheckpoint>,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
request_token_usage: Vec<TokenUsage>,
cumulative_token_usage: TokenUsage,
exceeded_window_error: Option<ExceededWindowError>,
tool_use_limit_reached: bool,
retry_state: Option<RetryState>,
message_feedback: HashMap<MessageId, ThreadFeedback>,
last_received_chunk_at: Option<Instant>,
request_callback: Option<
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
>,
remaining_turns: u32,
configured_model: Option<ConfiguredModel>,
profile: AgentProfile,
last_error_context: Option<(Arc<dyn LanguageModel>, CompletionIntent)>,
}
#[derive(Clone, Debug)]
struct RetryState {
attempt: u8,
max_attempts: u8,
intent: CompletionIntent,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ThreadSummary {
Pending,
Generating,
Ready(SharedString),
Error,
}
impl ThreadSummary {
pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
pub fn or_default(&self) -> SharedString {
self.unwrap_or(Self::DEFAULT)
}
pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
self.ready().unwrap_or_else(|| message.into())
}
pub fn ready(&self) -> Option<SharedString> {
match self {
ThreadSummary::Ready(summary) => Some(summary.clone()),
ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ExceededWindowError {
/// Model used when last message exceeded context window
model_id: LanguageModelId,
/// Token count including last message
token_count: u64,
}
impl Thread {
pub fn new(
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
system_prompt: SharedProjectContext,
cx: &mut Context<Self>,
) -> Self {
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
Self {
id: ThreadId::new(),
updated_at: Utc::now(),
summary: ThreadSummary::Pending,
pending_summary: Task::ready(None),
detailed_summary_task: Task::ready(None),
detailed_summary_tx,
detailed_summary_rx,
completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
messages: Vec::new(),
next_message_id: MessageId(0),
last_prompt_id: PromptId::new(),
project_context: system_prompt,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
project: project.clone(),
prompt_builder,
tools: tools.clone(),
last_restore_checkpoint: None,
pending_checkpoint: None,
tool_use: ToolUseState::new(tools.clone()),
action_log: cx.new(|_| ActionLog::new(project.clone())),
initial_project_snapshot: {
let project_snapshot = Self::project_snapshot(project, cx);
cx.foreground_executor()
.spawn(async move { Some(project_snapshot.await) })
.shared()
},
request_token_usage: Vec::new(),
cumulative_token_usage: TokenUsage::default(),
exceeded_window_error: None,
tool_use_limit_reached: false,
retry_state: None,
message_feedback: HashMap::default(),
last_error_context: None,
last_received_chunk_at: None,
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
profile: AgentProfile::new(profile_id, tools),
}
}
pub fn deserialize(
id: ThreadId,
serialized: SerializedThread,
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext,
window: Option<&mut Window>, // None in headless mode
cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(
serialized
.messages
.last()
.map(|message| message.id.0 + 1)
.unwrap_or(0),
);
let tool_use = ToolUseState::from_serialized_messages(
tools.clone(),
&serialized.messages,
project.clone(),
window,
cx,
);
let (detailed_summary_tx, detailed_summary_rx) =
postage::watch::channel_with(serialized.detailed_summary_state);
let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
serialized
.model
.and_then(|model| {
let model = SelectedModel {
provider: model.provider.clone().into(),
model: model.model.into(),
};
registry.select_model(&model, cx)
})
.or_else(|| registry.default_model())
});
let completion_mode = serialized
.completion_mode
.unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
let profile_id = serialized
.profile
.unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
Self {
id,
updated_at: serialized.updated_at,
summary: ThreadSummary::Ready(serialized.summary),
pending_summary: Task::ready(None),
detailed_summary_task: Task::ready(None),
detailed_summary_tx,
detailed_summary_rx,
completion_mode,
retry_state: None,
messages: serialized
.messages
.into_iter()
.map(|message| Message {
id: message.id,
role: message.role,
segments: message
.segments
.into_iter()
.map(|segment| match segment {
SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
SerializedMessageSegment::Thinking { text, signature } => {
MessageSegment::Thinking { text, signature }
}
SerializedMessageSegment::RedactedThinking { data } => {
MessageSegment::RedactedThinking(data)
}
})
.collect(),
loaded_context: LoadedContext {
contexts: Vec::new(),
text: message.context,
images: Vec::new(),
},
creases: message
.creases
.into_iter()
.map(|crease| MessageCrease {
range: crease.start..crease.end,
icon_path: crease.icon_path,
label: crease.label,
context: None,
})
.collect(),
is_hidden: message.is_hidden,
ui_only: false, // UI-only messages are not persisted
})
.collect(),
next_message_id,
last_prompt_id: PromptId::new(),
project_context,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
last_restore_checkpoint: None,
pending_checkpoint: None,
project: project.clone(),
prompt_builder,
tools: tools.clone(),
tool_use,
action_log: cx.new(|_| ActionLog::new(project)),
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
request_token_usage: serialized.request_token_usage,
cumulative_token_usage: serialized.cumulative_token_usage,
exceeded_window_error: None,
tool_use_limit_reached: serialized.tool_use_limit_reached,
message_feedback: HashMap::default(),
last_error_context: None,
last_received_chunk_at: None,
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
profile: AgentProfile::new(profile_id, tools),
}
}
pub fn set_request_callback(
&mut self,
callback: impl 'static
+ FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
) {
self.request_callback = Some(Box::new(callback));
}
pub fn id(&self) -> &ThreadId {
&self.id
}
pub fn profile(&self) -> &AgentProfile {
&self.profile
}
pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
if &id != self.profile.id() {
self.profile = AgentProfile::new(id, self.tools.clone());
cx.emit(ThreadEvent::ProfileChanged);
}
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn updated_at(&self) -> DateTime<Utc> {
self.updated_at
}
pub fn touch_updated_at(&mut self) {
self.updated_at = Utc::now();
}
pub fn advance_prompt_id(&mut self) {
self.last_prompt_id = PromptId::new();
}
pub fn project_context(&self) -> SharedProjectContext {
self.project_context.clone()
}
pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
if self.configured_model.is_none() || self.messages.is_empty() {
self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
}
self.configured_model.clone()
}
pub fn configured_model(&self) -> Option<ConfiguredModel> {
self.configured_model.clone()
}
pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
self.configured_model = model;
cx.notify();
}
pub fn summary(&self) -> &ThreadSummary {
&self.summary
}
pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
let current_summary = match &self.summary {
ThreadSummary::Pending | ThreadSummary::Generating => return,
ThreadSummary::Ready(summary) => summary,
ThreadSummary::Error => &ThreadSummary::DEFAULT,
};
let mut new_summary = new_summary.into();
if new_summary.is_empty() {
new_summary = ThreadSummary::DEFAULT;
}
if current_summary != &new_summary {
self.summary = ThreadSummary::Ready(new_summary);
cx.emit(ThreadEvent::SummaryChanged);
}
}
pub fn completion_mode(&self) -> CompletionMode {
self.completion_mode
}
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
pub fn message(&self, id: MessageId) -> Option<&Message> {
let index = self
.messages
.binary_search_by(|message| message.id.cmp(&id))
.ok()?;
self.messages.get(index)
}
pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
self.messages.iter()
}
pub fn is_generating(&self) -> bool {
!self.pending_completions.is_empty() || !self.all_tools_finished()
}
/// Indicates whether streaming of language model events is stale.
/// When `is_generating()` is false, this method returns `None`.
pub fn is_generation_stale(&self) -> Option<bool> {
const STALE_THRESHOLD: u128 = 250;
self.last_received_chunk_at
.map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
}
fn received_chunk(&mut self) {
self.last_received_chunk_at = Some(Instant::now());
}
pub fn queue_state(&self) -> Option<QueueState> {
self.pending_completions
.first()
.map(|pending_completion| pending_completion.queue_state)
}
pub fn tools(&self) -> &Entity<ToolWorkingSet> {
&self.tools
}
pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
self.tool_use
.pending_tool_uses()
.into_iter()
.find(|tool_use| &tool_use.id == id)
}
pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
self.tool_use
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.needs_confirmation())
}
pub fn has_pending_tool_uses(&self) -> bool {
!self.tool_use.pending_tool_uses().is_empty()
}
pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
self.checkpoints_by_message.get(&id).cloned()
}
pub fn restore_checkpoint(
&mut self,
checkpoint: ThreadCheckpoint,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
message_id: checkpoint.message_id,
});
cx.emit(ThreadEvent::CheckpointChanged);
cx.notify();
let git_store = self.project().read(cx).git_store().clone();
let restore = git_store.update(cx, |git_store, cx| {
git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
});
cx.spawn(async move |this, cx| {
let result = restore.await;
this.update(cx, |this, cx| {
if let Err(err) = result.as_ref() {
this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
message_id: checkpoint.message_id,
error: err.to_string(),
});
} else {
this.truncate(checkpoint.message_id, cx);
this.last_restore_checkpoint = None;
}
this.pending_checkpoint = None;
cx.emit(ThreadEvent::CheckpointChanged);
cx.notify();
})?;
result
})
}
fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
let pending_checkpoint = if self.is_generating() {
return;
} else if let Some(checkpoint) = self.pending_checkpoint.take() {
checkpoint
} else {
return;
};
self.finalize_checkpoint(pending_checkpoint, cx);
}
fn finalize_checkpoint(
&mut self,
pending_checkpoint: ThreadCheckpoint,
cx: &mut Context<Self>,
) {
let git_store = self.project.read(cx).git_store().clone();
let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
cx.spawn(async move |this, cx| match final_checkpoint.await {
Ok(final_checkpoint) => {
let equal = git_store
.update(cx, |store, cx| {
store.compare_checkpoints(
pending_checkpoint.git_checkpoint.clone(),
final_checkpoint.clone(),
cx,
)
})?
.await
.unwrap_or(false);
this.update(cx, |this, cx| {
this.pending_checkpoint = if equal {
Some(pending_checkpoint)
} else {
this.insert_checkpoint(pending_checkpoint, cx);
Some(ThreadCheckpoint {
message_id: this.next_message_id,
git_checkpoint: final_checkpoint,
})
}
})?;
Ok(())
}
Err(_) => this.update(cx, |this, cx| {
this.insert_checkpoint(pending_checkpoint, cx)
}),
})
.detach();
}
fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
self.checkpoints_by_message
.insert(checkpoint.message_id, checkpoint);
cx.emit(ThreadEvent::CheckpointChanged);
cx.notify();
}
pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
self.last_restore_checkpoint.as_ref()
}
pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
let Some(message_ix) = self
.messages
.iter()
.rposition(|message| message.id == message_id)
else {
return;
};
for deleted_message in self.messages.drain(message_ix..) {
self.checkpoints_by_message.remove(&deleted_message.id);
}
cx.notify();
}
pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
self.messages
.iter()
.find(|message| message.id == id)
.into_iter()
.flat_map(|message| message.loaded_context.contexts.iter())
}
pub fn is_turn_end(&self, ix: usize) -> bool {
if self.messages.is_empty() {
return false;
}
if !self.is_generating() && ix == self.messages.len() - 1 {
return true;
}
let Some(message) = self.messages.get(ix) else {
return false;
};
if message.role != Role::Assistant {
return false;
}
self.messages
.get(ix + 1)
.and_then(|message| {
self.message(message.id)
.map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
})
.unwrap_or(false)
}
pub fn tool_use_limit_reached(&self) -> bool {
self.tool_use_limit_reached
}
/// Returns whether all of the tool uses have finished running.
pub fn all_tools_finished(&self) -> bool {
// If the only pending tool uses left are the ones with errors, then
// that means that we've finished running all of the pending tools.
self.tool_use
.pending_tool_uses()
.iter()
.all(|pending_tool_use| pending_tool_use.status.is_error())
}
/// Returns whether any pending tool uses may perform edits
pub fn has_pending_edit_tool_uses(&self) -> bool {
self.tool_use
.pending_tool_uses()
.iter()
.filter(|pending_tool_use| !pending_tool_use.status.is_error())
.any(|pending_tool_use| pending_tool_use.may_perform_edits)
}
pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
self.tool_use.tool_uses_for_message(id, &self.project, cx)
}
pub fn tool_results_for_message(
&self,
assistant_message_id: MessageId,
) -> Vec<&LanguageModelToolResult> {
self.tool_use.tool_results_for_message(assistant_message_id)
}
pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
self.tool_use.tool_result(id)
}
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
match &self.tool_use.tool_result(id)?.content {
LanguageModelToolResultContent::Text(text) => Some(text),
LanguageModelToolResultContent::Image(_) => {
// TODO: We should display image
None
}
}
}
pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
self.tool_use.tool_result_card(id).cloned()
}
/// Return tools that are both enabled and supported by the model
pub fn available_tools(
&self,
cx: &App,
model: Arc<dyn LanguageModel>,
) -> Vec<LanguageModelRequestTool> {
if model.supports_tools() {
self.profile
.enabled_tools(cx)
.into_iter()
.filter_map(|(name, tool)| {
// Skip tools that cannot be supported
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
Some(LanguageModelRequestTool {
name: name.into(),
description: tool.description(),
input_schema,
})
})
.collect()
} else {
Vec::default()
}
}
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
loaded_context: ContextLoadResult,
git_checkpoint: Option<GitStoreCheckpoint>,
creases: Vec<MessageCrease>,
cx: &mut Context<Self>,
) -> MessageId {
if !loaded_context.referenced_buffers.is_empty() {
self.action_log.update(cx, |log, cx| {
for buffer in loaded_context.referenced_buffers {
log.buffer_read(buffer, cx);
}
});
}
let message_id = self.insert_message(
Role::User,
vec![MessageSegment::Text(text.into())],
loaded_context.loaded_context,
creases,
false,
cx,
);
if let Some(git_checkpoint) = git_checkpoint {
self.pending_checkpoint = Some(ThreadCheckpoint {
message_id,
git_checkpoint,
});
}
message_id
}
pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
let id = self.insert_message(
Role::User,
vec![MessageSegment::Text("Continue where you left off".into())],
LoadedContext::default(),
vec![],
true,
cx,
);
self.pending_checkpoint = None;
id
}
pub fn insert_assistant_message(
&mut self,
segments: Vec<MessageSegment>,
cx: &mut Context<Self>,
) -> MessageId {
self.insert_message(
Role::Assistant,
segments,
LoadedContext::default(),
Vec::new(),
false,
cx,
)
}
pub fn insert_message(
&mut self,
role: Role,
segments: Vec<MessageSegment>,
loaded_context: LoadedContext,
creases: Vec<MessageCrease>,
is_hidden: bool,
cx: &mut Context<Self>,
) -> MessageId {
let id = self.next_message_id.post_inc();
self.messages.push(Message {
id,
role,
segments,
loaded_context,
creases,
is_hidden,
ui_only: false,
});
self.touch_updated_at();
cx.emit(ThreadEvent::MessageAdded(id));
id
}
pub fn edit_message(
&mut self,
id: MessageId,
new_role: Role,
new_segments: Vec<MessageSegment>,
creases: Vec<MessageCrease>,
loaded_context: Option<LoadedContext>,
checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>,
) -> bool {
let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
return false;
};
message.role = new_role;
message.segments = new_segments;
message.creases = creases;
if let Some(context) = loaded_context {
message.loaded_context = context;
}
if let Some(git_checkpoint) = checkpoint {
self.checkpoints_by_message.insert(
id,
ThreadCheckpoint {
message_id: id,
git_checkpoint,
},
);
}
self.touch_updated_at();
cx.emit(ThreadEvent::MessageEdited(id));
true
}
pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
let Some(index) = self.messages.iter().position(|message| message.id == id) else {
return false;
};
self.messages.remove(index);
self.touch_updated_at();
cx.emit(ThreadEvent::MessageDeleted(id));
true
}
/// Returns the representation of this [`Thread`] in a textual form.
///
/// This is the representation we use when attaching a thread as context to another thread.
pub fn text(&self) -> String {
let mut text = String::new();
for message in &self.messages {
text.push_str(match message.role {
language_model::Role::User => "User:",
language_model::Role::Assistant => "Agent:",
language_model::Role::System => "System:",
});
text.push('\n');
for segment in &message.segments {
match segment {
MessageSegment::Text(content) => text.push_str(content),
MessageSegment::Thinking { text: content, .. } => {
text.push_str(&format!("<think>{}</think>", content))
}
MessageSegment::RedactedThinking(_) => {}
}
}
text.push('\n');
}
text
}
/// Serializes this thread into a format for storage or telemetry.
pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
let initial_project_snapshot = self.initial_project_snapshot.clone();
cx.spawn(async move |this, cx| {
let initial_project_snapshot = initial_project_snapshot.await;
this.read_with(cx, |this, cx| SerializedThread {
version: SerializedThread::VERSION.to_string(),
summary: this.summary().or_default(),
updated_at: this.updated_at(),
messages: this
.messages()
.filter(|message| !message.ui_only)
.map(|message| SerializedMessage {
id: message.id,
role: message.role,
segments: message
.segments
.iter()
.map(|segment| match segment {
MessageSegment::Text(text) => {
SerializedMessageSegment::Text { text: text.clone() }
}
MessageSegment::Thinking { text, signature } => {
SerializedMessageSegment::Thinking {
text: text.clone(),
signature: signature.clone(),
}
}
MessageSegment::RedactedThinking(data) => {
SerializedMessageSegment::RedactedThinking {
data: data.clone(),
}
}
})
.collect(),
tool_uses: this
.tool_uses_for_message(message.id, cx)
.into_iter()
.map(|tool_use| SerializedToolUse {
id: tool_use.id,
name: tool_use.name,
input: tool_use.input,
})
.collect(),
tool_results: this
.tool_results_for_message(message.id)
.into_iter()
.map(|tool_result| SerializedToolResult {
tool_use_id: tool_result.tool_use_id.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
output: tool_result.output.clone(),
})
.collect(),
context: message.loaded_context.text.clone(),
creases: message
.creases
.iter()
.map(|crease| SerializedCrease {
start: crease.range.start,
end: crease.range.end,
icon_path: crease.icon_path.clone(),
label: crease.label.clone(),
})
.collect(),
is_hidden: message.is_hidden,
})
.collect(),
initial_project_snapshot,
cumulative_token_usage: this.cumulative_token_usage,
request_token_usage: this.request_token_usage.clone(),
detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
exceeded_window_error: this.exceeded_window_error.clone(),
model: this
.configured_model
.as_ref()
.map(|model| SerializedLanguageModel {
provider: model.provider.id().0.to_string(),
model: model.model.id().0.to_string(),
}),
completion_mode: Some(this.completion_mode),
tool_use_limit_reached: this.tool_use_limit_reached,
profile: Some(this.profile.id().clone()),
})
})
}
pub fn remaining_turns(&self) -> u32 {
self.remaining_turns
}
pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
self.remaining_turns = remaining_turns;
}
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
if self.remaining_turns == 0 {
return;
}
self.remaining_turns -= 1;
self.flush_notifications(model.clone(), intent, cx);
let _checkpoint = self.finalize_pending_checkpoint(cx);
self.stream_completion(
self.to_completion_request(model.clone(), intent, cx),
model,
intent,
window,
cx,
);
}
pub fn retry_last_completion(
&mut self,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
// Clear any existing error state
self.retry_state = None;
// Use the last error context if available, otherwise fall back to configured model
let (model, intent) = if let Some((model, intent)) = self.last_error_context.take() {
(model, intent)
} else if let Some(configured_model) = self.configured_model.as_ref() {
let model = configured_model.model.clone();
let intent = if self.has_pending_tool_uses() {
CompletionIntent::ToolResults
} else {
CompletionIntent::UserPrompt
};
(model, intent)
} else if let Some(configured_model) = self.get_or_init_configured_model(cx) {
let model = configured_model.model.clone();
let intent = if self.has_pending_tool_uses() {
CompletionIntent::ToolResults
} else {
CompletionIntent::UserPrompt
};
(model, intent)
} else {
return;
};
self.send_to_model(model, intent, window, cx);
}
pub fn enable_burn_mode_and_retry(
&mut self,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
self.completion_mode = CompletionMode::Burn;
cx.emit(ThreadEvent::ProfileChanged);
self.retry_last_completion(window, cx);
}
pub fn used_tools_since_last_user_message(&self) -> bool {
for message in self.messages.iter().rev() {
if self.tool_use.message_has_tool_results(message.id) {
return true;
} else if message.role == Role::User {
return false;
}
}
false
}
pub fn to_completion_request(
&self,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
cx: &mut Context<Self>,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
thread_id: Some(self.id.to_string()),
prompt_id: Some(self.last_prompt_id.to_string()),
intent: Some(intent),
mode: None,
messages: vec![],
tools: Vec::new(),
tool_choice: None,
stop: Vec::new(),
temperature: AgentSettings::temperature_for_model(&model, cx),
thinking_allowed: true,
};
let available_tools = self.available_tools(cx, model.clone());
let available_tool_names = available_tools
.iter()
.map(|tool| tool.name.clone())
.collect();
let model_context = &ModelContext {
available_tools: available_tool_names,
};
if let Some(project_context) = self.project_context.borrow().as_ref() {
match self
.prompt_builder
.generate_assistant_system_prompt(project_context, model_context)
{
Err(err) => {
let message = format!("{err:?}").into();
log::error!("{message}");
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
header: "Error generating system prompt".into(),
message,
}));
}
Ok(system_prompt) => {
request.messages.push(LanguageModelRequestMessage {
role: Role::System,
content: vec![MessageContent::Text(system_prompt)],
cache: true,
});
}
}
} else {
let message = "Context for system prompt unexpectedly not ready.".into();
log::error!("{message}");
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
header: "Error generating system prompt".into(),
message,
}));
}
let mut message_ix_to_cache = None;
for message in &self.messages {
// ui_only messages are for the UI only, not for the model
if message.ui_only {
continue;
}
let mut request_message = LanguageModelRequestMessage {
role: message.role,
content: Vec::new(),
cache: false,
};
message
.loaded_context
.add_to_request_message(&mut request_message);
for segment in &message.segments {
match segment {
MessageSegment::Text(text) => {
let text = text.trim_end();
if !text.is_empty() {
request_message
.content
.push(MessageContent::Text(text.into()));
}
}
MessageSegment::Thinking { text, signature } => {
if !text.is_empty() {
request_message.content.push(MessageContent::Thinking {
text: text.into(),
signature: signature.clone(),
});
}
}
MessageSegment::RedactedThinking(data) => {
request_message
.content
.push(MessageContent::RedactedThinking(data.clone()));
}
};
}
let mut cache_message = true;
let mut tool_results_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
};
for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
if let Some(tool_result) = tool_result {
request_message
.content
.push(MessageContent::ToolUse(tool_use.clone()));
tool_results_message
.content
.push(MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_use.id.clone(),
tool_name: tool_result.tool_name.clone(),
is_error: tool_result.is_error,
content: if tool_result.content.is_empty() {
// Surprisingly, the API fails if we return an empty string here.
// It thinks we are sending a tool use without a tool result.
"<Tool returned an empty string>".into()
} else {
tool_result.content.clone()
},
output: None,
}));
} else {
cache_message = false;
log::debug!(
"skipped tool use {:?} because it is still pending",
tool_use
);
}
}
if cache_message {
message_ix_to_cache = Some(request.messages.len());
}
request.messages.push(request_message);
if !tool_results_message.content.is_empty() {
if cache_message {
message_ix_to_cache = Some(request.messages.len());
}
request.messages.push(tool_results_message);
}
}
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
if let Some(message_ix_to_cache) = message_ix_to_cache {
request.messages[message_ix_to_cache].cache = true;
}
request.tools = available_tools;
request.mode = if model.supports_burn_mode() {
Some(self.completion_mode.into())
} else {
Some(CompletionMode::Normal.into())
};
request
}
fn to_summarize_request(
&self,
model: &Arc<dyn LanguageModel>,
intent: CompletionIntent,
added_user_message: String,
cx: &App,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: Some(intent),
mode: None,
messages: vec![],
tools: Vec::new(),
tool_choice: None,
stop: Vec::new(),
temperature: AgentSettings::temperature_for_model(model, cx),
thinking_allowed: false,
};
for message in &self.messages {
let mut request_message = LanguageModelRequestMessage {
role: message.role,
content: Vec::new(),
cache: false,
};
for segment in &message.segments {
match segment {
MessageSegment::Text(text) => request_message
.content
.push(MessageContent::Text(text.clone())),
MessageSegment::Thinking { .. } => {}
MessageSegment::RedactedThinking(_) => {}
}
}
if request_message.content.is_empty() {
continue;
}
request.messages.push(request_message);
}
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text(added_user_message)],
cache: false,
});
request
}
/// Insert auto-generated notifications (if any) to the thread
fn flush_notifications(
&mut self,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
cx: &mut Context<Self>,
) {
match intent {
CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
cx.emit(ThreadEvent::ToolFinished {
tool_use_id: pending_tool_use.id.clone(),
pending_tool_use: Some(pending_tool_use),
});
}
}
CompletionIntent::ThreadSummarization
| CompletionIntent::ThreadContextSummarization
| CompletionIntent::CreateFile
| CompletionIntent::EditFile
| CompletionIntent::InlineAssist
| CompletionIntent::TerminalInlineAssist
| CompletionIntent::GenerateGitCommitMessage => {}
};
}
fn attach_tracked_files_state(
&mut self,
model: Arc<dyn LanguageModel>,
cx: &mut App,
) -> Option<PendingToolUse> {
// Represent notification as a simulated `project_notifications` tool call
let tool_name = Arc::from("project_notifications");
let tool = self.tools.read(cx).tool(&tool_name, cx)?;
if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
return None;
}
if self
.action_log
.update(cx, |log, cx| log.unnotified_user_edits(cx).is_none())
{
return None;
}
let input = serde_json::json!({});
let request = Arc::new(LanguageModelRequest::default()); // unused
let window = None;
let tool_result = tool.run(
input,
request,
self.project.clone(),
self.action_log.clone(),
model.clone(),
window,
cx,
);
let tool_use_id =
LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
let tool_use = LanguageModelToolUse {
id: tool_use_id.clone(),
name: tool_name.clone(),
raw_input: "{}".to_string(),
input: serde_json::json!({}),
is_input_complete: true,
};
let tool_output = cx.background_executor().block(tool_result.output);
// Attach a project_notification tool call to the latest existing
// Assistant message. We cannot create a new Assistant message
// because thinking models require a `thinking` block that we
// cannot mock. We cannot send a notification as a normal
// (non-tool-use) User message because this distracts Agent
// too much.
let tool_message_id = self
.messages
.iter()
.enumerate()
.rfind(|(_, message)| message.role == Role::Assistant)
.map(|(_, message)| message.id)?;
let tool_use_metadata = ToolUseMetadata {
model: model.clone(),
thread_id: self.id.clone(),
prompt_id: self.last_prompt_id.clone(),
};
self.tool_use
.request_tool_use(tool_message_id, tool_use, tool_use_metadata, cx);
self.tool_use.insert_tool_output(
tool_use_id,
tool_name,
tool_output,
self.configured_model.as_ref(),
self.completion_mode,
)
}
pub fn stream_completion(
&mut self,
request: LanguageModelRequest,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
self.tool_use_limit_reached = false;
let pending_completion_id = post_inc(&mut self.completion_count);
let mut request_callback_parameters = if self.request_callback.is_some() {
Some((request.clone(), Vec::new()))
} else {
None
};
let prompt_id = self.last_prompt_id.clone();
let tool_use_metadata = ToolUseMetadata {
model: model.clone(),
thread_id: self.id.clone(),
prompt_id: prompt_id.clone(),
};
let completion_mode = request
.mode
.unwrap_or(cloud_llm_client::CompletionMode::Normal);
self.last_received_chunk_at = Some(Instant::now());
let task = cx.spawn(async move |thread, cx| {
let stream_completion_future = model.stream_completion(request, cx);
let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
let stream_completion = async {
let mut events = stream_completion_future.await?;
let mut stop_reason = StopReason::EndTurn;
let mut current_token_usage = TokenUsage::default();
thread
.update(cx, |_thread, cx| {
cx.emit(ThreadEvent::NewRequest);
})
.ok();
let mut request_assistant_message_id = None;
while let Some(event) = events.next().await {
if let Some((_, response_events)) = request_callback_parameters.as_mut() {
response_events
.push(event.as_ref().map_err(|error| error.to_string()).cloned());
}
thread.update(cx, |thread, cx| {
match event? {
LanguageModelCompletionEvent::StartMessage { .. } => {
request_assistant_message_id =
Some(thread.insert_assistant_message(
vec![MessageSegment::Text(String::new())],
cx,
));
}
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
}
LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
thread.update_token_usage_at_last_message(token_usage);
thread.cumulative_token_usage = thread.cumulative_token_usage
+ token_usage
- current_token_usage;
current_token_usage = token_usage;
}
LanguageModelCompletionEvent::Text(chunk) => {
thread.received_chunk();
cx.emit(ThreadEvent::ReceivedTextChunk);
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant
&& !thread.tool_use.has_tool_results(last_message.id)
{
last_message.push_text(&chunk);
cx.emit(ThreadEvent::StreamedAssistantText(
last_message.id,
chunk,
));
} else {
// If we won't have an Assistant message yet, assume this chunk marks the beginning
// of a new Assistant response.
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
request_assistant_message_id =
Some(thread.insert_assistant_message(
vec![MessageSegment::Text(chunk.to_string())],
cx,
));
};
}
}
LanguageModelCompletionEvent::Thinking {
text: chunk,
signature,
} => {
thread.received_chunk();
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant
&& !thread.tool_use.has_tool_results(last_message.id)
{
last_message.push_thinking(&chunk, signature);
cx.emit(ThreadEvent::StreamedAssistantThinking(
last_message.id,
chunk,
));
} else {
// If we won't have an Assistant message yet, assume this chunk marks the beginning
// of a new Assistant response.
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
request_assistant_message_id =
Some(thread.insert_assistant_message(
vec![MessageSegment::Thinking {
text: chunk.to_string(),
signature,
}],
cx,
));
};
}
}
LanguageModelCompletionEvent::RedactedThinking { data } => {
thread.received_chunk();
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant
&& !thread.tool_use.has_tool_results(last_message.id)
{
last_message.push_redacted_thinking(data);
} else {
request_assistant_message_id =
Some(thread.insert_assistant_message(
vec![MessageSegment::RedactedThinking(data)],
cx,
));
};
}
}
LanguageModelCompletionEvent::ToolUse(tool_use) => {
let last_assistant_message_id = request_assistant_message_id
.unwrap_or_else(|| {
let new_assistant_message_id =
thread.insert_assistant_message(vec![], cx);
request_assistant_message_id =
Some(new_assistant_message_id);
new_assistant_message_id
});
let tool_use_id = tool_use.id.clone();
let streamed_input = if tool_use.is_input_complete {
None
} else {
Some(tool_use.input.clone())
};
let ui_text = thread.tool_use.request_tool_use(
last_assistant_message_id,
tool_use,
tool_use_metadata.clone(),
cx,
);
if let Some(input) = streamed_input {
cx.emit(ThreadEvent::StreamedToolUse {
tool_use_id,
ui_text,
input,
});
}
}
LanguageModelCompletionEvent::ToolUseJsonParseError {
id,
tool_name,
raw_input: invalid_input_json,
json_parse_error,
} => {
thread.receive_invalid_tool_json(
id,
tool_name,
invalid_input_json,
json_parse_error,
window,
cx,
);
}
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
if let Some(completion) = thread
.pending_completions
.iter_mut()
.find(|completion| completion.id == pending_completion_id)
{
match status_update {
CompletionRequestStatus::Queued { position } => {
completion.queue_state =
QueueState::Queued { position };
}
CompletionRequestStatus::Started => {
completion.queue_state = QueueState::Started;
}
CompletionRequestStatus::Failed {
code,
message,
request_id: _,
retry_after,
} => {
return Err(
LanguageModelCompletionError::from_cloud_failure(
model.upstream_provider_name(),
code,
message,
retry_after.map(Duration::from_secs_f64),
),
);
}
CompletionRequestStatus::UsageUpdated { amount, limit } => {
thread.update_model_request_usage(
amount as u32,
limit,
cx,
);
}
CompletionRequestStatus::ToolUseLimitReached => {
thread.tool_use_limit_reached = true;
cx.emit(ThreadEvent::ToolUseLimitReached);
}
}
}
}
}
thread.touch_updated_at();
cx.emit(ThreadEvent::StreamedCompletion);
cx.notify();
Ok(())
})??;
smol::future::yield_now().await;
}
thread.update(cx, |thread, cx| {
thread.last_received_chunk_at = None;
thread
.pending_completions
.retain(|completion| completion.id != pending_completion_id);
// If there is a response without tool use, summarize the message. Otherwise,
// allow two tool uses before summarizing.
if matches!(thread.summary, ThreadSummary::Pending)
&& thread.messages.len() >= 2
&& (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
{
thread.summarize(cx);
}
})?;
anyhow::Ok(stop_reason)
};
let result = stream_completion.await;
let mut retry_scheduled = false;
thread
.update(cx, |thread, cx| {
thread.finalize_pending_checkpoint(cx);
match result.as_ref() {
Ok(stop_reason) => {
match stop_reason {
StopReason::ToolUse => {
let tool_uses =
thread.use_pending_tools(window, model.clone(), cx);
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
}
StopReason::EndTurn | StopReason::MaxTokens => {
thread.project.update(cx, |project, cx| {
project.set_agent_location(None, cx);
});
}
StopReason::Refusal => {
thread.project.update(cx, |project, cx| {
project.set_agent_location(None, cx);
});
// Remove the turn that was refused.
//
// https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
{
let mut messages_to_remove = Vec::new();
for (ix, message) in
thread.messages.iter().enumerate().rev()
{
messages_to_remove.push(message.id);
if message.role == Role::User {
if ix == 0 {
break;
}
if let Some(prev_message) =
thread.messages.get(ix - 1)
&& prev_message.role == Role::Assistant {
break;
}
}
}
for message_id in messages_to_remove {
thread.delete_message(message_id, cx);
}
}
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
header: "Language model refusal".into(),
message:
"Model refused to generate content for safety reasons."
.into(),
}));
}
}
// We successfully completed, so cancel any remaining retries.
thread.retry_state = None;
}
Err(error) => {
thread.project.update(cx, |project, cx| {
project.set_agent_location(None, cx);
});
if error.is::<PaymentRequiredError>() {
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
} else if let Some(error) =
error.downcast_ref::<ModelRequestLimitReachedError>()
{
cx.emit(ThreadEvent::ShowError(
ThreadError::ModelRequestLimitReached { plan: error.plan },
));
} else if let Some(completion_error) =
error.downcast_ref::<LanguageModelCompletionError>()
{
match &completion_error {
LanguageModelCompletionError::PromptTooLarge {
tokens, ..
} => {
let tokens = tokens.unwrap_or_else(|| {
// We didn't get an exact token count from the API, so fall back on our estimate.
thread
.total_token_usage()
.map(|usage| usage.total)
.unwrap_or(0)
// We know the context window was exceeded in practice, so if our estimate was
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
.max(
model
.max_token_count_for_mode(completion_mode)
.saturating_add(1),
)
});
thread.exceeded_window_error = Some(ExceededWindowError {
model_id: model.id(),
token_count: tokens,
});
cx.notify();
}
_ => {
if let Some(retry_strategy) =
Thread::get_retry_strategy(completion_error)
{
log::info!(
"Retrying with {:?} for language model completion error {:?}",
retry_strategy,
completion_error
);
retry_scheduled = thread
.handle_retryable_error_with_delay(
completion_error,
Some(retry_strategy),
model.clone(),
intent,
window,
cx,
);
}
}
}
}
if !retry_scheduled {
thread.cancel_last_completion(window, cx);
}
}
}
if !retry_scheduled {
cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
}
if let Some((request_callback, (request, response_events))) = thread
.request_callback
.as_mut()
.zip(request_callback_parameters.as_ref())
{
request_callback(request, response_events);
}
if let Ok(initial_usage) = initial_token_usage {
let usage = thread.cumulative_token_usage - initial_usage;
telemetry::event!(
"Assistant Thread Completion",
thread_id = thread.id().to_string(),
prompt_id = prompt_id,
model = model.telemetry_id(),
model_provider = model.provider_id().to_string(),
input_tokens = usage.input_tokens,
output_tokens = usage.output_tokens,
cache_creation_input_tokens = usage.cache_creation_input_tokens,
cache_read_input_tokens = usage.cache_read_input_tokens,
);
}
})
.ok();
});
self.pending_completions.push(PendingCompletion {
id: pending_completion_id,
queue_state: QueueState::Sending,
_task: task,
});
}
pub fn summarize(&mut self, cx: &mut Context<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model(cx) else {
println!("No thread summary model");
return;
};
if !model.provider.is_authenticated(cx) {
return;
}
let request = self.to_summarize_request(
&model.model,
CompletionIntent::ThreadSummarization,
SUMMARIZE_THREAD_PROMPT.into(),
cx,
);
self.summary = ThreadSummary::Generating;
self.pending_summary = cx.spawn(async move |this, cx| {
let result = async {
let mut messages = model.model.stream_completion(request, cx).await?;
let mut new_summary = String::new();
while let Some(event) = messages.next().await {
let Ok(event) = event else {
continue;
};
let text = match event {
LanguageModelCompletionEvent::Text(text) => text,
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::UsageUpdated { amount, limit },
) => {
this.update(cx, |thread, cx| {
thread.update_model_request_usage(amount as u32, limit, cx);
})?;
continue;
}
_ => continue,
};
let mut lines = text.lines();
new_summary.extend(lines.next());
// Stop if the LLM generated multiple lines.
if lines.next().is_some() {
break;
}
}
anyhow::Ok(new_summary)
}
.await;
this.update(cx, |this, cx| {
match result {
Ok(new_summary) => {
if new_summary.is_empty() {
this.summary = ThreadSummary::Error;
} else {
this.summary = ThreadSummary::Ready(new_summary.into());
}
}
Err(err) => {
this.summary = ThreadSummary::Error;
log::error!("Failed to generate thread summary: {}", err);
}
}
cx.emit(ThreadEvent::SummaryGenerated);
})
.log_err()?;
Some(())
});
}
fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
use LanguageModelCompletionError::*;
// General strategy here:
// - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
// - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
// - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
match error {
HttpResponseError {
status_code: StatusCode::TOO_MANY_REQUESTS,
..
} => Some(RetryStrategy::ExponentialBackoff {
initial_delay: BASE_RETRY_DELAY,
max_attempts: MAX_RETRY_ATTEMPTS,
}),
ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
max_attempts: MAX_RETRY_ATTEMPTS,
})
}
UpstreamProviderError {
status,
retry_after,
..
} => match *status {
StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
max_attempts: MAX_RETRY_ATTEMPTS,
})
}
StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
// Internal Server Error could be anything, retry up to 3 times.
max_attempts: 3,
}),
status => {
// There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
// but we frequently get them in practice. See https://http.dev/529
if status.as_u16() == 529 {
Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
max_attempts: MAX_RETRY_ATTEMPTS,
})
} else {
Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
max_attempts: 2,
})
}
}
},
ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 3,
}),
ApiReadResponseError { .. }
| HttpSend { .. }
| DeserializeResponse { .. }
| BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 3,
}),
// Retrying these errors definitely shouldn't help.
HttpResponseError {
status_code:
StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
..
}
| AuthenticationError { .. }
| PermissionError { .. }
| NoApiKey { .. }
| ApiEndpointNotFound { .. }
| PromptTooLarge { .. } => None,
// These errors might be transient, so retry them
SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 1,
}),
// Retry all other 4xx and 5xx errors once.
HttpResponseError { status_code, .. }
if status_code.is_client_error() || status_code.is_server_error() =>
{
Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 3,
})
}
Other(err)
if err.is::<PaymentRequiredError>()
|| err.is::<ModelRequestLimitReachedError>() =>
{
// Retrying won't help for Payment Required or Model Request Limit errors (where
// the user must upgrade to usage-based billing to get more requests, or else wait
// for a significant amount of time for the request limit to reset).
None
}
// Conservatively assume that any other errors are non-retryable
HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 2,
}),
}
}
fn handle_retryable_error_with_delay(
&mut self,
error: &LanguageModelCompletionError,
strategy: Option<RetryStrategy>,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) -> bool {
// Store context for the Retry button
self.last_error_context = Some((model.clone(), intent));
// Only auto-retry if Burn Mode is enabled
if self.completion_mode != CompletionMode::Burn {
// Show error with retry options
cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
message: format!(
"{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.",
error
)
.into(),
can_enable_burn_mode: true,
}));
return false;
}
let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
return false;
};
let max_attempts = match &strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
};
let retry_state = self.retry_state.get_or_insert(RetryState {
attempt: 0,
max_attempts,
intent,
});
retry_state.attempt += 1;
let attempt = retry_state.attempt;
let max_attempts = retry_state.max_attempts;
let intent = retry_state.intent;
if attempt <= max_attempts {
let delay = match &strategy {
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
Duration::from_secs(delay_secs)
}
RetryStrategy::Fixed { delay, .. } => *delay,
};
// Add a transient message to inform the user
let delay_secs = delay.as_secs();
let retry_message = if max_attempts == 1 {
format!("{error}. Retrying in {delay_secs} seconds...")
} else {
format!(
"{error}. Retrying (attempt {attempt} of {max_attempts}) \
in {delay_secs} seconds..."
)
};
log::warn!(
"Retrying completion request (attempt {attempt} of {max_attempts}) \
in {delay_secs} seconds: {error:?}",
);
// Add a UI-only message instead of a regular message
let id = self.next_message_id.post_inc();
self.messages.push(Message {
id,
role: Role::System,
segments: vec![MessageSegment::Text(retry_message)],
loaded_context: LoadedContext::default(),
creases: Vec::new(),
is_hidden: false,
ui_only: true,
});
cx.emit(ThreadEvent::MessageAdded(id));
// Schedule the retry
let thread_handle = cx.entity().downgrade();
cx.spawn(async move |_thread, cx| {
cx.background_executor().timer(delay).await;
thread_handle
.update(cx, |thread, cx| {
// Retry the completion
thread.send_to_model(model, intent, window, cx);
})
.log_err();
})
.detach();
true
} else {
// Max retries exceeded
self.retry_state = None;
// Stop generating since we're giving up on retrying.
self.pending_completions.clear();
// Show error alongside a Retry button, but no
// Enable Burn Mode button (since it's already enabled)
cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
message: format!("Failed after retrying: {}", error).into(),
can_enable_burn_mode: false,
}));
false
}
}
pub fn start_generating_detailed_summary_if_needed(
&mut self,
thread_store: WeakEntity<ThreadStore>,
cx: &mut Context<Self>,
) {
let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
return;
};
match &*self.detailed_summary_rx.borrow() {
DetailedSummaryState::Generating { message_id, .. }
| DetailedSummaryState::Generated { message_id, .. }
if *message_id == last_message_id =>
{
// Already up-to-date
return;
}
_ => {}
}
let Some(ConfiguredModel { model, provider }) =
LanguageModelRegistry::read_global(cx).thread_summary_model(cx)
else {
return;
};
if !provider.is_authenticated(cx) {
return;
}
let request = self.to_summarize_request(
&model,
CompletionIntent::ThreadContextSummarization,
SUMMARIZE_THREAD_DETAILED_PROMPT.into(),
cx,
);
*self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
message_id: last_message_id,
};
// Replace the detailed summarization task if there is one, cancelling it. It would probably
// be better to allow the old task to complete, but this would require logic for choosing
// which result to prefer (the old task could complete after the new one, resulting in a
// stale summary).
self.detailed_summary_task = cx.spawn(async move |thread, cx| {
let stream = model.stream_completion_text(request, cx);
let Some(mut messages) = stream.await.log_err() else {
thread
.update(cx, |thread, _cx| {
*thread.detailed_summary_tx.borrow_mut() =
DetailedSummaryState::NotGenerated;
})
.ok()?;
return None;
};
let mut new_detailed_summary = String::new();
while let Some(chunk) = messages.stream.next().await {
if let Some(chunk) = chunk.log_err() {
new_detailed_summary.push_str(&chunk);
}
}
thread
.update(cx, |thread, _cx| {
*thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
text: new_detailed_summary.into(),
message_id: last_message_id,
};
})
.ok()?;
// Save thread so its summary can be reused later
if let Some(thread) = thread.upgrade()
&& let Ok(Ok(save_task)) = cx.update(|cx| {
thread_store
.update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
})
{
save_task.await.log_err();
}
Some(())
});
}
pub async fn wait_for_detailed_summary_or_text(
this: &Entity<Self>,
cx: &mut AsyncApp,
) -> Option<SharedString> {
let mut detailed_summary_rx = this
.read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
.ok()?;
loop {
match detailed_summary_rx.recv().await? {
DetailedSummaryState::Generating { .. } => {}
DetailedSummaryState::NotGenerated => {
return this.read_with(cx, |this, _cx| this.text().into()).ok();
}
DetailedSummaryState::Generated { text, .. } => return Some(text),
}
}
}
pub fn latest_detailed_summary_or_text(&self) -> SharedString {
self.detailed_summary_rx
.borrow()
.text()
.unwrap_or_else(|| self.text().into())
}
pub fn is_generating_detailed_summary(&self) -> bool {
matches!(
&*self.detailed_summary_rx.borrow(),
DetailedSummaryState::Generating { .. }
)
}
pub fn use_pending_tools(
&mut self,
window: Option<AnyWindowHandle>,
model: Arc<dyn LanguageModel>,
cx: &mut Context<Self>,
) -> Vec<PendingToolUse> {
let request =
Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
let pending_tool_uses = self
.tool_use
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
.cloned()
.collect::<Vec<_>>();
for tool_use in pending_tool_uses.iter() {
self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
}
pending_tool_uses
}
fn use_pending_tool(
&mut self,
tool_use: PendingToolUse,
request: Arc<LanguageModelRequest>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
};
if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
}
if tool.needs_confirmation(&tool_use.input, &self.project, cx)
&& !AgentSettings::get_global(cx).always_allow_tool_actions
{
self.tool_use.confirm_tool_use(
tool_use.id,
tool_use.ui_text,
tool_use.input,
request,
tool,
);
cx.emit(ThreadEvent::ToolConfirmationNeeded);
} else {
self.run_tool(
tool_use.id,
tool_use.ui_text,
tool_use.input,
request,
tool,
model,
window,
cx,
);
}
}
pub fn handle_hallucinated_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
hallucinated_tool_name: Arc<str>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) {
let available_tools = self.profile.enabled_tools(cx);
let tool_list = available_tools
.iter()
.map(|(name, tool)| format!("- {}: {}", name, tool.description()))
.collect::<Vec<_>>()
.join("\n");
let error_message = format!(
"The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
hallucinated_tool_name, tool_list
);
let pending_tool_use = self.tool_use.insert_tool_output(
tool_use_id.clone(),
hallucinated_tool_name,
Err(anyhow!("Missing tool call: {error_message}")),
self.configured_model.as_ref(),
self.completion_mode,
);
cx.emit(ThreadEvent::MissingToolUse {
tool_use_id: tool_use_id.clone(),
ui_text: error_message.into(),
});
self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
}
pub fn receive_invalid_tool_json(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
invalid_json: Arc<str>,
error: String,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) {
log::error!("The model returned invalid input JSON: {invalid_json}");
let pending_tool_use = self.tool_use.insert_tool_output(
tool_use_id.clone(),
tool_name,
Err(anyhow!("Error parsing input JSON: {error}")),
self.configured_model.as_ref(),
self.completion_mode,
);
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
pending_tool_use.ui_text.clone()
} else {
log::error!(
"There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
);
format!("Unknown tool {}", tool_use_id).into()
};
cx.emit(ThreadEvent::InvalidToolInput {
tool_use_id: tool_use_id.clone(),
ui_text,
invalid_input_json: invalid_json,
});
self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
}
pub fn run_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: impl Into<SharedString>,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
tool: Arc<dyn Tool>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) {
let task =
self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
self.tool_use
.run_pending_tool(tool_use_id, ui_text.into(), task);
}
fn spawn_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
request: Arc<LanguageModelRequest>,
input: serde_json::Value,
tool: Arc<dyn Tool>,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) -> Task<()> {
let tool_name: Arc<str> = tool.name().into();
let tool_result = tool.run(
input,
request,
self.project.clone(),
self.action_log.clone(),
model,
window,
cx,
);
// Store the card separately if it exists
if let Some(card) = tool_result.card.clone() {
self.tool_use
.insert_tool_result_card(tool_use_id.clone(), card);
}
cx.spawn({
async move |thread: WeakEntity<Thread>, cx| {
let output = tool_result.output.await;
thread
.update(cx, |thread, cx| {
let pending_tool_use = thread.tool_use.insert_tool_output(
tool_use_id.clone(),
tool_name,
output,
thread.configured_model.as_ref(),
thread.completion_mode,
);
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
})
.ok();
}
})
}
fn tool_finished(
&mut self,
tool_use_id: LanguageModelToolUseId,
pending_tool_use: Option<PendingToolUse>,
canceled: bool,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
if self.all_tools_finished()
&& let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref()
&& !canceled
{
self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
}
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
});
}
/// Cancels the last pending completion, if there are any pending.
///
/// Returns whether a completion was canceled.
pub fn cancel_last_completion(
&mut self,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) -> bool {
let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
self.retry_state = None;
for pending_tool_use in self.tool_use.cancel_pending() {
canceled = true;
self.tool_finished(
pending_tool_use.id.clone(),
Some(pending_tool_use),
true,
window,
cx,
);
}
if canceled {
cx.emit(ThreadEvent::CompletionCanceled);
// When canceled, we always want to insert the checkpoint.
// (We skip over finalize_pending_checkpoint, because it
// would conclude we didn't have anything to insert here.)
if let Some(checkpoint) = self.pending_checkpoint.take() {
self.insert_checkpoint(checkpoint, cx);
}
} else {
self.finalize_pending_checkpoint(cx);
}
canceled
}
/// Signals that any in-progress editing should be canceled.
///
/// This method is used to notify listeners (like ActiveThread) that
/// they should cancel any editing operations.
pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
cx.emit(ThreadEvent::CancelEditing);
}
pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
self.message_feedback.get(&message_id).copied()
}
pub fn report_message_feedback(
&mut self,
message_id: MessageId,
feedback: ThreadFeedback,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
if self.message_feedback.get(&message_id) == Some(&feedback) {
return Task::ready(Ok(()));
}
let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
let serialized_thread = self.serialize(cx);
let thread_id = self.id().clone();
let client = self.project.read(cx).client();
let enabled_tool_names: Vec<String> = self
.profile
.enabled_tools(cx)
.iter()
.map(|(name, _)| name.clone().into())
.collect();
self.message_feedback.insert(message_id, feedback);
cx.notify();
let message_content = self
.message(message_id)
.map(|msg| msg.to_message_content())
.unwrap_or_default();
cx.background_spawn(async move {
let final_project_snapshot = final_project_snapshot.await;
let serialized_thread = serialized_thread.await?;
let thread_data =
serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
let rating = match feedback {
ThreadFeedback::Positive => "positive",
ThreadFeedback::Negative => "negative",
};
telemetry::event!(
"Assistant Thread Rated",
rating,
thread_id,
enabled_tool_names,
message_id = message_id.0,
message_content,
thread_data,
final_project_snapshot
);
client.telemetry().flush_events().await;
Ok(())
})
}
/// Create a snapshot of the current project state including git information and unsaved buffers.
fn project_snapshot(
project: Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Arc<ProjectSnapshot>> {
let git_store = project.read(cx).git_store().clone();
let worktree_snapshots: Vec<_> = project
.read(cx)
.visible_worktrees(cx)
.map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
.collect();
cx.spawn(async move |_, cx| {
let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
let mut unsaved_buffers = Vec::new();
cx.update(|app_cx| {
let buffer_store = project.read(app_cx).buffer_store();
for buffer_handle in buffer_store.read(app_cx).buffers() {
let buffer = buffer_handle.read(app_cx);
if buffer.is_dirty()
&& let Some(file) = buffer.file()
{
let path = file.path().to_string_lossy().to_string();
unsaved_buffers.push(path);
}
}
})
.ok();
Arc::new(ProjectSnapshot {
worktree_snapshots,
unsaved_buffer_paths: unsaved_buffers,
timestamp: Utc::now(),
})
})
}
fn worktree_snapshot(
worktree: Entity<project::Worktree>,
git_store: Entity<GitStore>,
cx: &App,
) -> Task<WorktreeSnapshot> {
cx.spawn(async move |cx| {
// Get worktree path and snapshot
let worktree_info = cx.update(|app_cx| {
let worktree = worktree.read(app_cx);
let path = worktree.abs_path().to_string_lossy().to_string();
let snapshot = worktree.snapshot();
(path, snapshot)
});
let Ok((worktree_path, _snapshot)) = worktree_info else {
return WorktreeSnapshot {
worktree_path: String::new(),
git_state: None,
};
};
let git_state = git_store
.update(cx, |git_store, cx| {
git_store
.repositories()
.values()
.find(|repo| {
repo.read(cx)
.abs_path_to_repo_path(&worktree.read(cx).abs_path())
.is_some()
})
.cloned()
})
.ok()
.flatten()
.map(|repo| {
repo.update(cx, |repo, _| {
let current_branch =
repo.branch.as_ref().map(|branch| branch.name().to_owned());
repo.send_job(None, |state, _| async move {
let RepositoryState::Local { backend, .. } = state else {
return GitState {
remote_url: None,
head_sha: None,
current_branch,
diff: None,
};
};
let remote_url = backend.remote_url("origin");
let head_sha = backend.head_sha().await;
let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
GitState {
remote_url,
head_sha,
current_branch,
diff,
}
})
})
});
let git_state = match git_state {
Some(git_state) => match git_state.ok() {
Some(git_state) => git_state.await.ok(),
None => None,
},
None => None,
};
WorktreeSnapshot {
worktree_path,
git_state,
}
})
}
pub fn to_markdown(&self, cx: &App) -> Result<String> {
let mut markdown = Vec::new();
let summary = self.summary().or_default();
writeln!(markdown, "# {summary}\n")?;
for message in self.messages() {
writeln!(
markdown,
"## {role}\n",
role = match message.role {
Role::User => "User",
Role::Assistant => "Agent",
Role::System => "System",
}
)?;
if !message.loaded_context.text.is_empty() {
writeln!(markdown, "{}", message.loaded_context.text)?;
}
if !message.loaded_context.images.is_empty() {
writeln!(
markdown,
"\n{} images attached as context.\n",
message.loaded_context.images.len()
)?;
}
for segment in &message.segments {
match segment {
MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
MessageSegment::Thinking { text, .. } => {
writeln!(markdown, "<think>\n{}\n</think>\n", text)?
}
MessageSegment::RedactedThinking(_) => {}
}
}
for tool_use in self.tool_uses_for_message(message.id, cx) {
writeln!(
markdown,
"**Use Tool: {} ({})**",
tool_use.name, tool_use.id
)?;
writeln!(markdown, "```json")?;
writeln!(
markdown,
"{}",
serde_json::to_string_pretty(&tool_use.input)?
)?;
writeln!(markdown, "```")?;
}
for tool_result in self.tool_results_for_message(message.id) {
write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
if tool_result.is_error {
write!(markdown, " (Error)")?;
}
writeln!(markdown, "**\n")?;
match &tool_result.content {
LanguageModelToolResultContent::Text(text) => {
writeln!(markdown, "{text}")?;
}
LanguageModelToolResultContent::Image(image) => {
writeln!(markdown, "![Image](data:base64,{})", image.source)?;
}
}
if let Some(output) = tool_result.output.as_ref() {
writeln!(
markdown,
"\n\nDebug Output:\n\n```json\n{}\n```\n",
serde_json::to_string_pretty(output)?
)?;
}
}
}
Ok(String::from_utf8_lossy(&markdown).to_string())
}
pub fn keep_edits_in_range(
&mut self,
buffer: Entity<language::Buffer>,
buffer_range: Range<language::Anchor>,
cx: &mut Context<Self>,
) {
self.action_log.update(cx, |action_log, cx| {
action_log.keep_edits_in_range(buffer, buffer_range, cx)
});
}
pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
self.action_log
.update(cx, |action_log, cx| action_log.keep_all_edits(cx));
}
pub fn reject_edits_in_ranges(
&mut self,
buffer: Entity<language::Buffer>,
buffer_ranges: Vec<Range<language::Anchor>>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.action_log.update(cx, |action_log, cx| {
action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
})
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn project(&self) -> &Entity<Project> {
&self.project
}
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage
}
pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
let Some(model) = self.configured_model.as_ref() else {
return TotalTokenUsage::default();
};
let max = model
.model
.max_token_count_for_mode(self.completion_mode().into());
let index = self
.messages
.iter()
.position(|msg| msg.id == message_id)
.unwrap_or(0);
if index == 0 {
return TotalTokenUsage { total: 0, max };
}
let token_usage = &self
.request_token_usage
.get(index - 1)
.cloned()
.unwrap_or_default();
TotalTokenUsage {
total: token_usage.total_tokens(),
max,
}
}
pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
let model = self.configured_model.as_ref()?;
let max = model
.model
.max_token_count_for_mode(self.completion_mode().into());
if let Some(exceeded_error) = &self.exceeded_window_error
&& model.model.id() == exceeded_error.model_id
{
return Some(TotalTokenUsage {
total: exceeded_error.token_count,
max,
});
}
let total = self
.token_usage_at_last_message()
.unwrap_or_default()
.total_tokens();
Some(TotalTokenUsage { total, max })
}
fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
self.request_token_usage
.get(self.messages.len().saturating_sub(1))
.or_else(|| self.request_token_usage.last())
.cloned()
}
fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
let placeholder = self.token_usage_at_last_message().unwrap_or_default();
self.request_token_usage
.resize(self.messages.len(), placeholder);
if let Some(last) = self.request_token_usage.last_mut() {
*last = token_usage;
}
}
fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
self.project
.read(cx)
.user_store()
.update(cx, |user_store, cx| {
user_store.update_model_request_usage(
ModelRequestUsage(RequestUsage {
amount: amount as i32,
limit,
}),
cx,
)
});
}
pub fn deny_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
let err = Err(anyhow::anyhow!(
"Permission to run tool action denied by user"
));
self.tool_use.insert_tool_output(
tool_use_id.clone(),
tool_name,
err,
self.configured_model.as_ref(),
self.completion_mode,
);
self.tool_finished(tool_use_id, None, true, window, cx);
}
}
#[derive(Debug, Clone, Error)]
pub enum ThreadError {
#[error("Payment required")]
PaymentRequired,
#[error("Model request limit reached")]
ModelRequestLimitReached { plan: Plan },
#[error("Message {header}: {message}")]
Message {
header: SharedString,
message: SharedString,
},
#[error("Retryable error: {message}")]
RetryableError {
message: SharedString,
can_enable_burn_mode: bool,
},
}
#[derive(Debug, Clone)]
pub enum ThreadEvent {
ShowError(ThreadError),
StreamedCompletion,
ReceivedTextChunk,
NewRequest,
StreamedAssistantText(MessageId, String),
StreamedAssistantThinking(MessageId, String),
StreamedToolUse {
tool_use_id: LanguageModelToolUseId,
ui_text: Arc<str>,
input: serde_json::Value,
},
MissingToolUse {
tool_use_id: LanguageModelToolUseId,
ui_text: Arc<str>,
},
InvalidToolInput {
tool_use_id: LanguageModelToolUseId,
ui_text: Arc<str>,
invalid_input_json: Arc<str>,
},
Stopped(Result<StopReason, Arc<anyhow::Error>>),
MessageAdded(MessageId),
MessageEdited(MessageId),
MessageDeleted(MessageId),
SummaryGenerated,
SummaryChanged,
UsePendingTools {
tool_uses: Vec<PendingToolUse>,
},
ToolFinished {
#[allow(unused)]
tool_use_id: LanguageModelToolUseId,
/// The pending tool use that corresponds to this tool.
pending_tool_use: Option<PendingToolUse>,
},
CheckpointChanged,
ToolConfirmationNeeded,
ToolUseLimitReached,
CancelEditing,
CompletionCanceled,
ProfileChanged,
}
impl EventEmitter<ThreadEvent> for Thread {}
struct PendingCompletion {
id: usize,
queue_state: QueueState,
_task: Task<()>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
};
// Test-specific constants
const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
use assistant_tool::ToolRegistry;
use assistant_tools;
use futures::StreamExt;
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use gpui::TestAppContext;
use http_client;
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
use language_model::{
LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelToolChoice,
};
use parking_lot::Mutex;
use project::{FakeFs, Project};
use prompt_store::PromptBuilder;
use serde_json::json;
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use std::time::Duration;
use theme::ThemeSettings;
use util::path;
use workspace::Workspace;
#[gpui::test]
async fn test_message_with_context(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_workspace, _thread_store, thread, context_store, model) =
setup_test_environment(cx, project.clone()).await;
add_file_to_context(&project, &context_store, "test/code.rs", cx)
.await
.unwrap();
let context =
context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
let loaded_context = cx
.update(|cx| load_context(vec![context], &project, &None, cx))
.await;
// Insert user message with context
let message_id = thread.update(cx, |thread, cx| {
thread.insert_user_message(
"Please explain this code",
loaded_context,
None,
Vec::new(),
cx,
)
});
// Check content and context in message object
let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
// Use different path format strings based on platform for the test
#[cfg(windows)]
let path_part = r"test\code.rs";
#[cfg(not(windows))]
let path_part = "test/code.rs";
let expected_context = format!(
r#"
<context>
The following items were attached by the user. They are up-to-date and don't need to be re-read.
<files>
```rs {path_part}
fn main() {{
println!("Hello, world!");
}}
```
</files>
</context>
"#
);
assert_eq!(message.role, Role::User);
assert_eq!(message.segments.len(), 1);
assert_eq!(
message.segments[0],
MessageSegment::Text("Please explain this code".to_string())
);
assert_eq!(message.loaded_context.text, expected_context);
// Check message in request
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.messages.len(), 2);
let expected_full_message = format!("{}Please explain this code", expected_context);
assert_eq!(request.messages[1].string_contents(), expected_full_message);
}
#[gpui::test]
async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({
"file1.rs": "fn function1() {}\n",
"file2.rs": "fn function2() {}\n",
"file3.rs": "fn function3() {}\n",
"file4.rs": "fn function4() {}\n",
}),
)
.await;
let (_, _thread_store, thread, context_store, model) =
setup_test_environment(cx, project.clone()).await;
// First message with context 1
add_file_to_context(&project, &context_store, "test/file1.rs", cx)
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx), None)
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await;
let message1_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
});
// Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
add_file_to_context(&project, &context_store, "test/file2.rs", cx)
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx), None)
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await;
let message2_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
});
// Third message with all three contexts (contexts 1 and 2 should be skipped)
//
add_file_to_context(&project, &context_store, "test/file3.rs", cx)
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx), None)
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await;
let message3_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
});
// Check what contexts are included in each message
let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
(
thread.message(message1_id).unwrap().clone(),
thread.message(message2_id).unwrap().clone(),
thread.message(message3_id).unwrap().clone(),
)
});
// First message should include context 1
assert!(message1.loaded_context.text.contains("file1.rs"));
// Second message should include only context 2 (not 1)
assert!(!message2.loaded_context.text.contains("file1.rs"));
assert!(message2.loaded_context.text.contains("file2.rs"));
// Third message should include only context 3 (not 1 or 2)
assert!(!message3.loaded_context.text.contains("file1.rs"));
assert!(!message3.loaded_context.text.contains("file2.rs"));
assert!(message3.loaded_context.text.contains("file3.rs"));
// Check entire request to make sure all contexts are properly included
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
// The request should contain all 3 messages
assert_eq!(request.messages.len(), 4);
// Check that the contexts are properly formatted in each message
assert!(request.messages[1].string_contents().contains("file1.rs"));
assert!(!request.messages[1].string_contents().contains("file2.rs"));
assert!(!request.messages[1].string_contents().contains("file3.rs"));
assert!(!request.messages[2].string_contents().contains("file1.rs"));
assert!(request.messages[2].string_contents().contains("file2.rs"));
assert!(!request.messages[2].string_contents().contains("file3.rs"));
assert!(!request.messages[3].string_contents().contains("file1.rs"));
assert!(!request.messages[3].string_contents().contains("file2.rs"));
assert!(request.messages[3].string_contents().contains("file3.rs"));
add_file_to_context(&project, &context_store, "test/file4.rs", cx)
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 3);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(loaded_context.text.contains("file3.rs"));
assert!(loaded_context.text.contains("file4.rs"));
let new_contexts = context_store.update(cx, |store, cx| {
// Remove file4.rs
store.remove_context(&loaded_context.contexts[2].handle(), cx);
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 2);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(loaded_context.text.contains("file3.rs"));
assert!(!loaded_context.text.contains("file4.rs"));
let new_contexts = context_store.update(cx, |store, cx| {
// Remove file3.rs
store.remove_context(&loaded_context.contexts[1].handle(), cx);
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(!loaded_context.text.contains("file3.rs"));
assert!(!loaded_context.text.contains("file4.rs"));
}
#[gpui::test]
async fn test_message_without_files(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_, _thread_store, thread, _context_store, model) =
setup_test_environment(cx, project.clone()).await;
// Insert user message without any context (empty context vector)
let message_id = thread.update(cx, |thread, cx| {
thread.insert_user_message(
"What is the best way to learn Rust?",
ContextLoadResult::default(),
None,
Vec::new(),
cx,
)
});
// Check content and context in message object
let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
// Context should be empty when no files are included
assert_eq!(message.role, Role::User);
assert_eq!(message.segments.len(), 1);
assert_eq!(
message.segments[0],
MessageSegment::Text("What is the best way to learn Rust?".to_string())
);
assert_eq!(message.loaded_context.text, "");
// Check message in request
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.messages.len(), 2);
assert_eq!(
request.messages[1].string_contents(),
"What is the best way to learn Rust?"
);
// Add second message, also without context
let message2_id = thread.update(cx, |thread, cx| {
thread.insert_user_message(
"Are there any good books?",
ContextLoadResult::default(),
None,
Vec::new(),
cx,
)
});
let message2 =
thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
assert_eq!(message2.loaded_context.text, "");
// Check that both messages appear in the request
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.messages.len(), 3);
assert_eq!(
request.messages[1].string_contents(),
"What is the best way to learn Rust?"
);
assert_eq!(
request.messages[2].string_contents(),
"Are there any good books?"
);
}
#[gpui::test]
#[ignore] // turn this test on when project_notifications tool is re-enabled
async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_workspace, _thread_store, thread, context_store, model) =
setup_test_environment(cx, project.clone()).await;
// Add a buffer to the context. This will be a tracked buffer
let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
.await
.unwrap();
let context = context_store
.read_with(cx, |store, _| store.context().next().cloned())
.unwrap();
let loaded_context = cx
.update(|cx| load_context(vec![context], &project, &None, cx))
.await;
// Insert user message and assistant response
thread.update(cx, |thread, cx| {
thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
thread.insert_assistant_message(
vec![MessageSegment::Text("This code prints 42.".into())],
cx,
);
});
cx.run_until_parked();
// We shouldn't have a stale buffer notification yet
let notifications = thread.read_with(cx, |thread, _| {
find_tool_uses(thread, "project_notifications")
});
assert!(
notifications.is_empty(),
"Should not have stale buffer notification before buffer is modified"
);
// Modify the buffer
buffer.update(cx, |buffer, cx| {
buffer.edit(
[(1..1, "\n println!(\"Added a new line\");\n")],
None,
cx,
);
});
// Insert another user message
thread.update(cx, |thread, cx| {
thread.insert_user_message(
"What does the code do now?",
ContextLoadResult::default(),
None,
Vec::new(),
cx,
)
});
cx.run_until_parked();
// Check for the stale buffer warning
thread.update(cx, |thread, cx| {
thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
});
cx.run_until_parked();
let notifications = thread.read_with(cx, |thread, _cx| {
find_tool_uses(thread, "project_notifications")
});
let [notification] = notifications.as_slice() else {
panic!("Should have a `project_notifications` tool use");
};
let Some(notification_content) = notification.content.to_str() else {
panic!("`project_notifications` should return text");
};
assert!(notification_content.contains("These files have changed since the last read:"));
assert!(notification_content.contains("code.rs"));
// Insert another user message and flush notifications again
thread.update(cx, |thread, cx| {
thread.insert_user_message(
"Can you tell me more?",
ContextLoadResult::default(),
None,
Vec::new(),
cx,
)
});
thread.update(cx, |thread, cx| {
thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
});
cx.run_until_parked();
// There should be no new notifications (we already flushed one)
let notifications = thread.read_with(cx, |thread, _cx| {
find_tool_uses(thread, "project_notifications")
});
assert_eq!(
notifications.len(),
1,
"Should still have only one notification after second flush - no duplicates"
);
}
fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
thread
.messages()
.flat_map(|message| {
thread
.tool_results_for_message(message.id)
.into_iter()
.filter(|result| result.tool_name == tool_name.into())
.cloned()
.collect::<Vec<_>>()
})
.collect()
}
#[gpui::test]
async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_workspace, thread_store, thread, _context_store, _model) =
setup_test_environment(cx, project.clone()).await;
// Check that we are starting with the default profile
let profile = cx.read(|cx| thread.read(cx).profile.clone());
let tool_set = cx.read(|cx| thread_store.read(cx).tools());
assert_eq!(
profile,
AgentProfile::new(AgentProfileId::default(), tool_set)
);
}
#[gpui::test]
async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_workspace, thread_store, thread, _context_store, _model) =
setup_test_environment(cx, project.clone()).await;
// Profile gets serialized with default values
let serialized = thread
.update(cx, |thread, cx| thread.serialize(cx))
.await
.unwrap();
assert_eq!(serialized.profile, Some(AgentProfileId::default()));
let deserialized = cx.update(|cx| {
thread.update(cx, |thread, cx| {
Thread::deserialize(
thread.id.clone(),
serialized,
thread.project.clone(),
thread.tools.clone(),
thread.prompt_builder.clone(),
thread.project_context.clone(),
None,
cx,
)
})
});
let tool_set = cx.read(|cx| thread_store.read(cx).tools());
assert_eq!(
deserialized.profile,
AgentProfile::new(AgentProfileId::default(), tool_set)
);
}
#[gpui::test]
async fn test_temperature_setting(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_workspace, _thread_store, thread, _context_store, model) =
setup_test_environment(cx, project.clone()).await;
// Both model and provider
cx.update(|cx| {
AgentSettings::override_global(
AgentSettings {
model_parameters: vec![LanguageModelParameters {
provider: Some(model.provider_id().0.to_string().into()),
model: Some(model.id().0),
temperature: Some(0.66),
}],
..AgentSettings::get_global(cx).clone()
},
cx,
);
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.temperature, Some(0.66));
// Only model
cx.update(|cx| {
AgentSettings::override_global(
AgentSettings {
model_parameters: vec![LanguageModelParameters {
provider: None,
model: Some(model.id().0),
temperature: Some(0.66),
}],
..AgentSettings::get_global(cx).clone()
},
cx,
);
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.temperature, Some(0.66));
// Only provider
cx.update(|cx| {
AgentSettings::override_global(
AgentSettings {
model_parameters: vec![LanguageModelParameters {
provider: Some(model.provider_id().0.to_string().into()),
model: None,
temperature: Some(0.66),
}],
..AgentSettings::get_global(cx).clone()
},
cx,
);
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.temperature, Some(0.66));
// Same model name, different provider
cx.update(|cx| {
AgentSettings::override_global(
AgentSettings {
model_parameters: vec![LanguageModelParameters {
provider: Some("anthropic".into()),
model: Some(model.id().0),
temperature: Some(0.66),
}],
..AgentSettings::get_global(cx).clone()
},
cx,
);
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
assert_eq!(request.temperature, None);
}
#[gpui::test]
async fn test_thread_summary(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _thread_store, thread, _context_store, model) =
setup_test_environment(cx, project.clone()).await;
// Initial state should be pending
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Pending));
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
});
// Manually setting the summary should not be allowed in this state
thread.update(cx, |thread, cx| {
thread.set_summary("This should not work", cx);
});
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Pending));
});
// Send a message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
thread.send_to_model(
model.clone(),
CompletionIntent::ThreadSummarization,
None,
cx,
);
});
let fake_model = model.as_fake();
simulate_successful_response(fake_model, cx);
// Should start generating summary when there are >= 2 messages
thread.read_with(cx, |thread, _| {
assert_eq!(*thread.summary(), ThreadSummary::Generating);
});
// Should not be able to set the summary while generating
thread.update(cx, |thread, cx| {
thread.set_summary("This should not work either", cx);
});
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Generating));
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Brief");
fake_model.send_last_completion_stream_text_chunk(" Introduction");
fake_model.end_last_completion_stream();
cx.run_until_parked();
// Summary should be set
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
assert_eq!(thread.summary().or_default(), "Brief Introduction");
});
// Now we should be able to set a summary
thread.update(cx, |thread, cx| {
thread.set_summary("Brief Intro", cx);
});
thread.read_with(cx, |thread, _| {
assert_eq!(thread.summary().or_default(), "Brief Intro");
});
// Test setting an empty summary (should default to DEFAULT)
thread.update(cx, |thread, cx| {
thread.set_summary("", cx);
});
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
});
}
#[gpui::test]
async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _thread_store, thread, _context_store, model) =
setup_test_environment(cx, project.clone()).await;
test_summarize_error(&model, &thread, cx);
// Now we should be able to set a summary
thread.update(cx, |thread, cx| {
thread.set_summary("Brief Intro", cx);
});
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
assert_eq!(thread.summary().or_default(), "Brief Intro");
});
}
#[gpui::test]
async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _thread_store, thread, _context_store, model) =
setup_test_environment(cx, project.clone()).await;
test_summarize_error(&model, &thread, cx);
// Sending another message should not trigger another summarize request
thread.update(cx, |thread, cx| {
thread.insert_user_message(
"How are you?",
ContextLoadResult::default(),
None,
vec![],
cx,
);
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
});
let fake_model = model.as_fake();
simulate_successful_response(fake_model, cx);
thread.read_with(cx, |thread, _| {
// State is still Error, not Generating
assert!(matches!(thread.summary(), ThreadSummary::Error));
});
// But the summarize request can be invoked manually
thread.update(cx, |thread, cx| {
thread.summarize(cx);
});
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Generating));
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("A successful summary");
fake_model.end_last_completion_stream();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
assert_eq!(thread.summary().or_default(), "A successful summary");
});
}
// Helper to create a model that returns errors
enum TestError {
Overloaded,
InternalServerError,
}
struct ErrorInjector {
inner: Arc<FakeLanguageModel>,
error_type: TestError,
}
impl ErrorInjector {
fn new(error_type: TestError) -> Self {
Self {
inner: Arc::new(FakeLanguageModel::default()),
error_type,
}
}
}
impl LanguageModel for ErrorInjector {
fn id(&self) -> LanguageModelId {
self.inner.id()
}
fn name(&self) -> LanguageModelName {
self.inner.name()
}
fn provider_id(&self) -> LanguageModelProviderId {
self.inner.provider_id()
}
fn provider_name(&self) -> LanguageModelProviderName {
self.inner.provider_name()
}
fn supports_tools(&self) -> bool {
self.inner.supports_tools()
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
self.inner.supports_tool_choice(choice)
}
fn supports_images(&self) -> bool {
self.inner.supports_images()
}
fn telemetry_id(&self) -> String {
self.inner.telemetry_id()
}
fn max_token_count(&self) -> u64 {
self.inner.max_token_count()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
self.inner.count_tokens(request, cx)
}
fn stream_completion(
&self,
_request: LanguageModelRequest,
_cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
let error = match self.error_type {
TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
provider: self.provider_name(),
retry_after: None,
},
TestError::InternalServerError => {
LanguageModelCompletionError::ApiInternalServerError {
provider: self.provider_name(),
message: "I'm a teapot orbiting the sun".to_string(),
}
}
};
async move {
let stream = futures::stream::once(async move { Err(error) });
Ok(stream.boxed())
}
.boxed()
}
fn as_fake(&self) -> &FakeLanguageModel {
&self.inner
}
}
#[gpui::test]
async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
// Enable Burn Mode to allow retries
thread.update(cx, |thread, _| {
thread.set_completion_mode(CompletionMode::Burn);
});
// Create model that returns overloaded error
let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
// Insert a user message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
});
// Start completion
thread.update(cx, |thread, cx| {
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
});
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert!(thread.retry_state.is_some(), "Should have retry state");
let retry_state = thread.retry_state.as_ref().unwrap();
assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
assert_eq!(
retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
"Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
);
});
// Check that a retry message was added
thread.read_with(cx, |thread, _| {
let mut messages = thread.messages();
assert!(
messages.any(|msg| {
msg.role == Role::System
&& msg.ui_only
&& msg.segments.iter().any(|seg| {
if let MessageSegment::Text(text) = seg {
text.contains("overloaded")
&& text
.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
} else {
false
}
})
}),
"Should have added a system retry message"
);
});
let retry_count = thread.update(cx, |thread, _| {
thread
.messages
.iter()
.filter(|m| {
m.ui_only
&& m.segments.iter().any(|s| {
if let MessageSegment::Text(text) = s {
text.contains("Retrying") && text.contains("seconds")
} else {
false
}
})
})
.count()
});
assert_eq!(retry_count, 1, "Should have one retry message");
}
#[gpui::test]
async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
// Enable Burn Mode to allow retries
thread.update(cx, |thread, _| {
thread.set_completion_mode(CompletionMode::Burn);
});
// Create model that returns internal server error
let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
// Insert a user message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
});
// Start completion
thread.update(cx, |thread, cx| {
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
});
cx.run_until_parked();
// Check retry state on thread
thread.read_with(cx, |thread, _| {
assert!(thread.retry_state.is_some(), "Should have retry state");
let retry_state = thread.retry_state.as_ref().unwrap();
assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
assert_eq!(
retry_state.max_attempts, 3,
"Should have correct max attempts"
);
});
// Check that a retry message was added with provider name
thread.read_with(cx, |thread, _| {
let mut messages = thread.messages();
assert!(
messages.any(|msg| {
msg.role == Role::System
&& msg.ui_only
&& msg.segments.iter().any(|seg| {
if let MessageSegment::Text(text) = seg {
text.contains("internal")
&& text.contains("Fake")
&& text.contains("Retrying")
&& text.contains("attempt 1 of 3")
&& text.contains("seconds")
} else {
false
}
})
}),
"Should have added a system retry message with provider name"
);
});
// Count retry messages
let retry_count = thread.update(cx, |thread, _| {
thread
.messages
.iter()
.filter(|m| {
m.ui_only
&& m.segments.iter().any(|s| {
if let MessageSegment::Text(text) = s {
text.contains("Retrying") && text.contains("seconds")
} else {
false
}
})
})
.count()
});
assert_eq!(retry_count, 1, "Should have one retry message");
}
#[gpui::test]
async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
// Enable Burn Mode to allow retries
thread.update(cx, |thread, _| {
thread.set_completion_mode(CompletionMode::Burn);
});
// Create model that returns internal server error
let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
// Insert a user message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
});
// Track retry events and completion count
// Track completion events
let completion_count = Arc::new(Mutex::new(0));
let completion_count_clone = completion_count.clone();
let _subscription = thread.update(cx, |_, cx| {
cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
if let ThreadEvent::NewRequest = event {
*completion_count_clone.lock() += 1;
}
})
});
// First attempt
thread.update(cx, |thread, cx| {
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
});
cx.run_until_parked();
// Should have scheduled first retry - count retry messages
let retry_count = thread.update(cx, |thread, _| {
thread
.messages
.iter()
.filter(|m| {
m.ui_only
&& m.segments.iter().any(|s| {
if let MessageSegment::Text(text) = s {
text.contains("Retrying") && text.contains("seconds")
} else {
false
}
})
})
.count()
});
assert_eq!(retry_count, 1, "Should have scheduled first retry");
// Check retry state
thread.read_with(cx, |thread, _| {
assert!(thread.retry_state.is_some(), "Should have retry state");
let retry_state = thread.retry_state.as_ref().unwrap();
assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
assert_eq!(
retry_state.max_attempts, 3,
"Internal server errors should retry up to 3 times"
);
});
// Advance clock for first retry
cx.executor().advance_clock(BASE_RETRY_DELAY);
cx.run_until_parked();
// Advance clock for second retry
cx.executor().advance_clock(BASE_RETRY_DELAY);
cx.run_until_parked();
// Advance clock for third retry
cx.executor().advance_clock(BASE_RETRY_DELAY);
cx.run_until_parked();
// Should have completed all retries - count retry messages
let retry_count = thread.update(cx, |thread, _| {
thread
.messages
.iter()
.filter(|m| {
m.ui_only
&& m.segments.iter().any(|s| {
if let MessageSegment::Text(text) = s {
text.contains("Retrying") && text.contains("seconds")
} else {
false
}
})
})
.count()
});
assert_eq!(
retry_count, 3,
"Should have 3 retries for internal server errors"
);
// For internal server errors, we retry 3 times and then give up
// Check that retry_state is cleared after all retries
thread.read_with(cx, |thread, _| {
assert!(
thread.retry_state.is_none(),
"Retry state should be cleared after all retries"
);
});
// Verify total attempts (1 initial + 3 retries)
assert_eq!(
*completion_count.lock(),
4,
"Should have attempted once plus 3 retries"
);
}
#[gpui::test]
async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
// Enable Burn Mode to allow retries
thread.update(cx, |thread, _| {
thread.set_completion_mode(CompletionMode::Burn);
});
// Create model that returns overloaded error
let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
// Insert a user message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
});
// Track events
let stopped_with_error = Arc::new(Mutex::new(false));
let stopped_with_error_clone = stopped_with_error.clone();
let _subscription = thread.update(cx, |_, cx| {
cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
if let ThreadEvent::Stopped(Err(_)) = event {
*stopped_with_error_clone.lock() = true;
}
})
});
// Start initial completion
thread.update(cx, |thread, cx| {
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
});
cx.run_until_parked();
// Advance through all retries
for _ in 0..MAX_RETRY_ATTEMPTS {
cx.executor().advance_clock(BASE_RETRY_DELAY);
cx.run_until_parked();
}
let retry_count = thread.update(cx, |thread, _| {
thread
.messages
.iter()
.filter(|m| {
m.ui_only
&& m.segments.iter().any(|s| {
if let MessageSegment::Text(text) = s {
text.contains("Retrying") && text.contains("seconds")
} else {
false
}
})
})
.count()
});
// After max retries, should emit Stopped(Err(...)) event
assert_eq!(
retry_count, MAX_RETRY_ATTEMPTS as usize,
"Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
);
assert!(
*stopped_with_error.lock(),
"Should emit Stopped(Err(...)) event after max retries exceeded"
);
// Retry state should be cleared
thread.read_with(cx, |thread, _| {
assert!(
thread.retry_state.is_none(),
"Retry state should be cleared after max retries"
);
// Verify we have the expected number of retry messages
let retry_messages = thread
.messages
.iter()
.filter(|msg| msg.ui_only && msg.role == Role::System)
.count();
assert_eq!(
retry_messages, MAX_RETRY_ATTEMPTS as usize,
"Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
);
});
}
#[gpui::test]
async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
// Enable Burn Mode to allow retries
thread.update(cx, |thread, _| {
thread.set_completion_mode(CompletionMode::Burn);
});
// We'll use a wrapper to switch behavior after first failure
struct RetryTestModel {
inner: Arc<FakeLanguageModel>,
failed_once: Arc<Mutex<bool>>,
}
impl LanguageModel for RetryTestModel {
fn id(&self) -> LanguageModelId {
self.inner.id()
}
fn name(&self) -> LanguageModelName {
self.inner.name()
}
fn provider_id(&self) -> LanguageModelProviderId {
self.inner.provider_id()
}
fn provider_name(&self) -> LanguageModelProviderName {
self.inner.provider_name()
}
fn supports_tools(&self) -> bool {
self.inner.supports_tools()
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
self.inner.supports_tool_choice(choice)
}
fn supports_images(&self) -> bool {
self.inner.supports_images()
}
fn telemetry_id(&self) -> String {
self.inner.telemetry_id()
}
fn max_token_count(&self) -> u64 {
self.inner.max_token_count()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
self.inner.count_tokens(request, cx)
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
if !*self.failed_once.lock() {
*self.failed_once.lock() = true;
let provider = self.provider_name();
// Return error on first attempt
let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::ServerOverloaded {
provider,
retry_after: None,
})
});
async move { Ok(stream.boxed()) }.boxed()
} else {
// Succeed on retry
self.inner.stream_completion(request, cx)
}
}
fn as_fake(&self) -> &FakeLanguageModel {
&self.inner
}
}
let model = Arc::new(RetryTestModel {
inner: Arc::new(FakeLanguageModel::default()),
failed_once: Arc::new(Mutex::new(false)),
});
// Insert a user message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
});
// Track message deletions
// Track when retry completes successfully
let retry_completed = Arc::new(Mutex::new(false));
let retry_completed_clone = retry_completed.clone();
let _subscription = thread.update(cx, |_, cx| {
cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
if let ThreadEvent::StreamedCompletion = event {
*retry_completed_clone.lock() = true;
}
})
});
// Start completion
thread.update(cx, |thread, cx| {
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
});
cx.run_until_parked();
// Get the retry message ID
let retry_message_id = thread.read_with(cx, |thread, _| {
thread
.messages()
.find(|msg| msg.role == Role::System && msg.ui_only)
.map(|msg| msg.id)
.expect("Should have a retry message")
});
// Wait for retry
cx.executor().advance_clock(BASE_RETRY_DELAY);
cx.run_until_parked();
// Stream some successful content
let fake_model = model.as_fake();
// After the retry, there should be a new pending completion
let pending = fake_model.pending_completions();
assert!(
!pending.is_empty(),
"Should have a pending completion after retry"
);
fake_model.send_completion_stream_text_chunk(&pending[0], "Success!");
fake_model.end_completion_stream(&pending[0]);
cx.run_until_parked();
// Check that the retry completed successfully
assert!(
*retry_completed.lock(),
"Retry should have completed successfully"
);
// Retry message should still exist but be marked as ui_only
thread.read_with(cx, |thread, _| {
let retry_msg = thread
.message(retry_message_id)
.expect("Retry message should still exist");
assert!(retry_msg.ui_only, "Retry message should be ui_only");
assert_eq!(
retry_msg.role,
Role::System,
"Retry message should have System role"
);
});
}
#[gpui::test]
async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
// Enable Burn Mode to allow retries
thread.update(cx, |thread, _| {
thread.set_completion_mode(CompletionMode::Burn);
});
// Create a model that fails once then succeeds
struct FailOnceModel {
inner: Arc<FakeLanguageModel>,
failed_once: Arc<Mutex<bool>>,
}
impl LanguageModel for FailOnceModel {
fn id(&self) -> LanguageModelId {
self.inner.id()
}
fn name(&self) -> LanguageModelName {
self.inner.name()
}
fn provider_id(&self) -> LanguageModelProviderId {
self.inner.provider_id()
}
fn provider_name(&self) -> LanguageModelProviderName {
self.inner.provider_name()
}
fn supports_tools(&self) -> bool {
self.inner.supports_tools()
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
self.inner.supports_tool_choice(choice)
}
fn supports_images(&self) -> bool {
self.inner.supports_images()
}
fn telemetry_id(&self) -> String {
self.inner.telemetry_id()
}
fn max_token_count(&self) -> u64 {
self.inner.max_token_count()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
self.inner.count_tokens(request, cx)
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
if !*self.failed_once.lock() {
*self.failed_once.lock() = true;
let provider = self.provider_name();
// Return error on first attempt
let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::ServerOverloaded {
provider,
retry_after: None,
})
});
async move { Ok(stream.boxed()) }.boxed()
} else {
// Succeed on retry
self.inner.stream_completion(request, cx)
}
}
}
let fail_once_model = Arc::new(FailOnceModel {
inner: Arc::new(FakeLanguageModel::default()),
failed_once: Arc::new(Mutex::new(false)),
});
// Insert a user message
thread.update(cx, |thread, cx| {
thread.insert_user_message(
"Test message",
ContextLoadResult::default(),
None,
vec![],
cx,
);
});
// Start completion with fail-once model
thread.update(cx, |thread, cx| {
thread.send_to_model(
fail_once_model.clone(),
CompletionIntent::UserPrompt,
None,
cx,
);
});
cx.run_until_parked();
// Verify retry state exists after first failure
thread.read_with(cx, |thread, _| {
assert!(
thread.retry_state.is_some(),
"Should have retry state after failure"
);
});
// Wait for retry delay
cx.executor().advance_clock(BASE_RETRY_DELAY);
cx.run_until_parked();
// The retry should now use our FailOnceModel which should succeed
// We need to help the FakeLanguageModel complete the stream
let inner_fake = fail_once_model.inner.clone();
// Wait a bit for the retry to start
cx.run_until_parked();
// Check for pending completions and complete them
if let Some(pending) = inner_fake.pending_completions().first() {
inner_fake.send_completion_stream_text_chunk(pending, "Success!");
inner_fake.end_completion_stream(pending);
}
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert!(
thread.retry_state.is_none(),
"Retry state should be cleared after successful completion"
);
let has_assistant_message = thread
.messages
.iter()
.any(|msg| msg.role == Role::Assistant && !msg.ui_only);
assert!(
has_assistant_message,
"Should have an assistant message after successful retry"
);
});
}
#[gpui::test]
async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
// Enable Burn Mode to allow retries
thread.update(cx, |thread, _| {
thread.set_completion_mode(CompletionMode::Burn);
});
// Create a model that returns rate limit error with retry_after
struct RateLimitModel {
inner: Arc<FakeLanguageModel>,
}
impl LanguageModel for RateLimitModel {
fn id(&self) -> LanguageModelId {
self.inner.id()
}
fn name(&self) -> LanguageModelName {
self.inner.name()
}
fn provider_id(&self) -> LanguageModelProviderId {
self.inner.provider_id()
}
fn provider_name(&self) -> LanguageModelProviderName {
self.inner.provider_name()
}
fn supports_tools(&self) -> bool {
self.inner.supports_tools()
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
self.inner.supports_tool_choice(choice)
}
fn supports_images(&self) -> bool {
self.inner.supports_images()
}
fn telemetry_id(&self) -> String {
self.inner.telemetry_id()
}
fn max_token_count(&self) -> u64 {
self.inner.max_token_count()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
self.inner.count_tokens(request, cx)
}
fn stream_completion(
&self,
_request: LanguageModelRequest,
_cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
let provider = self.provider_name();
async move {
let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::RateLimitExceeded {
provider,
retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
})
});
Ok(stream.boxed())
}
.boxed()
}
fn as_fake(&self) -> &FakeLanguageModel {
&self.inner
}
}
let model = Arc::new(RateLimitModel {
inner: Arc::new(FakeLanguageModel::default()),
});
// Insert a user message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
});
// Start completion
thread.update(cx, |thread, cx| {
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
});
cx.run_until_parked();
let retry_count = thread.update(cx, |thread, _| {
thread
.messages
.iter()
.filter(|m| {
m.ui_only
&& m.segments.iter().any(|s| {
if let MessageSegment::Text(text) = s {
text.contains("rate limit exceeded")
} else {
false
}
})
})
.count()
});
assert_eq!(retry_count, 1, "Should have scheduled one retry");
thread.read_with(cx, |thread, _| {
assert!(
thread.retry_state.is_some(),
"Rate limit errors should set retry_state"
);
if let Some(retry_state) = &thread.retry_state {
assert_eq!(
retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
"Rate limit errors should use MAX_RETRY_ATTEMPTS"
);
}
});
// Verify we have one retry message
thread.read_with(cx, |thread, _| {
let retry_messages = thread
.messages
.iter()
.filter(|msg| {
msg.ui_only
&& msg.segments.iter().any(|seg| {
if let MessageSegment::Text(text) = seg {
text.contains("rate limit exceeded")
} else {
false
}
})
})
.count();
assert_eq!(
retry_messages, 1,
"Should have one rate limit retry message"
);
});
// Check that retry message doesn't include attempt count
thread.read_with(cx, |thread, _| {
let retry_message = thread
.messages
.iter()
.find(|msg| msg.role == Role::System && msg.ui_only)
.expect("Should have a retry message");
// Check that the message contains attempt count since we use retry_state
if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
assert!(
text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
"Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
);
assert!(
text.contains("Retrying"),
"Rate limit retry message should contain retry text"
);
}
});
}
#[gpui::test]
async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
// Insert a regular user message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
});
// Insert a UI-only message (like our retry notifications)
thread.update(cx, |thread, cx| {
let id = thread.next_message_id.post_inc();
thread.messages.push(Message {
id,
role: Role::System,
segments: vec![MessageSegment::Text(
"This is a UI-only message that should not be sent to the model".to_string(),
)],
loaded_context: LoadedContext::default(),
creases: Vec::new(),
is_hidden: true,
ui_only: true,
});
cx.emit(ThreadEvent::MessageAdded(id));
});
// Insert another regular message
thread.update(cx, |thread, cx| {
thread.insert_user_message(
"How are you?",
ContextLoadResult::default(),
None,
vec![],
cx,
);
});
// Generate the completion request
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
});
// Verify that the request only contains non-UI-only messages
// Should have system prompt + 2 user messages, but not the UI-only message
let user_messages: Vec<_> = request
.messages
.iter()
.filter(|msg| msg.role == Role::User)
.collect();
assert_eq!(
user_messages.len(),
2,
"Should have exactly 2 user messages"
);
// Verify the UI-only content is not present anywhere in the request
let request_text = request
.messages
.iter()
.flat_map(|msg| &msg.content)
.filter_map(|content| match content {
MessageContent::Text(text) => Some(text.as_str()),
_ => None,
})
.collect::<String>();
assert!(
!request_text.contains("UI-only message"),
"UI-only message content should not be in the request"
);
// Verify the thread still has all 3 messages (including UI-only)
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.messages().count(),
3,
"Thread should have 3 messages"
);
assert_eq!(
thread.messages().filter(|m| m.ui_only).count(),
1,
"Thread should have 1 UI-only message"
);
});
// Verify that UI-only messages are not serialized
let serialized = thread
.update(cx, |thread, cx| thread.serialize(cx))
.await
.unwrap();
assert_eq!(
serialized.messages.len(),
2,
"Serialized thread should only have 2 messages (no UI-only)"
);
}
#[gpui::test]
async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
// Ensure we're in Normal mode (not Burn mode)
thread.update(cx, |thread, _| {
thread.set_completion_mode(CompletionMode::Normal);
});
// Track error events
let error_events = Arc::new(Mutex::new(Vec::new()));
let error_events_clone = error_events.clone();
let _subscription = thread.update(cx, |_, cx| {
cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
if let ThreadEvent::ShowError(error) = event {
error_events_clone.lock().push(error.clone());
}
})
});
// Create model that returns overloaded error
let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
// Insert a user message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
});
// Start completion
thread.update(cx, |thread, cx| {
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
});
cx.run_until_parked();
// Verify no retry state was created
thread.read_with(cx, |thread, _| {
assert!(
thread.retry_state.is_none(),
"Should not have retry state in Normal mode"
);
});
// Check that a retryable error was reported
let errors = error_events.lock();
assert!(!errors.is_empty(), "Should have received an error event");
if let ThreadError::RetryableError {
message: _,
can_enable_burn_mode,
} = &errors[0]
{
assert!(
*can_enable_burn_mode,
"Error should indicate burn mode can be enabled"
);
} else {
panic!("Expected RetryableError, got {:?}", errors[0]);
}
// Verify the thread is no longer generating
thread.read_with(cx, |thread, _| {
assert!(
!thread.is_generating(),
"Should not be generating after error without retry"
);
});
}
#[gpui::test]
async fn test_retry_canceled_on_stop(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
// Enable Burn Mode to allow retries
thread.update(cx, |thread, _| {
thread.set_completion_mode(CompletionMode::Burn);
});
// Create model that returns overloaded error
let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
// Insert a user message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
});
// Start completion
thread.update(cx, |thread, cx| {
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
});
cx.run_until_parked();
// Verify retry was scheduled by checking for retry message
let has_retry_message = thread.read_with(cx, |thread, _| {
thread.messages.iter().any(|m| {
m.ui_only
&& m.segments.iter().any(|s| {
if let MessageSegment::Text(text) = s {
text.contains("Retrying") && text.contains("seconds")
} else {
false
}
})
})
});
assert!(has_retry_message, "Should have scheduled a retry");
// Cancel the completion before the retry happens
thread.update(cx, |thread, cx| {
thread.cancel_last_completion(None, cx);
});
cx.run_until_parked();
// The retry should not have happened - no pending completions
let fake_model = model.as_fake();
assert_eq!(
fake_model.pending_completions().len(),
0,
"Should have no pending completions after cancellation"
);
// Verify the retry was canceled by checking retry state
thread.read_with(cx, |thread, _| {
if let Some(retry_state) = &thread.retry_state {
panic!(
"retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
retry_state.attempt, retry_state.max_attempts, retry_state.intent
);
}
});
}
fn test_summarize_error(
model: &Arc<dyn LanguageModel>,
thread: &Entity<Thread>,
cx: &mut TestAppContext,
) {
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
thread.send_to_model(
model.clone(),
CompletionIntent::ThreadSummarization,
None,
cx,
);
});
let fake_model = model.as_fake();
simulate_successful_response(fake_model, cx);
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Generating));
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
});
// Simulate summary request ending
cx.run_until_parked();
fake_model.end_last_completion_stream();
cx.run_until_parked();
// State is set to Error and default message
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Error));
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
});
}
fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Assistant response");
fake_model.end_last_completion_stream();
cx.run_until_parked();
}
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
AgentSettings::register(cx);
prompt_store::init(cx);
thread_store::init(cx);
workspace::init_settings(cx);
language_model::init_settings(cx);
ThemeSettings::register(cx);
ToolRegistry::default_global(cx);
assistant_tool::init(cx);
let http_client = Arc::new(http_client::HttpClientWithUrl::new(
http_client::FakeHttpClient::with_200_response(),
"http://localhost".to_string(),
None,
));
assistant_tools::init(http_client, cx);
});
}
// Helper to create a test project with test files
async fn create_test_project(
cx: &mut TestAppContext,
files: serde_json::Value,
) -> Entity<Project> {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/test"), files).await;
Project::test(fs, [path!("/test").as_ref()], cx).await
}
async fn setup_test_environment(
cx: &mut TestAppContext,
project: Entity<Project>,
) -> (
Entity<Workspace>,
Entity<ThreadStore>,
Entity<Thread>,
Entity<ContextStore>,
Arc<dyn LanguageModel>,
) {
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_store = cx
.update(|_, cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
})
.await
.unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
let provider = Arc::new(FakeLanguageModelProvider::default());
let model = provider.test_model();
let model: Arc<dyn LanguageModel> = Arc::new(model);
cx.update(|_, cx| {
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(
Some(ConfiguredModel {
provider: provider.clone(),
model: model.clone(),
}),
cx,
);
registry.set_thread_summary_model(Some(ConfiguredModel {
provider,
model: model.clone(),
}));
})
});
(workspace, thread_store, thread, context_store, model)
}
async fn add_file_to_context(
project: &Entity<Project>,
context_store: &Entity<ContextStore>,
path: &str,
cx: &mut TestAppContext,
) -> Result<Entity<language::Buffer>> {
let buffer_path = project
.read_with(cx, |project, cx| project.find_project_path(path, cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(buffer_path.clone(), cx)
})
.await
.unwrap();
context_store.update(cx, |context_store, cx| {
context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
});
Ok(buffer)
}
}