agent2: Allow tools to be provider specific (#36111)

Our WebSearch tool requires access to a Zed provider

Release Notes:

- N/A
This commit is contained in:
Ben Brandt 2025-08-13 15:22:05 +02:00 committed by GitHub
parent 7f1a5c6ad7
commit 2b3dbe8815
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 4 deletions

View file

@ -15,9 +15,9 @@ use futures::{
use gpui::{App, Context, Entity, SharedString, Task}; use gpui::{App, Context, Entity, SharedString, Task};
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
}; };
use log; use log;
use project::Project; use project::Project;
@ -681,10 +681,12 @@ impl Thread {
.profiles .profiles
.get(&self.profile_id) .get(&self.profile_id)
.context("profile not found")?; .context("profile not found")?;
let provider_id = self.selected_model.provider_id();
Ok(self Ok(self
.tools .tools
.iter() .iter()
.filter(move |(_, tool)| tool.supported_provider(&provider_id))
.filter_map(|(tool_name, tool)| { .filter_map(|(tool_name, tool)| {
if profile.is_tool_enabled(tool_name) { if profile.is_tool_enabled(tool_name) {
Some(tool) Some(tool)
@ -782,6 +784,12 @@ where
schemars::schema_for!(Self::Input) schemars::schema_for!(Self::Input)
} }
/// Some tools rely on a provider for the underlying billing or other reasons.
/// Allow the tool to check if they are compatible, or should be filtered out.
fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
true
}
/// Runs the tool with the provided input. /// Runs the tool with the provided input.
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
@ -808,6 +816,9 @@ pub trait AnyAgentTool {
fn kind(&self) -> acp::ToolKind; fn kind(&self) -> acp::ToolKind;
fn initial_title(&self, input: serde_json::Value) -> SharedString; fn initial_title(&self, input: serde_json::Value) -> SharedString;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>; fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
true
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
@ -843,6 +854,10 @@ where
Ok(json) Ok(json)
} }
fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
self.0.supported_provider(provider)
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,

View file

@ -5,7 +5,9 @@ use agent_client_protocol as acp;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use cloud_llm_client::WebSearchResponse; use cloud_llm_client::WebSearchResponse;
use gpui::{App, AppContext, Task}; use gpui::{App, AppContext, Task};
use language_model::LanguageModelToolResultContent; use language_model::{
LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use ui::prelude::*; use ui::prelude::*;
@ -50,6 +52,11 @@ impl AgentTool for WebSearchTool {
"Searching the Web".into() "Searching the Web".into()
} }
/// We currently only support Zed Cloud as a provider.
fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
provider == &ZED_CLOUD_PROVIDER_ID
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: Self::Input, input: Self::Input,