diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 21627f38e2..fcb67b0c6d 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -14,7 +14,7 @@ use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use strum::EnumIter; use thiserror::Error; -use crate::LanguageModelAvailability; +use crate::{LanguageModelAvailability, LanguageModelToolSchemaFormat}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[serde(tag = "provider", rename_all = "lowercase")] @@ -113,6 +113,13 @@ impl CloudModel { }, } } + + pub fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + match self { + Self::Anthropic(_) | Self::OpenAi(_) => LanguageModelToolSchemaFormat::JsonSchema, + Self::Google(_) => LanguageModelToolSchemaFormat::JsonSchemaSubset, + } + } } #[derive(Error, Debug)] diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index bc3cc87181..376d5cd3f5 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -15,8 +15,8 @@ use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use language_model::{ AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, RateLimiter, - ZED_CLOUD_PROVIDER_ID, + LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, + LanguageModelToolSchemaFormat, RateLimiter, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, @@ -559,6 +559,10 @@ impl LanguageModel for CloudLanguageModel { self.model.availability() } + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + self.model.tool_input_format() + } + fn max_token_count(&self) -> usize { self.model.max_token_count() }