agent2: Allow tools to be provider specific (#36111)
Our WebSearch tool requires access to a Zed provider Release Notes: - N/A
This commit is contained in:
parent
7f1a5c6ad7
commit
2b3dbe8815
2 changed files with 26 additions and 4 deletions
|
@ -15,9 +15,9 @@ use futures::{
|
||||||
use gpui::{App, Context, Entity, SharedString, Task};
|
use gpui::{App, Context, Entity, SharedString, Task};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
||||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
|
||||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||||
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||||
};
|
};
|
||||||
use log;
|
use log;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
@ -681,10 +681,12 @@ impl Thread {
|
||||||
.profiles
|
.profiles
|
||||||
.get(&self.profile_id)
|
.get(&self.profile_id)
|
||||||
.context("profile not found")?;
|
.context("profile not found")?;
|
||||||
|
let provider_id = self.selected_model.provider_id();
|
||||||
|
|
||||||
Ok(self
|
Ok(self
|
||||||
.tools
|
.tools
|
||||||
.iter()
|
.iter()
|
||||||
|
.filter(move |(_, tool)| tool.supported_provider(&provider_id))
|
||||||
.filter_map(|(tool_name, tool)| {
|
.filter_map(|(tool_name, tool)| {
|
||||||
if profile.is_tool_enabled(tool_name) {
|
if profile.is_tool_enabled(tool_name) {
|
||||||
Some(tool)
|
Some(tool)
|
||||||
|
@ -782,6 +784,12 @@ where
|
||||||
schemars::schema_for!(Self::Input)
|
schemars::schema_for!(Self::Input)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Some tools rely on a provider for the underlying billing or other reasons.
|
||||||
|
/// Allow the tool to check if they are compatible, or should be filtered out.
|
||||||
|
fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
/// Runs the tool with the provided input.
|
/// Runs the tool with the provided input.
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
|
@ -808,6 +816,9 @@ pub trait AnyAgentTool {
|
||||||
fn kind(&self) -> acp::ToolKind;
|
fn kind(&self) -> acp::ToolKind;
|
||||||
fn initial_title(&self, input: serde_json::Value) -> SharedString;
|
fn initial_title(&self, input: serde_json::Value) -> SharedString;
|
||||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||||
|
fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
|
@ -843,6 +854,10 @@ where
|
||||||
Ok(json)
|
Ok(json)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
|
||||||
|
self.0.supported_provider(provider)
|
||||||
|
}
|
||||||
|
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
|
|
|
@ -5,7 +5,9 @@ use agent_client_protocol as acp;
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{Result, anyhow};
|
||||||
use cloud_llm_client::WebSearchResponse;
|
use cloud_llm_client::WebSearchResponse;
|
||||||
use gpui::{App, AppContext, Task};
|
use gpui::{App, AppContext, Task};
|
||||||
use language_model::LanguageModelToolResultContent;
|
use language_model::{
|
||||||
|
LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
|
||||||
|
};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
|
@ -50,6 +52,11 @@ impl AgentTool for WebSearchTool {
|
||||||
"Searching the Web".into()
|
"Searching the Web".into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// We currently only support Zed Cloud as a provider.
|
||||||
|
fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
|
||||||
|
provider == &ZED_CLOUD_PROVIDER_ID
|
||||||
|
}
|
||||||
|
|
||||||
fn run(
|
fn run(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
input: Self::Input,
|
input: Self::Input,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue