agent2: Allow tools to be provider specific (#36111)

Our WebSearch tool requires access to a Zed provider

Release Notes:

- N/A
This commit is contained in:
Ben Brandt 2025-08-13 15:22:05 +02:00 committed by GitHub
parent 7f1a5c6ad7
commit 2b3dbe8815
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 4 deletions

View file

@ -15,9 +15,9 @@ use futures::{
use gpui::{App, Context, Entity, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
};
use log;
use project::Project;
@ -681,10 +681,12 @@ impl Thread {
.profiles
.get(&self.profile_id)
.context("profile not found")?;
let provider_id = self.selected_model.provider_id();
Ok(self
.tools
.iter()
.filter(move |(_, tool)| tool.supported_provider(&provider_id))
.filter_map(|(tool_name, tool)| {
if profile.is_tool_enabled(tool_name) {
Some(tool)
@ -782,6 +784,12 @@ where
schemars::schema_for!(Self::Input)
}
/// Some tools rely on a provider for the underlying billing or other reasons.
/// Allow the tool to check if they are compatible, or should be filtered out.
fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
true
}
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
@ -808,6 +816,9 @@ pub trait AnyAgentTool {
fn kind(&self) -> acp::ToolKind;
fn initial_title(&self, input: serde_json::Value) -> SharedString;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
true
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
@ -843,6 +854,10 @@ where
Ok(json)
}
fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
self.0.supported_provider(provider)
}
fn run(
self: Arc<Self>,
input: serde_json::Value,

View file

@ -5,7 +5,9 @@ use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
use cloud_llm_client::WebSearchResponse;
use gpui::{App, AppContext, Task};
use language_model::LanguageModelToolResultContent;
use language_model::{
LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use ui::prelude::*;
@ -50,6 +52,11 @@ impl AgentTool for WebSearchTool {
"Searching the Web".into()
}
/// We currently only support Zed Cloud as a provider.
fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
provider == &ZED_CLOUD_PROVIDER_ID
}
fn run(
self: Arc<Self>,
input: Self::Input,