420 lines
15 KiB
Rust
420 lines
15 KiB
Rust
use crate::templates::Templates;
|
|
use anyhow::{anyhow, Result};
|
|
use futures::{channel::mpsc, future};
|
|
use gpui::{App, Context, SharedString, Task};
|
|
use language_model::{
|
|
CompletionIntent, CompletionMode, LanguageModel, LanguageModelCompletionError,
|
|
LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
|
|
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
|
LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, Role, StopReason,
|
|
};
|
|
use schemars::{JsonSchema, Schema};
|
|
use serde::Deserialize;
|
|
use smol::stream::StreamExt;
|
|
use std::{collections::BTreeMap, sync::Arc};
|
|
use util::ResultExt;
|
|
|
|
#[derive(Debug)]
|
|
pub struct AgentMessage {
|
|
pub role: Role,
|
|
pub content: Vec<MessageContent>,
|
|
}
|
|
|
|
pub type AgentResponseEvent = LanguageModelCompletionEvent;
|
|
|
|
pub trait Prompt {
|
|
fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
|
|
}
|
|
|
|
pub struct Thread {
|
|
messages: Vec<AgentMessage>,
|
|
completion_mode: CompletionMode,
|
|
/// Holds the task that handles agent interaction until the end of the turn.
|
|
/// Survives across multiple requests as the model performs tool calls and
|
|
/// we run tools, report their results.
|
|
running_turn: Option<Task<()>>,
|
|
system_prompts: Vec<Arc<dyn Prompt>>,
|
|
tools: BTreeMap<SharedString, Arc<dyn AgentToolErased>>,
|
|
templates: Arc<Templates>,
|
|
// project: Entity<Project>,
|
|
// action_log: Entity<ActionLog>,
|
|
}
|
|
|
|
impl Thread {
|
|
pub fn new(templates: Arc<Templates>) -> Self {
|
|
Self {
|
|
messages: Vec::new(),
|
|
completion_mode: CompletionMode::Normal,
|
|
system_prompts: Vec::new(),
|
|
running_turn: None,
|
|
tools: BTreeMap::default(),
|
|
templates,
|
|
}
|
|
}
|
|
|
|
pub fn set_mode(&mut self, mode: CompletionMode) {
|
|
self.completion_mode = mode;
|
|
}
|
|
|
|
pub fn messages(&self) -> &[AgentMessage] {
|
|
&self.messages
|
|
}
|
|
|
|
pub fn add_tool(&mut self, tool: impl AgentTool) {
|
|
self.tools.insert(tool.name(), tool.erase());
|
|
}
|
|
|
|
pub fn remove_tool(&mut self, name: &str) -> bool {
|
|
self.tools.remove(name).is_some()
|
|
}
|
|
|
|
/// Sending a message results in the model streaming a response, which could include tool calls.
|
|
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
|
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
|
pub fn send(
|
|
&mut self,
|
|
model: Arc<dyn LanguageModel>,
|
|
content: impl Into<MessageContent>,
|
|
cx: &mut Context<Self>,
|
|
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
|
|
cx.notify();
|
|
let (events_tx, events_rx) =
|
|
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
|
|
|
let system_message = self.build_system_message(cx);
|
|
self.messages.extend(system_message);
|
|
|
|
self.messages.push(AgentMessage {
|
|
role: Role::User,
|
|
content: vec![content.into()],
|
|
});
|
|
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
|
let turn_result = async {
|
|
// Perform one request, then keep looping if the model makes tool calls.
|
|
let mut completion_intent = CompletionIntent::UserPrompt;
|
|
loop {
|
|
let request = thread.update(cx, |thread, cx| {
|
|
thread.build_completion_request(completion_intent, cx)
|
|
})?;
|
|
|
|
// println!(
|
|
// "request: {}",
|
|
// serde_json::to_string_pretty(&request).unwrap()
|
|
// );
|
|
|
|
// Stream events, appending to messages and collecting up tool uses.
|
|
let mut events = model.stream_completion(request, cx).await?;
|
|
let mut tool_uses = Vec::new();
|
|
while let Some(event) = events.next().await {
|
|
match event {
|
|
Ok(event) => {
|
|
thread
|
|
.update(cx, |thread, cx| {
|
|
tool_uses.extend(thread.handle_streamed_completion_event(
|
|
event,
|
|
events_tx.clone(),
|
|
cx,
|
|
));
|
|
})
|
|
.ok();
|
|
}
|
|
Err(error) => {
|
|
events_tx.unbounded_send(Err(error)).ok();
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// If there are no tool uses, the turn is done.
|
|
if tool_uses.is_empty() {
|
|
break;
|
|
}
|
|
|
|
// If there are tool uses, wait for their results to be
|
|
// computed, then send them together in a single message on
|
|
// the next loop iteration.
|
|
let tool_results = future::join_all(tool_uses).await;
|
|
thread
|
|
.update(cx, |thread, _cx| {
|
|
thread.messages.push(AgentMessage {
|
|
role: Role::User,
|
|
content: tool_results.into_iter().map(Into::into).collect(),
|
|
});
|
|
})
|
|
.ok();
|
|
completion_intent = CompletionIntent::ToolResults;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
.await;
|
|
|
|
if let Err(error) = turn_result {
|
|
events_tx.unbounded_send(Err(error)).ok();
|
|
}
|
|
}));
|
|
events_rx
|
|
}
|
|
|
|
pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
|
|
let mut system_message = AgentMessage {
|
|
role: Role::System,
|
|
content: Vec::new(),
|
|
};
|
|
|
|
for prompt in &self.system_prompts {
|
|
if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
|
|
system_message
|
|
.content
|
|
.push(MessageContent::Text(rendered_prompt));
|
|
}
|
|
}
|
|
|
|
(!system_message.content.is_empty()).then_some(system_message)
|
|
}
|
|
|
|
/// A helper method that's called on every streamed completion event.
|
|
/// Returns an optional tool result task, which the main agentic loop in
|
|
/// send will send back to the model when it resolves.
|
|
fn handle_streamed_completion_event(
|
|
&mut self,
|
|
event: LanguageModelCompletionEvent,
|
|
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
|
cx: &mut Context<Self>,
|
|
) -> Option<Task<LanguageModelToolResult>> {
|
|
use LanguageModelCompletionEvent::*;
|
|
events_tx.unbounded_send(Ok(event.clone())).ok();
|
|
|
|
match event {
|
|
Text(new_text) => self.handle_text_event(new_text, cx),
|
|
Thinking { text, signature } => {
|
|
todo!()
|
|
}
|
|
ToolUse(tool_use) => {
|
|
return self.handle_tool_use_event(tool_use, cx);
|
|
}
|
|
StartMessage { role, .. } => {
|
|
self.messages.push(AgentMessage {
|
|
role,
|
|
content: Vec::new(),
|
|
});
|
|
}
|
|
UsageUpdate(_) => {}
|
|
Stop(stop_reason) => self.handle_stop_event(stop_reason),
|
|
StatusUpdate(_completion_request_status) => {}
|
|
RedactedThinking { data } => todo!(),
|
|
ToolUseJsonParseError {
|
|
id,
|
|
tool_name,
|
|
raw_input,
|
|
json_parse_error,
|
|
} => todo!(),
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
fn handle_stop_event(&mut self, stop_reason: StopReason) {
|
|
match stop_reason {
|
|
StopReason::EndTurn | StopReason::ToolUse => {}
|
|
StopReason::MaxTokens => todo!(),
|
|
StopReason::Refusal => todo!(),
|
|
}
|
|
}
|
|
|
|
fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
|
|
let last_message = self.last_assistant_message();
|
|
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
|
|
text.push_str(&new_text);
|
|
} else {
|
|
last_message.content.push(MessageContent::Text(new_text));
|
|
}
|
|
|
|
cx.notify();
|
|
}
|
|
|
|
fn handle_tool_use_event(
|
|
&mut self,
|
|
tool_use: LanguageModelToolUse,
|
|
cx: &mut Context<Self>,
|
|
) -> Option<Task<LanguageModelToolResult>> {
|
|
cx.notify();
|
|
|
|
let last_message = self.last_assistant_message();
|
|
|
|
// Ensure the last message ends in the current tool use
|
|
let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
|
|
if let MessageContent::ToolUse(last_tool_use) = content {
|
|
if last_tool_use.id == tool_use.id {
|
|
*last_tool_use = tool_use.clone();
|
|
false
|
|
} else {
|
|
true
|
|
}
|
|
} else {
|
|
true
|
|
}
|
|
});
|
|
if push_new_tool_use {
|
|
last_message.content.push(tool_use.clone().into());
|
|
}
|
|
|
|
if !tool_use.is_input_complete {
|
|
return None;
|
|
}
|
|
|
|
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
|
|
let pending_tool_result = tool.clone().run(tool_use.input, cx);
|
|
|
|
Some(cx.foreground_executor().spawn(async move {
|
|
match pending_tool_result.await {
|
|
Ok(tool_output) => LanguageModelToolResult {
|
|
tool_use_id: tool_use.id,
|
|
tool_name: tool_use.name,
|
|
is_error: false,
|
|
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
|
output: None,
|
|
},
|
|
Err(error) => LanguageModelToolResult {
|
|
tool_use_id: tool_use.id,
|
|
tool_name: tool_use.name,
|
|
is_error: true,
|
|
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
|
|
output: None,
|
|
},
|
|
}
|
|
}))
|
|
} else {
|
|
Some(Task::ready(LanguageModelToolResult {
|
|
content: LanguageModelToolResultContent::Text(Arc::from(format!(
|
|
"No tool named {} exists",
|
|
tool_use.name
|
|
))),
|
|
tool_use_id: tool_use.id,
|
|
tool_name: tool_use.name,
|
|
is_error: true,
|
|
output: None,
|
|
}))
|
|
}
|
|
}
|
|
|
|
/// Guarantees the last message is from the assistant and returns a mutable reference.
|
|
fn last_assistant_message(&mut self) -> &mut AgentMessage {
|
|
if self
|
|
.messages
|
|
.last()
|
|
.map_or(true, |m| m.role != Role::Assistant)
|
|
{
|
|
self.messages.push(AgentMessage {
|
|
role: Role::Assistant,
|
|
content: Vec::new(),
|
|
});
|
|
}
|
|
self.messages.last_mut().unwrap()
|
|
}
|
|
|
|
fn build_completion_request(
|
|
&self,
|
|
completion_intent: CompletionIntent,
|
|
cx: &mut App,
|
|
) -> LanguageModelRequest {
|
|
LanguageModelRequest {
|
|
thread_id: None,
|
|
prompt_id: None,
|
|
intent: Some(completion_intent),
|
|
mode: Some(self.completion_mode),
|
|
messages: self.build_request_messages(),
|
|
tools: self
|
|
.tools
|
|
.values()
|
|
.filter_map(|tool| {
|
|
Some(LanguageModelRequestTool {
|
|
name: tool.name().to_string(),
|
|
description: tool.description(cx).to_string(),
|
|
input_schema: tool
|
|
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
|
.log_err()?,
|
|
})
|
|
})
|
|
.collect(),
|
|
tool_choice: None,
|
|
stop: Vec::new(),
|
|
temperature: None,
|
|
}
|
|
}
|
|
|
|
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
|
self.messages
|
|
.iter()
|
|
.map(|message| LanguageModelRequestMessage {
|
|
role: message.role,
|
|
content: message.content.clone(),
|
|
cache: false,
|
|
})
|
|
.collect()
|
|
}
|
|
}
|
|
|
|
pub trait AgentTool
|
|
where
|
|
Self: 'static + Sized,
|
|
{
|
|
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
|
|
|
fn name(&self) -> SharedString;
|
|
fn description(&self, _cx: &mut App) -> SharedString {
|
|
let schema = schemars::schema_for!(Self::Input);
|
|
SharedString::new(
|
|
schema
|
|
.get("description")
|
|
.and_then(|description| description.as_str())
|
|
.unwrap_or_default(),
|
|
)
|
|
}
|
|
|
|
/// Returns the JSON schema that describes the tool's input.
|
|
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
|
|
assistant_tools::root_schema_for::<Self::Input>(format)
|
|
}
|
|
|
|
/// Runs the tool with the provided input.
|
|
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
|
|
|
|
fn erase(self) -> Arc<dyn AgentToolErased> {
|
|
Arc::new(Erased(Arc::new(self)))
|
|
}
|
|
}
|
|
|
|
pub struct Erased<T>(T);
|
|
|
|
pub trait AgentToolErased {
|
|
fn name(&self) -> SharedString;
|
|
fn description(&self, cx: &mut App) -> SharedString;
|
|
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
|
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
|
|
}
|
|
|
|
impl<T> AgentToolErased for Erased<Arc<T>>
|
|
where
|
|
T: AgentTool,
|
|
{
|
|
fn name(&self) -> SharedString {
|
|
self.0.name()
|
|
}
|
|
|
|
fn description(&self, cx: &mut App) -> SharedString {
|
|
self.0.description(cx)
|
|
}
|
|
|
|
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
|
Ok(serde_json::to_value(self.0.input_schema(format))?)
|
|
}
|
|
|
|
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
|
|
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
|
match parsed_input {
|
|
Ok(input) => self.0.clone().run(input, cx),
|
|
Err(error) => Task::ready(Err(anyhow!(error))),
|
|
}
|
|
}
|
|
}
|