agent2: Allow tools to be provider specific (#36111)
Our WebSearch tool requires access to a Zed provider Release Notes: - N/A
This commit is contained in:
parent
7f1a5c6ad7
commit
2b3dbe8815
2 changed files with 26 additions and 4 deletions
|
@ -15,9 +15,9 @@ use futures::{
|
|||
use gpui::{App, Context, Entity, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
};
|
||||
use log;
|
||||
use project::Project;
|
||||
|
@ -681,10 +681,12 @@ impl Thread {
|
|||
.profiles
|
||||
.get(&self.profile_id)
|
||||
.context("profile not found")?;
|
||||
let provider_id = self.selected_model.provider_id();
|
||||
|
||||
Ok(self
|
||||
.tools
|
||||
.iter()
|
||||
.filter(move |(_, tool)| tool.supported_provider(&provider_id))
|
||||
.filter_map(|(tool_name, tool)| {
|
||||
if profile.is_tool_enabled(tool_name) {
|
||||
Some(tool)
|
||||
|
@ -782,6 +784,12 @@ where
|
|||
schemars::schema_for!(Self::Input)
|
||||
}
|
||||
|
||||
/// Some tools rely on a provider for the underlying billing or other reasons.
|
||||
/// Allow the tool to check if they are compatible, or should be filtered out.
|
||||
fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
|
@ -808,6 +816,9 @@ pub trait AnyAgentTool {
|
|||
fn kind(&self) -> acp::ToolKind;
|
||||
fn initial_title(&self, input: serde_json::Value) -> SharedString;
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||
fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
|
||||
true
|
||||
}
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
|
@ -843,6 +854,10 @@ where
|
|||
Ok(json)
|
||||
}
|
||||
|
||||
fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
|
||||
self.0.supported_provider(provider)
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
|
|
|
@ -5,7 +5,9 @@ use agent_client_protocol as acp;
|
|||
use anyhow::{Result, anyhow};
|
||||
use cloud_llm_client::WebSearchResponse;
|
||||
use gpui::{App, AppContext, Task};
|
||||
use language_model::LanguageModelToolResultContent;
|
||||
use language_model::{
|
||||
LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::prelude::*;
|
||||
|
@ -50,6 +52,11 @@ impl AgentTool for WebSearchTool {
|
|||
"Searching the Web".into()
|
||||
}
|
||||
|
||||
/// We currently only support Zed Cloud as a provider.
|
||||
fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
|
||||
provider == &ZED_CLOUD_PROVIDER_ID
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue