use std::{ error::Error, fmt::{self, Debug}, path::Path, sync::{Arc, Mutex}, time::Duration, }; use crate::{ ToolMetrics, assertions::{AssertionsReport, RanAssertion, RanAssertionResult}, }; use agent::{ContextLoadResult, ThreadEvent}; use anyhow::{Result, anyhow}; use async_trait::async_trait; use buffer_diff::DiffHunkStatus; use collections::HashMap; use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased}; use gpui::{AppContext, AsyncApp, Entity}; use language_model::{LanguageModel, Role, StopReason}; pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2); #[async_trait(?Send)] pub trait Example { fn meta(&self) -> ExampleMetadata; async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>; fn diff_assertions(&self) -> Vec { Vec::new() } fn thread_assertions(&self) -> Vec { Vec::new() } } #[derive(Clone, Debug)] pub struct JudgeAssertion { pub id: String, pub description: String, } #[derive(Clone, Debug)] pub struct ExampleMetadata { pub name: String, pub url: String, pub revision: String, pub language_server: Option, pub max_assertions: Option, } #[derive(Clone, Debug)] pub struct LanguageServer { pub file_extension: String, pub allow_preexisting_diagnostics: bool, } impl ExampleMetadata { pub fn repo_name(&self) -> String { self.url .split('/') .next_back() .unwrap_or(&"") .trim_end_matches(".git") .into() } } pub struct FailedAssertion(pub String); impl fmt::Debug for FailedAssertion { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Assertion failure: {}", self.0) } } impl fmt::Display for FailedAssertion { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } impl Error for FailedAssertion {} pub struct ExampleContext { meta: ExampleMetadata, log_prefix: String, agent_thread: Entity, app: AsyncApp, model: Arc, pub assertions: AssertionsReport, pub tool_metrics: Arc>, } impl ExampleContext { pub fn new( meta: ExampleMetadata, log_prefix: String, agent_thread: Entity, model: Arc, app: AsyncApp, ) -> Self { let assertions = AssertionsReport::new(meta.max_assertions); Self { meta, log_prefix, agent_thread, assertions, model, app, tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())), } } pub fn push_user_message(&mut self, text: impl ToString) { self.app .update_entity(&self.agent_thread, |thread, cx| { thread.insert_user_message( text.to_string(), ContextLoadResult::default(), None, cx, ); }) .unwrap(); } pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> { let message = message.to_string(); self.log_assertion( if expected { Ok(()) } else { Err(anyhow::Error::from(FailedAssertion(message.clone()))) }, message, ) } pub fn assert_some(&mut self, option: Option, message: impl ToString) -> Result { let message = message.to_string(); self.log_assertion( match option { Some(value) => Ok(value), None => Err(anyhow::Error::from(FailedAssertion(message.clone()))), }, message, ) } #[allow(dead_code)] pub fn assert_eq( &mut self, left: T, right: T, message: impl ToString, ) -> Result<()> { let message = message.to_string(); self.log_assertion( if left == right { Ok(()) } else { println!("{}{:#?} != {:#?}", self.log_prefix, left, right); Err(anyhow::Error::from(FailedAssertion(message.clone()))) }, message, ) } fn log_assertion(&mut self, result: Result, message: String) -> Result { if let Some(max) = self.meta.max_assertions { if self.assertions.run_count() > max { return Err(anyhow!( "More assertions were run than the stated max_assertions of {}", max )); } } self.assertions.ran.push(RanAssertion { id: message.clone(), result: Ok(RanAssertionResult { analysis: None, passed: result.is_ok(), }), }); if result.is_ok() { println!("{}✅ {}", self.log_prefix, message); } else { println!("{}❌ {}", self.log_prefix, message); } result } pub async fn run_to_end(&mut self) -> Result { self.run_turns(u32::MAX).await } pub async fn run_turn(&mut self) -> Result { self.run_turns(1).await } pub async fn run_turns(&mut self, iterations: u32) -> Result { let (mut tx, mut rx) = mpsc::channel(1); let tool_metrics = self.tool_metrics.clone(); let log_prefix = self.log_prefix.clone(); let _subscription = self.app.subscribe( &self.agent_thread, move |thread, event: &ThreadEvent, cx| match event { ThreadEvent::ShowError(thread_error) => { tx.try_send(Err(anyhow!(thread_error.clone()))).ok(); } ThreadEvent::Stopped(reason) => match reason { Ok(StopReason::EndTurn) => { tx.close_channel(); } Ok(StopReason::ToolUse) => { if thread.read(cx).remaining_turns() == 0 { tx.close_channel(); } } Ok(StopReason::MaxTokens) => { tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok(); } Err(err) => { tx.try_send(Err(anyhow!(err.clone()))).ok(); } }, ThreadEvent::StreamedAssistantText(_, _) | ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => {} ThreadEvent::ToolFinished { tool_use_id, pending_tool_use, .. } => { thread.update(cx, |thread, _cx| { if let Some(tool_use) = pending_tool_use { let mut tool_metrics = tool_metrics.lock().unwrap(); if let Some(tool_result) = thread.tool_result(&tool_use_id) { let message = if tool_result.is_error { format!("✖︎ {}", tool_use.name) } else { format!("✔︎ {}", tool_use.name) }; println!("{log_prefix}{message}"); tool_metrics .insert(tool_result.tool_name.clone(), !tool_result.is_error); } else { let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name); println!("{log_prefix}{message}"); tool_metrics.insert(tool_use.name.clone(), true); } } }); } ThreadEvent::InvalidToolInput { .. } => { println!("{log_prefix} invalid tool input"); } ThreadEvent::ToolConfirmationNeeded => { panic!( "{}Bug: Tool confirmation should not be required in eval", log_prefix ); } ThreadEvent::StreamedCompletion | ThreadEvent::MessageAdded(_) | ThreadEvent::MessageEdited(_) | ThreadEvent::MessageDeleted(_) | ThreadEvent::SummaryChanged | ThreadEvent::SummaryGenerated | ThreadEvent::ReceivedTextChunk | ThreadEvent::StreamedToolUse { .. } | ThreadEvent::CheckpointChanged | ThreadEvent::UsageUpdated(_) | ThreadEvent::CancelEditing => { tx.try_send(Ok(())).ok(); if std::env::var("ZED_EVAL_DEBUG").is_ok() { println!("{}Event: {:#?}", log_prefix, event); } } }, ); let model = self.model.clone(); let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| { thread.set_remaining_turns(iterations); thread.send_to_model(model, None, cx); thread.messages().len() })?; loop { select_biased! { result = rx.next() => { if let Some(result) = result { result?; } else { break; } } _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => { return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT)); } } } let messages = self.app.read_entity(&self.agent_thread, |thread, cx| { let mut messages = Vec::new(); for message in thread.messages().skip(message_count_before) { messages.push(Message { _role: message.role, _text: message.to_string(), tool_use: thread .tool_uses_for_message(message.id, cx) .into_iter() .map(|tool_use| ToolUse { name: tool_use.name.to_string(), value: tool_use.input, }) .collect(), }); } messages })?; let response = Response::new(messages); Ok(response) } pub fn edits(&self) -> HashMap, FileEdits> { self.app .read_entity(&self.agent_thread, |thread, cx| { let action_log = thread.action_log().read(cx); HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map( |(buffer, diff)| { let snapshot = buffer.read(cx).snapshot(); let file = snapshot.file().unwrap(); let diff = diff.read(cx); let base_text = diff.base_text().text(); let hunks = diff .hunks(&snapshot, cx) .map(|hunk| FileEditHunk { base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(), text: snapshot .text_for_range(hunk.range.clone()) .collect::(), status: hunk.status(), }) .collect(); (file.path().clone(), FileEdits { hunks }) }, )) }) .unwrap() } } #[derive(Debug)] pub struct Response { messages: Vec, } impl Response { pub fn new(messages: Vec) -> Self { Self { messages } } pub fn expect_tool( &self, tool_name: &'static str, cx: &mut ExampleContext, ) -> Result<&ToolUse> { let result = self.messages.iter().find_map(|msg| { msg.tool_use .iter() .find(|tool_use| tool_use.name == tool_name) }); cx.assert_some(result, format!("called `{}`", tool_name)) } pub fn tool_uses(&self) -> impl Iterator { self.messages.iter().flat_map(|msg| &msg.tool_use) } } #[derive(Debug)] pub struct Message { _role: Role, _text: String, tool_use: Vec, } #[derive(Debug)] pub struct ToolUse { pub name: String, value: serde_json::Value, } impl ToolUse { pub fn parse_input(&self) -> Result where Input: for<'de> serde::Deserialize<'de>, { serde_json::from_value::(self.value.clone()).map_err(|err| anyhow!(err)) } } #[derive(Debug)] pub struct FileEdits { hunks: Vec, } #[derive(Debug)] struct FileEditHunk { base_text: String, text: String, status: DiffHunkStatus, } impl FileEdits { pub fn has_added_line(&self, line: &str) -> bool { self.hunks.iter().any(|hunk| { hunk.status == DiffHunkStatus::added_none() && hunk.base_text.is_empty() && hunk.text.contains(line) }) } }