use std::{ error::Error, fmt::{self, Debug}, path::Path, sync::{Arc, Mutex}, time::Duration, }; use crate::{ ToolMetrics, assertions::{AssertionsReport, RanAssertion, RanAssertionResult}, }; use agent::{ContextLoadResult, Thread, ThreadEvent}; use agent_settings::AgentProfileId; use anyhow::{Result, anyhow}; use async_trait::async_trait; use buffer_diff::DiffHunkStatus; use cloud_llm_client::CompletionIntent; use collections::HashMap; use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased}; use gpui::{App, AppContext, AsyncApp, Entity}; use language_model::{LanguageModel, Role, StopReason}; pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2); #[async_trait(?Send)] pub trait Example { fn meta(&self) -> ExampleMetadata; async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>; fn diff_assertions(&self) -> Vec { Vec::new() } fn thread_assertions(&self) -> Vec { Vec::new() } } #[derive(Clone, Debug)] pub struct JudgeAssertion { pub id: String, pub description: String, } #[derive(Clone, Debug)] pub struct ExampleMetadata { pub name: String, pub url: String, pub revision: String, pub language_server: Option, pub max_assertions: Option, pub profile_id: AgentProfileId, pub existing_thread_json: Option, pub max_turns: Option, } #[derive(Clone, Debug)] pub struct LanguageServer { pub file_extension: String, pub allow_preexisting_diagnostics: bool, } impl ExampleMetadata { pub fn repo_name(&self) -> String { self.url .split('/') .next_back() .unwrap_or("") .trim_end_matches(".git") .into() } } pub struct FailedAssertion(pub String); impl fmt::Debug for FailedAssertion { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Assertion failure: {}", self.0) } } impl fmt::Display for FailedAssertion { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } impl Error for FailedAssertion {} pub struct ExampleContext { meta: ExampleMetadata, log_prefix: String, agent_thread: Entity, app: AsyncApp, model: Arc, pub assertions: AssertionsReport, pub tool_metrics: Arc>, } impl ExampleContext { pub fn new( meta: ExampleMetadata, log_prefix: String, agent_thread: Entity, model: Arc, app: AsyncApp, ) -> Self { let assertions = AssertionsReport::new(meta.max_assertions); Self { meta, log_prefix, agent_thread, assertions, model, app, tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())), } } pub fn push_user_message(&mut self, text: impl ToString) { self.app .update_entity(&self.agent_thread, |thread, cx| { thread.insert_user_message( text.to_string(), ContextLoadResult::default(), None, Vec::new(), cx, ); }) .unwrap(); } pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> { let message = message.to_string(); self.log_assertion( if expected { Ok(()) } else { Err(anyhow::Error::from(FailedAssertion(message.clone()))) }, message, ) } pub fn assert_some(&mut self, option: Option, message: impl ToString) -> Result { let message = message.to_string(); self.log_assertion( match option { Some(value) => Ok(value), None => Err(anyhow::Error::from(FailedAssertion(message.clone()))), }, message, ) } #[allow(dead_code)] pub fn assert_eq( &mut self, left: T, right: T, message: impl ToString, ) -> Result<()> { let message = message.to_string(); self.log_assertion( if left == right { Ok(()) } else { println!( "{}{}", self.log_prefix, pretty_assertions::Comparison::new(&left, &right) ); Err(anyhow::Error::from(FailedAssertion(message.clone()))) }, message, ) } fn log_assertion(&mut self, result: Result, message: String) -> Result { if let Some(max) = self.meta.max_assertions { anyhow::ensure!( self.assertions.run_count() <= max, "More assertions were run than the stated max_assertions of {max}" ); } self.assertions.ran.push(RanAssertion { id: message.clone(), result: Ok(RanAssertionResult { analysis: None, passed: result.is_ok(), }), }); if result.is_ok() { println!("{}✅ {}", self.log_prefix, message); } else { println!("{}❌ {}", self.log_prefix, message); } result } pub async fn run_to_end(&mut self) -> Result { self.run_turns(u32::MAX).await } pub async fn run_turn(&mut self) -> Result { self.run_turns(1).await } pub async fn run_turns(&mut self, iterations: u32) -> Result { let (mut tx, mut rx) = mpsc::channel(1); let tool_metrics = self.tool_metrics.clone(); let log_prefix = self.log_prefix.clone(); let _subscription = self.app.subscribe( &self.agent_thread, move |thread, event: &ThreadEvent, cx| match event { ThreadEvent::ShowError(thread_error) => { tx.try_send(Err(anyhow!(thread_error.clone()))).ok(); } ThreadEvent::Stopped(reason) => match reason { Ok(StopReason::EndTurn) => { tx.close_channel(); } Ok(StopReason::ToolUse) => { if thread.read(cx).remaining_turns() == 0 { tx.close_channel(); } } Ok(StopReason::MaxTokens) => { tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok(); } Ok(StopReason::Refusal) => { tx.try_send(Err(anyhow!("Model refused to generate content"))) .ok(); } Err(err) => { tx.try_send(Err(anyhow!(err.clone()))).ok(); } }, ThreadEvent::NewRequest | ThreadEvent::StreamedAssistantText(_, _) | ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } | ThreadEvent::CompletionCanceled => {} ThreadEvent::ToolUseLimitReached => {} ThreadEvent::ToolFinished { tool_use_id, pending_tool_use, .. } => { thread.update(cx, |thread, _cx| { if let Some(tool_use) = pending_tool_use { let mut tool_metrics = tool_metrics.lock().unwrap(); if let Some(tool_result) = thread.tool_result(tool_use_id) { let message = if tool_result.is_error { format!("✖︎ {}", tool_use.name) } else { format!("✔︎ {}", tool_use.name) }; println!("{log_prefix}{message}"); tool_metrics .insert(tool_result.tool_name.clone(), !tool_result.is_error); } else { let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name); println!("{log_prefix}{message}"); tool_metrics.insert(tool_use.name.clone(), true); } } }); } ThreadEvent::InvalidToolInput { .. } => { println!("{log_prefix} invalid tool input"); } ThreadEvent::MissingToolUse { tool_use_id: _, ui_text, } => { println!("{log_prefix} {ui_text}"); } ThreadEvent::ToolConfirmationNeeded => { panic!( "{}Bug: Tool confirmation should not be required in eval", log_prefix ); } ThreadEvent::StreamedCompletion | ThreadEvent::MessageAdded(_) | ThreadEvent::MessageEdited(_) | ThreadEvent::MessageDeleted(_) | ThreadEvent::SummaryChanged | ThreadEvent::SummaryGenerated | ThreadEvent::ProfileChanged | ThreadEvent::ReceivedTextChunk | ThreadEvent::StreamedToolUse { .. } | ThreadEvent::CheckpointChanged | ThreadEvent::CancelEditing => { tx.try_send(Ok(())).ok(); if std::env::var("ZED_EVAL_DEBUG").is_ok() { println!("{}Event: {:#?}", log_prefix, event); } } }, ); let model = self.model.clone(); let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| { thread.set_remaining_turns(iterations); thread.send_to_model(model, CompletionIntent::UserPrompt, None, cx); thread.messages().len() })?; loop { select_biased! { result = rx.next() => { if let Some(result) = result { result?; } else { break; } } _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => { anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events"); } } } let messages = self.app.read_entity(&self.agent_thread, |thread, cx| { let mut messages = Vec::new(); for message in thread.messages().skip(message_count_before) { messages.push(Message { _role: message.role, text: message.to_message_content(), tool_use: thread .tool_uses_for_message(message.id, cx) .into_iter() .map(|tool_use| ToolUse { name: tool_use.name.to_string(), value: tool_use.input, }) .collect(), }); } messages })?; let response = Response::new(messages); Ok(response) } pub fn edits(&self) -> HashMap, FileEdits> { self.agent_thread .read_with(&self.app, |thread, cx| { let action_log = thread.action_log().read(cx); HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map( |(buffer, diff)| { let snapshot = buffer.read(cx).snapshot(); let file = snapshot.file().unwrap(); let diff = diff.read(cx); let base_text = diff.base_text().text(); let hunks = diff .hunks(&snapshot, cx) .map(|hunk| FileEditHunk { base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(), text: snapshot .text_for_range(hunk.range.clone()) .collect::(), status: hunk.status(), }) .collect(); (file.path().clone(), FileEdits { hunks }) }, )) }) .unwrap() } pub fn agent_thread(&self) -> Entity { self.agent_thread.clone() } } impl AppContext for ExampleContext { type Result = anyhow::Result; fn new( &mut self, build_entity: impl FnOnce(&mut gpui::Context) -> T, ) -> Self::Result> { self.app.new(build_entity) } fn reserve_entity(&mut self) -> Self::Result> { self.app.reserve_entity() } fn insert_entity( &mut self, reservation: gpui::Reservation, build_entity: impl FnOnce(&mut gpui::Context) -> T, ) -> Self::Result> { self.app.insert_entity(reservation, build_entity) } fn update_entity( &mut self, handle: &Entity, update: impl FnOnce(&mut T, &mut gpui::Context) -> R, ) -> Self::Result where T: 'static, { self.app.update_entity(handle, update) } fn as_mut<'a, T>(&'a mut self, handle: &Entity) -> Self::Result> where T: 'static, { self.app.as_mut(handle) } fn read_entity( &self, handle: &Entity, read: impl FnOnce(&T, &App) -> R, ) -> Self::Result where T: 'static, { self.app.read_entity(handle, read) } fn update_window(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result where F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T, { self.app.update_window(window, f) } fn read_window( &self, window: &gpui::WindowHandle, read: impl FnOnce(Entity, &App) -> R, ) -> Result where T: 'static, { self.app.read_window(window, read) } fn background_spawn( &self, future: impl std::future::Future + Send + 'static, ) -> gpui::Task where R: Send + 'static, { self.app.background_spawn(future) } fn read_global(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result where G: gpui::Global, { self.app.read_global(callback) } } #[derive(Debug)] pub struct Response { messages: Vec, } impl Response { pub fn new(messages: Vec) -> Self { Self { messages } } pub fn expect_tool( &self, tool_name: &'static str, cx: &mut ExampleContext, ) -> Result<&ToolUse> { let result = self.find_tool_call(tool_name); cx.assert_some(result, format!("called `{}`", tool_name)) } pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> { self.messages.iter().rev().find_map(|msg| { msg.tool_use .iter() .find(|tool_use| tool_use.name == tool_name) }) } #[allow(dead_code)] pub fn tool_uses(&self) -> impl Iterator { self.messages.iter().flat_map(|msg| &msg.tool_use) } pub fn texts(&self) -> impl Iterator { self.messages.iter().map(|message| message.text.clone()) } } #[derive(Debug)] pub struct Message { _role: Role, text: String, tool_use: Vec, } #[derive(Debug)] pub struct ToolUse { pub name: String, value: serde_json::Value, } impl ToolUse { pub fn parse_input(&self) -> Result where Input: for<'de> serde::Deserialize<'de>, { serde_json::from_value::(self.value.clone()).map_err(|err| anyhow!(err)) } } #[derive(Debug, Eq, PartialEq)] pub struct FileEdits { pub hunks: Vec, } #[derive(Debug, Eq, PartialEq)] pub struct FileEditHunk { pub base_text: String, pub text: String, pub status: DiffHunkStatus, } impl FileEdits { pub fn has_added_line(&self, line: &str) -> bool { self.hunks.iter().any(|hunk| { hunk.status == DiffHunkStatus::added_none() && hunk.base_text.is_empty() && hunk.text.contains(line) }) } }