Connect Native Agent responses to UI display
User-visible improvements: - Native Agent now shows AI responses in the chat interface - Uses configured default model from settings instead of random selection - Streams responses in real-time as the model generates them Technical changes: - Implemented response stream forwarding from Thread to AcpThread - Created Session struct to manage Thread and AcpThread together - Added proper SessionUpdate handling for text chunks and tool calls - Fixed model selection to use LanguageModelRegistry's default - Added comprehensive logging for debugging model interactions - Removed unused cwd parameter - native agent captures context differently than external agents
This commit is contained in:
parent
bc1f861d3f
commit
f81993574e
4 changed files with 307 additions and 55 deletions
|
@ -36,7 +36,7 @@ pub trait ModelSelector: 'static {
|
||||||
/// A task resolving to `Ok(())` on success or an error.
|
/// A task resolving to `Ok(())` on success or an error.
|
||||||
fn select_model(
|
fn select_model(
|
||||||
&self,
|
&self,
|
||||||
session_id: &acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Task<Result<()>>;
|
) -> Task<Result<()>>;
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use acp_thread::ModelSelector;
|
use acp_thread::ModelSelector;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use anyhow::Result;
|
use anyhow::{anyhow, Result};
|
||||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||||
use language_model::{LanguageModel, LanguageModelRegistry};
|
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
@ -11,15 +11,25 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{templates::Templates, Thread};
|
use crate::{templates::Templates, Thread};
|
||||||
|
|
||||||
|
/// Holds both the internal Thread and the AcpThread for a session
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Session {
|
||||||
|
/// The internal thread that processes messages
|
||||||
|
thread: Entity<Thread>,
|
||||||
|
/// The ACP thread that handles protocol communication
|
||||||
|
acp_thread: Entity<acp_thread::AcpThread>,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct NativeAgent {
|
pub struct NativeAgent {
|
||||||
/// Session ID -> Thread entity mapping
|
/// Session ID -> Session mapping
|
||||||
sessions: HashMap<acp::SessionId, Entity<Thread>>,
|
sessions: HashMap<acp::SessionId, Session>,
|
||||||
/// Shared templates for all threads
|
/// Shared templates for all threads
|
||||||
templates: Arc<Templates>,
|
templates: Arc<Templates>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl NativeAgent {
|
impl NativeAgent {
|
||||||
pub fn new(templates: Arc<Templates>) -> Self {
|
pub fn new(templates: Arc<Templates>) -> Self {
|
||||||
|
log::info!("Creating new NativeAgent");
|
||||||
Self {
|
Self {
|
||||||
sessions: HashMap::new(),
|
sessions: HashMap::new(),
|
||||||
templates,
|
templates,
|
||||||
|
@ -33,10 +43,12 @@ pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
||||||
|
|
||||||
impl ModelSelector for NativeAgentConnection {
|
impl ModelSelector for NativeAgentConnection {
|
||||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
|
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
|
||||||
|
log::debug!("NativeAgentConnection::list_models called");
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
let registry = LanguageModelRegistry::read_global(cx);
|
let registry = LanguageModelRegistry::read_global(cx);
|
||||||
let models = registry.available_models(cx).collect::<Vec<_>>();
|
let models = registry.available_models(cx).collect::<Vec<_>>();
|
||||||
|
log::info!("Found {} available models", models.len());
|
||||||
if models.is_empty() {
|
if models.is_empty() {
|
||||||
Err(anyhow::anyhow!("No models available"))
|
Err(anyhow::anyhow!("No models available"))
|
||||||
} else {
|
} else {
|
||||||
|
@ -48,21 +60,26 @@ impl ModelSelector for NativeAgentConnection {
|
||||||
|
|
||||||
fn select_model(
|
fn select_model(
|
||||||
&self,
|
&self,
|
||||||
session_id: &acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
|
log::info!(
|
||||||
|
"Setting model for session {}: {:?}",
|
||||||
|
session_id,
|
||||||
|
model.name()
|
||||||
|
);
|
||||||
let agent = self.0.clone();
|
let agent = self.0.clone();
|
||||||
let session_id = session_id.clone();
|
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
agent.update(cx, |agent, cx| {
|
agent.update(cx, |agent, cx| {
|
||||||
if let Some(thread) = agent.sessions.get(&session_id) {
|
if let Some(session) = agent.sessions.get(&session_id) {
|
||||||
thread.update(cx, |thread, _| {
|
session.thread.update(cx, |thread, _cx| {
|
||||||
thread.selected_model = model;
|
thread.selected_model = model;
|
||||||
});
|
});
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
Err(anyhow::anyhow!("Session not found"))
|
Err(anyhow!("Session not found"))
|
||||||
}
|
}
|
||||||
})?
|
})?
|
||||||
})
|
})
|
||||||
|
@ -76,10 +93,12 @@ impl ModelSelector for NativeAgentConnection {
|
||||||
let agent = self.0.clone();
|
let agent = self.0.clone();
|
||||||
let session_id = session_id.clone();
|
let session_id = session_id.clone();
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
let thread = agent
|
let session = agent
|
||||||
.read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
|
.read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
|
||||||
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
|
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
|
||||||
let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
let selected = session
|
||||||
|
.thread
|
||||||
|
.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||||
Ok(selected)
|
Ok(selected)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -92,32 +111,64 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
cwd: &Path,
|
cwd: &Path,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
||||||
let _cwd = cwd.to_owned();
|
|
||||||
let agent = self.0.clone();
|
let agent = self.0.clone();
|
||||||
|
log::info!("Creating new thread for project at: {:?}", cwd);
|
||||||
|
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
// Create Thread and store in Agent
|
log::debug!("Starting thread creation in async context");
|
||||||
let (session_id, _thread) =
|
// Create Thread
|
||||||
agent.update(cx, |agent, cx: &mut gpui::Context<NativeAgent>| {
|
let (session_id, thread) = agent.update(
|
||||||
// Fetch default model
|
cx,
|
||||||
let default_model = LanguageModelRegistry::read_global(cx)
|
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
|
||||||
.available_models(cx)
|
// Fetch default model from registry settings
|
||||||
.next()
|
let registry = LanguageModelRegistry::read_global(cx);
|
||||||
.unwrap_or_else(|| panic!("No default model available"));
|
|
||||||
|
// Log available models for debugging
|
||||||
|
let available_count = registry.available_models(cx).count();
|
||||||
|
log::debug!("Total available models: {}", available_count);
|
||||||
|
|
||||||
|
let default_model = registry
|
||||||
|
.default_model()
|
||||||
|
.map(|configured| {
|
||||||
|
log::info!(
|
||||||
|
"Using configured default model: {:?} from provider: {:?}",
|
||||||
|
configured.model.name(),
|
||||||
|
configured.provider.name()
|
||||||
|
);
|
||||||
|
configured.model
|
||||||
|
})
|
||||||
|
.ok_or_else(|| {
|
||||||
|
log::warn!("No default model configured in settings");
|
||||||
|
anyhow!("No default model configured. Please configure a default model in settings.")
|
||||||
|
})?;
|
||||||
|
|
||||||
let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
|
let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
|
||||||
|
|
||||||
|
// Generate session ID
|
||||||
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
||||||
agent.sessions.insert(session_id.clone(), thread.clone());
|
log::info!("Created session with ID: {}", session_id);
|
||||||
(session_id, thread)
|
Ok((session_id, thread))
|
||||||
})?;
|
},
|
||||||
|
)??;
|
||||||
|
|
||||||
// Create AcpThread
|
// Create AcpThread
|
||||||
let acp_thread = cx.update(|cx| {
|
let acp_thread = cx.update(|cx| {
|
||||||
cx.new(|cx| {
|
cx.new(|cx| {
|
||||||
acp_thread::AcpThread::new("agent2", self.clone(), project, session_id, cx)
|
acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx)
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
// Store the session
|
||||||
|
agent.update(cx, |agent, _cx| {
|
||||||
|
agent.sessions.insert(
|
||||||
|
session_id,
|
||||||
|
Session {
|
||||||
|
thread,
|
||||||
|
acp_thread: acp_thread.clone(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(acp_thread)
|
Ok(acp_thread)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -137,28 +188,155 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||||
let session_id = params.session_id.clone();
|
let session_id = params.session_id.clone();
|
||||||
let agent = self.0.clone();
|
let agent = self.0.clone();
|
||||||
|
log::info!("Received prompt request for session: {}", session_id);
|
||||||
|
log::debug!("Prompt blocks count: {}", params.prompt.len());
|
||||||
|
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
// Get thread
|
// Get session
|
||||||
let thread: Entity<Thread> = agent
|
let session = agent
|
||||||
.read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
|
.read_with(cx, |agent, _| {
|
||||||
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
|
agent.sessions.get(&session_id).map(|s| Session {
|
||||||
|
thread: s.thread.clone(),
|
||||||
|
acp_thread: s.acp_thread.clone(),
|
||||||
|
})
|
||||||
|
})?
|
||||||
|
.ok_or_else(|| {
|
||||||
|
log::error!("Session not found: {}", session_id);
|
||||||
|
anyhow::anyhow!("Session not found")
|
||||||
|
})?;
|
||||||
|
log::debug!("Found session for: {}", session_id);
|
||||||
|
|
||||||
// Convert prompt to message
|
// Convert prompt to message
|
||||||
let message = convert_prompt_to_message(params.prompt);
|
let message = convert_prompt_to_message(params.prompt);
|
||||||
|
log::info!("Converted prompt to message: {} chars", message.len());
|
||||||
|
log::debug!("Message content: {}", message);
|
||||||
|
|
||||||
// Get model using the ModelSelector capability (always available for agent2)
|
// Get model using the ModelSelector capability (always available for agent2)
|
||||||
// Get the selected model from the thread directly
|
// Get the selected model from the thread directly
|
||||||
let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
let model = session
|
||||||
|
.thread
|
||||||
|
.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||||
|
|
||||||
// Send to thread
|
// Send to thread
|
||||||
thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
|
log::info!("Sending message to thread with model: {:?}", model.name());
|
||||||
|
let response_stream = session
|
||||||
|
.thread
|
||||||
|
.update(cx, |thread, cx| thread.send(model, message, cx))?;
|
||||||
|
|
||||||
|
// Handle response stream and forward to session.acp_thread
|
||||||
|
let acp_thread = session.acp_thread.clone();
|
||||||
|
cx.spawn(async move |cx| {
|
||||||
|
use futures::StreamExt;
|
||||||
|
use language_model::LanguageModelCompletionEvent;
|
||||||
|
|
||||||
|
let mut response_stream = response_stream;
|
||||||
|
|
||||||
|
while let Some(result) = response_stream.next().await {
|
||||||
|
match result {
|
||||||
|
Ok(event) => {
|
||||||
|
log::trace!("Received completion event: {:?}", event);
|
||||||
|
|
||||||
|
match event {
|
||||||
|
LanguageModelCompletionEvent::Text(text) => {
|
||||||
|
// Send text chunk as agent message
|
||||||
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.handle_session_update(
|
||||||
|
acp::SessionUpdate::AgentMessageChunk {
|
||||||
|
content: acp::ContentBlock::Text(
|
||||||
|
acp::TextContent {
|
||||||
|
text: text.into(),
|
||||||
|
annotations: None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})??;
|
||||||
|
}
|
||||||
|
LanguageModelCompletionEvent::ToolUse(tool_use) => {
|
||||||
|
// Convert LanguageModelToolUse to ACP ToolCall
|
||||||
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.handle_session_update(
|
||||||
|
acp::SessionUpdate::ToolCall(acp::ToolCall {
|
||||||
|
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
||||||
|
label: tool_use.name.to_string(),
|
||||||
|
kind: acp::ToolKind::Other,
|
||||||
|
status: acp::ToolCallStatus::Pending,
|
||||||
|
content: vec![],
|
||||||
|
locations: vec![],
|
||||||
|
raw_input: Some(tool_use.input),
|
||||||
|
}),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})??;
|
||||||
|
}
|
||||||
|
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||||
|
log::debug!("Started new assistant message");
|
||||||
|
}
|
||||||
|
LanguageModelCompletionEvent::UsageUpdate(usage) => {
|
||||||
|
log::debug!("Token usage update: {:?}", usage);
|
||||||
|
}
|
||||||
|
LanguageModelCompletionEvent::Thinking { text, .. } => {
|
||||||
|
// Send thinking text as agent thought chunk
|
||||||
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.handle_session_update(
|
||||||
|
acp::SessionUpdate::AgentThoughtChunk {
|
||||||
|
content: acp::ContentBlock::Text(
|
||||||
|
acp::TextContent {
|
||||||
|
text: text.into(),
|
||||||
|
annotations: None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})??;
|
||||||
|
}
|
||||||
|
LanguageModelCompletionEvent::StatusUpdate(status) => {
|
||||||
|
log::trace!("Status update: {:?}", status);
|
||||||
|
}
|
||||||
|
LanguageModelCompletionEvent::Stop(stop_reason) => {
|
||||||
|
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||||
|
}
|
||||||
|
LanguageModelCompletionEvent::RedactedThinking { .. } => {
|
||||||
|
log::trace!("Redacted thinking event");
|
||||||
|
}
|
||||||
|
LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
|
id,
|
||||||
|
tool_name,
|
||||||
|
raw_input,
|
||||||
|
json_parse_error,
|
||||||
|
} => {
|
||||||
|
log::error!(
|
||||||
|
"Tool use JSON parse error for tool '{}' (id: {}): {} - input: {}",
|
||||||
|
tool_name,
|
||||||
|
id,
|
||||||
|
json_parse_error,
|
||||||
|
raw_input
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("Error in model response stream: {:?}", e);
|
||||||
|
// TODO: Consider sending an error message to the UI
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("Response stream completed");
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
|
||||||
|
log::info!("Successfully sent prompt to thread and started response handler");
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||||
|
log::info!("Cancelling session: {}", session_id);
|
||||||
self.0.update(cx, |agent, _cx| {
|
self.0.update(cx, |agent, _cx| {
|
||||||
agent.sessions.remove(session_id);
|
agent.sessions.remove(session_id);
|
||||||
});
|
});
|
||||||
|
@ -167,23 +345,29 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
|
|
||||||
/// Convert ACP content blocks to a message string
|
/// Convert ACP content blocks to a message string
|
||||||
fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
|
fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
|
||||||
|
log::debug!("Converting {} content blocks to message", blocks.len());
|
||||||
let mut message = String::new();
|
let mut message = String::new();
|
||||||
|
|
||||||
for block in blocks {
|
for block in blocks {
|
||||||
match block {
|
match block {
|
||||||
acp::ContentBlock::Text(text) => {
|
acp::ContentBlock::Text(text) => {
|
||||||
|
log::trace!("Processing text block: {} chars", text.text.len());
|
||||||
message.push_str(&text.text);
|
message.push_str(&text.text);
|
||||||
}
|
}
|
||||||
acp::ContentBlock::ResourceLink(link) => {
|
acp::ContentBlock::ResourceLink(link) => {
|
||||||
|
log::trace!("Processing resource link: {}", link.uri);
|
||||||
message.push_str(&format!(" @{} ", link.uri));
|
message.push_str(&format!(" @{} ", link.uri));
|
||||||
}
|
}
|
||||||
acp::ContentBlock::Image(_) => {
|
acp::ContentBlock::Image(_) => {
|
||||||
|
log::trace!("Processing image block");
|
||||||
message.push_str(" [image] ");
|
message.push_str(" [image] ");
|
||||||
}
|
}
|
||||||
acp::ContentBlock::Audio(_) => {
|
acp::ContentBlock::Audio(_) => {
|
||||||
|
log::trace!("Processing audio block");
|
||||||
message.push_str(" [audio] ");
|
message.push_str(" [audio] ");
|
||||||
}
|
}
|
||||||
acp::ContentBlock::Resource(resource) => {
|
acp::ContentBlock::Resource(resource) => {
|
||||||
|
log::trace!("Processing resource block: {:?}", resource.resource);
|
||||||
message.push_str(&format!(" [resource: {:?}] ", resource.resource));
|
message.push_str(&format!(" [resource: {:?}] ", resource.resource));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,15 +35,22 @@ impl AgentServer for NativeAgentServer {
|
||||||
_project: &Entity<Project>,
|
_project: &Entity<Project>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
|
) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
|
||||||
|
log::info!(
|
||||||
|
"NativeAgentServer::connect called for path: {:?}",
|
||||||
|
_root_dir
|
||||||
|
);
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
|
log::debug!("Creating templates for native agent");
|
||||||
// Create templates (you might want to load these from files or resources)
|
// Create templates (you might want to load these from files or resources)
|
||||||
let templates = Templates::new();
|
let templates = Templates::new();
|
||||||
|
|
||||||
// Create the native agent
|
// Create the native agent
|
||||||
|
log::debug!("Creating native agent entity");
|
||||||
let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?;
|
let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?;
|
||||||
|
|
||||||
// Create the connection wrapper
|
// Create the connection wrapper
|
||||||
let connection = NativeAgentConnection(agent);
|
let connection = NativeAgentConnection(agent);
|
||||||
|
log::info!("NativeAgentServer connection established successfully");
|
||||||
|
|
||||||
Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
|
Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
|
||||||
})
|
})
|
||||||
|
|
|
@ -9,6 +9,7 @@ use language_model::{
|
||||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||||
LanguageModelToolUse, MessageContent, Role, StopReason,
|
LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||||
};
|
};
|
||||||
|
use log;
|
||||||
use schemars::{JsonSchema, Schema};
|
use schemars::{JsonSchema, Schema};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use smol::stream::StreamExt;
|
use smol::stream::StreamExt;
|
||||||
|
@ -38,7 +39,6 @@ pub struct Thread {
|
||||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||||
templates: Arc<Templates>,
|
templates: Arc<Templates>,
|
||||||
pub selected_model: Arc<dyn LanguageModel>,
|
pub selected_model: Arc<dyn LanguageModel>,
|
||||||
// project: Entity<Project>,
|
|
||||||
// action_log: Entity<ActionLog>,
|
// action_log: Entity<ActionLog>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,22 +80,36 @@ impl Thread {
|
||||||
content: impl Into<MessageContent>,
|
content: impl Into<MessageContent>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
|
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
|
||||||
|
let content = content.into();
|
||||||
|
log::info!("Thread::send called with model: {:?}", model.name());
|
||||||
|
log::debug!("Thread::send content: {:?}", content);
|
||||||
|
|
||||||
cx.notify();
|
cx.notify();
|
||||||
let (events_tx, events_rx) =
|
let (events_tx, events_rx) =
|
||||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||||
|
|
||||||
let system_message = self.build_system_message(cx);
|
let system_message = self.build_system_message(cx);
|
||||||
|
log::debug!(
|
||||||
|
"System messages count: {}",
|
||||||
|
if system_message.is_some() { 1 } else { 0 }
|
||||||
|
);
|
||||||
self.messages.extend(system_message);
|
self.messages.extend(system_message);
|
||||||
|
|
||||||
self.messages.push(AgentMessage {
|
self.messages.push(AgentMessage {
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
content: vec![content.into()],
|
content: vec![content],
|
||||||
});
|
});
|
||||||
|
log::info!("Total messages in thread: {}", self.messages.len());
|
||||||
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
||||||
|
log::info!("Starting agent turn execution");
|
||||||
let turn_result = async {
|
let turn_result = async {
|
||||||
// Perform one request, then keep looping if the model makes tool calls.
|
// Perform one request, then keep looping if the model makes tool calls.
|
||||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||||
loop {
|
loop {
|
||||||
|
log::debug!(
|
||||||
|
"Building completion request with intent: {:?}",
|
||||||
|
completion_intent
|
||||||
|
);
|
||||||
let request = thread.update(cx, |thread, cx| {
|
let request = thread.update(cx, |thread, cx| {
|
||||||
thread.build_completion_request(completion_intent, cx)
|
thread.build_completion_request(completion_intent, cx)
|
||||||
})?;
|
})?;
|
||||||
|
@ -106,11 +120,14 @@ impl Thread {
|
||||||
// );
|
// );
|
||||||
|
|
||||||
// Stream events, appending to messages and collecting up tool uses.
|
// Stream events, appending to messages and collecting up tool uses.
|
||||||
|
log::info!("Calling model.stream_completion");
|
||||||
let mut events = model.stream_completion(request, cx).await?;
|
let mut events = model.stream_completion(request, cx).await?;
|
||||||
|
log::debug!("Stream completion started successfully");
|
||||||
let mut tool_uses = Vec::new();
|
let mut tool_uses = Vec::new();
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
match event {
|
match event {
|
||||||
Ok(event) => {
|
Ok(event) => {
|
||||||
|
log::trace!("Received completion event: {:?}", event);
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
tool_uses.extend(thread.handle_streamed_completion_event(
|
tool_uses.extend(thread.handle_streamed_completion_event(
|
||||||
|
@ -122,6 +139,7 @@ impl Thread {
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
|
log::error!("Error in completion stream: {:?}", error);
|
||||||
events_tx.unbounded_send(Err(error)).ok();
|
events_tx.unbounded_send(Err(error)).ok();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -130,13 +148,16 @@ impl Thread {
|
||||||
|
|
||||||
// If there are no tool uses, the turn is done.
|
// If there are no tool uses, the turn is done.
|
||||||
if tool_uses.is_empty() {
|
if tool_uses.is_empty() {
|
||||||
|
log::info!("No tool uses found, completing turn");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
log::info!("Found {} tool uses to execute", tool_uses.len());
|
||||||
|
|
||||||
// If there are tool uses, wait for their results to be
|
// If there are tool uses, wait for their results to be
|
||||||
// computed, then send them together in a single message on
|
// computed, then send them together in a single message on
|
||||||
// the next loop iteration.
|
// the next loop iteration.
|
||||||
let tool_results = future::join_all(tool_uses).await;
|
let tool_results = future::join_all(tool_uses).await;
|
||||||
|
log::debug!("Tool execution completed, {} results", tool_results.len());
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, _cx| {
|
.update(cx, |thread, _cx| {
|
||||||
thread.messages.push(AgentMessage {
|
thread.messages.push(AgentMessage {
|
||||||
|
@ -156,13 +177,17 @@ impl Thread {
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
if let Err(error) = turn_result {
|
if let Err(error) = turn_result {
|
||||||
|
log::error!("Turn execution failed: {:?}", error);
|
||||||
events_tx.unbounded_send(Err(error)).ok();
|
events_tx.unbounded_send(Err(error)).ok();
|
||||||
|
} else {
|
||||||
|
log::info!("Turn execution completed successfully");
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
events_rx
|
events_rx
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
|
pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
|
||||||
|
log::debug!("Building system message");
|
||||||
let mut system_message = AgentMessage {
|
let mut system_message = AgentMessage {
|
||||||
role: Role::System,
|
role: Role::System,
|
||||||
content: Vec::new(),
|
content: Vec::new(),
|
||||||
|
@ -176,7 +201,9 @@ impl Thread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(!system_message.content.is_empty()).then_some(system_message)
|
let result = (!system_message.content.is_empty()).then_some(system_message);
|
||||||
|
log::debug!("System message built: {}", result.is_some());
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A helper method that's called on every streamed completion event.
|
/// A helper method that's called on every streamed completion event.
|
||||||
|
@ -188,6 +215,7 @@ impl Thread {
|
||||||
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Option<Task<LanguageModelToolResult>> {
|
) -> Option<Task<LanguageModelToolResult>> {
|
||||||
|
log::trace!("Handling streamed completion event: {:?}", event);
|
||||||
use LanguageModelCompletionEvent::*;
|
use LanguageModelCompletionEvent::*;
|
||||||
events_tx.unbounded_send(Ok(event.clone())).ok();
|
events_tx.unbounded_send(Ok(event.clone())).ok();
|
||||||
|
|
||||||
|
@ -329,41 +357,74 @@ impl Thread {
|
||||||
completion_intent: CompletionIntent,
|
completion_intent: CompletionIntent,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> LanguageModelRequest {
|
) -> LanguageModelRequest {
|
||||||
LanguageModelRequest {
|
log::debug!("Building completion request");
|
||||||
|
log::debug!("Completion intent: {:?}", completion_intent);
|
||||||
|
log::debug!("Completion mode: {:?}", self.completion_mode);
|
||||||
|
|
||||||
|
let messages = self.build_request_messages();
|
||||||
|
log::info!("Request will include {} messages", messages.len());
|
||||||
|
|
||||||
|
let tools: Vec<LanguageModelRequestTool> = self
|
||||||
|
.tools
|
||||||
|
.values()
|
||||||
|
.filter_map(|tool| {
|
||||||
|
let tool_name = tool.name().to_string();
|
||||||
|
log::trace!("Including tool: {}", tool_name);
|
||||||
|
Some(LanguageModelRequestTool {
|
||||||
|
name: tool_name,
|
||||||
|
description: tool.description(cx).to_string(),
|
||||||
|
input_schema: tool
|
||||||
|
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
||||||
|
.log_err()?,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
log::info!("Request includes {} tools", tools.len());
|
||||||
|
|
||||||
|
let request = LanguageModelRequest {
|
||||||
thread_id: None,
|
thread_id: None,
|
||||||
prompt_id: None,
|
prompt_id: None,
|
||||||
intent: Some(completion_intent),
|
intent: Some(completion_intent),
|
||||||
mode: Some(self.completion_mode),
|
mode: Some(self.completion_mode),
|
||||||
messages: self.build_request_messages(),
|
messages,
|
||||||
tools: self
|
tools,
|
||||||
.tools
|
|
||||||
.values()
|
|
||||||
.filter_map(|tool| {
|
|
||||||
Some(LanguageModelRequestTool {
|
|
||||||
name: tool.name().to_string(),
|
|
||||||
description: tool.description(cx).to_string(),
|
|
||||||
input_schema: tool
|
|
||||||
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
|
||||||
.log_err()?,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature: None,
|
temperature: None,
|
||||||
thinking_allowed: false,
|
thinking_allowed: false,
|
||||||
}
|
};
|
||||||
|
|
||||||
|
log::debug!("Completion request built successfully");
|
||||||
|
request
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
||||||
self.messages
|
log::trace!(
|
||||||
|
"Building request messages from {} thread messages",
|
||||||
|
self.messages.len()
|
||||||
|
);
|
||||||
|
let messages = self
|
||||||
|
.messages
|
||||||
.iter()
|
.iter()
|
||||||
.map(|message| LanguageModelRequestMessage {
|
.map(|message| {
|
||||||
role: message.role,
|
log::trace!(
|
||||||
content: message.content.clone(),
|
" - {} message with {} content items",
|
||||||
cache: false,
|
match message.role {
|
||||||
|
Role::System => "System",
|
||||||
|
Role::User => "User",
|
||||||
|
Role::Assistant => "Assistant",
|
||||||
|
},
|
||||||
|
message.content.len()
|
||||||
|
);
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: message.role,
|
||||||
|
content: message.content.clone(),
|
||||||
|
cache: false,
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.collect()
|
.collect();
|
||||||
|
messages
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue