Support profiles in agent2 (#36034)
We still need a profile selector. Release Notes: - N/A --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
13bf45dd4a
commit
2444321756
15 changed files with 587 additions and 108 deletions
|
@ -1,7 +1,7 @@
|
|||
use crate::{SystemPromptTemplate, Template, Templates};
|
||||
use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
use agent_settings::{AgentProfileId, AgentSettings};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::adapt_schema_to_format;
|
||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||
|
@ -126,6 +126,8 @@ pub struct Thread {
|
|||
running_turn: Option<Task<()>>,
|
||||
pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
profile_id: AgentProfileId,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
templates: Arc<Templates>,
|
||||
pub selected_model: Arc<dyn LanguageModel>,
|
||||
|
@ -137,16 +139,21 @@ impl Thread {
|
|||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
default_model: Arc<dyn LanguageModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
completion_mode: CompletionMode::Normal,
|
||||
running_turn: None,
|
||||
pending_tool_uses: HashMap::default(),
|
||||
tools: BTreeMap::default(),
|
||||
context_server_registry,
|
||||
profile_id,
|
||||
project_context,
|
||||
templates,
|
||||
selected_model: default_model,
|
||||
|
@ -179,6 +186,10 @@ impl Thread {
|
|||
self.tools.remove(name).is_some()
|
||||
}
|
||||
|
||||
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
|
||||
self.profile_id = profile_id;
|
||||
}
|
||||
|
||||
pub fn cancel(&mut self) {
|
||||
self.running_turn.take();
|
||||
|
||||
|
@ -298,6 +309,7 @@ impl Thread {
|
|||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
}),
|
||||
raw_output: tool_result.output.clone(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
@ -604,21 +616,23 @@ impl Thread {
|
|||
let messages = self.build_request_messages();
|
||||
log::info!("Request will include {} messages", messages.len());
|
||||
|
||||
let tools: Vec<LanguageModelRequestTool> = self
|
||||
.tools
|
||||
.values()
|
||||
.filter_map(|tool| {
|
||||
let tool_name = tool.name().to_string();
|
||||
log::trace!("Including tool: {}", tool_name);
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool_name,
|
||||
description: tool.description(cx).to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(self.selected_model.tool_input_format())
|
||||
.log_err()?,
|
||||
let tools = if let Some(tools) = self.tools(cx).log_err() {
|
||||
tools
|
||||
.filter_map(|tool| {
|
||||
let tool_name = tool.name().to_string();
|
||||
log::trace!("Including tool: {}", tool_name);
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool_name,
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(self.selected_model.tool_input_format())
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
log::info!("Request includes {} tools", tools.len());
|
||||
|
||||
|
@ -639,6 +653,35 @@ impl Thread {
|
|||
request
|
||||
}
|
||||
|
||||
fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
|
||||
let profile = AgentSettings::get_global(cx)
|
||||
.profiles
|
||||
.get(&self.profile_id)
|
||||
.context("profile not found")?;
|
||||
|
||||
Ok(self
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(tool_name, tool)| {
|
||||
if profile.is_tool_enabled(tool_name) {
|
||||
Some(tool)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.chain(self.context_server_registry.read(cx).servers().flat_map(
|
||||
|(server_id, tools)| {
|
||||
tools.iter().filter_map(|(tool_name, tool)| {
|
||||
if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
|
||||
Some(tool)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
log::trace!(
|
||||
"Building request messages from {} thread messages",
|
||||
|
@ -686,7 +729,7 @@ where
|
|||
|
||||
fn name(&self) -> SharedString;
|
||||
|
||||
fn description(&self, _cx: &mut App) -> SharedString {
|
||||
fn description(&self) -> SharedString {
|
||||
let schema = schemars::schema_for!(Self::Input);
|
||||
SharedString::new(
|
||||
schema
|
||||
|
@ -722,13 +765,13 @@ where
|
|||
pub struct Erased<T>(T);
|
||||
|
||||
pub struct AgentToolOutput {
|
||||
llm_output: LanguageModelToolResultContent,
|
||||
raw_output: serde_json::Value,
|
||||
pub llm_output: LanguageModelToolResultContent,
|
||||
pub raw_output: serde_json::Value,
|
||||
}
|
||||
|
||||
pub trait AnyAgentTool {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
fn description(&self) -> SharedString;
|
||||
fn kind(&self) -> acp::ToolKind;
|
||||
fn initial_title(&self, input: serde_json::Value) -> SharedString;
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||
|
@ -748,8 +791,8 @@ where
|
|||
self.0.name()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
self.0.description(cx)
|
||||
fn description(&self) -> SharedString {
|
||||
self.0.description()
|
||||
}
|
||||
|
||||
fn kind(&self) -> agent_client_protocol::ToolKind {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue