language_models: Dynamically detect Copilot Chat models (#29027)

I noticed the discussion in #28881, and had thought of exactly the same
a few days prior.

This implementation should preserve existing functionality fairly well.

I've added a dependency (serde_with) to allow the deserializer to skip
models which cannot be deserialized, which could occur if a future
provider, for instance, is added. Without this modification, such a
change could break all models. If extra dependencies aren't desired, a
manual implementation could be used instead.

- Closes #29369 

Release Notes:

- Dynamically detect available Copilot Chat models, including all models
with tool support

---------

Co-authored-by: AidanV <aidanvanduyne@gmail.com>
Co-authored-by: imumesh18 <umesh4257@gmail.com>
Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
Co-authored-by: Agus Zubiaga <hi@aguz.me>
This commit is contained in:
Liam 2025-05-12 11:28:41 +00:00 committed by GitHub
parent 634b275931
commit f14e48d202
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 320 additions and 182 deletions

3
Cargo.lock generated
View file

@ -3309,6 +3309,7 @@ dependencies = [
"http_client", "http_client",
"indoc", "indoc",
"inline_completion", "inline_completion",
"itertools 0.14.0",
"language", "language",
"log", "log",
"lsp", "lsp",
@ -3318,11 +3319,9 @@ dependencies = [
"paths", "paths",
"project", "project",
"rpc", "rpc",
"schemars",
"serde", "serde",
"serde_json", "serde_json",
"settings", "settings",
"strum 0.27.1",
"task", "task",
"theme", "theme",
"ui", "ui",

View file

@ -14,7 +14,6 @@ doctest = false
[features] [features]
default = [] default = []
schemars = ["dep:schemars"]
test-support = [ test-support = [
"collections/test-support", "collections/test-support",
"gpui/test-support", "gpui/test-support",
@ -43,16 +42,15 @@ node_runtime.workspace = true
parking_lot.workspace = true parking_lot.workspace = true
paths.workspace = true paths.workspace = true
project.workspace = true project.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
settings.workspace = true settings.workspace = true
strum.workspace = true
task.workspace = true task.workspace = true
ui.workspace = true ui.workspace = true
util.workspace = true util.workspace = true
workspace.workspace = true workspace.workspace = true
workspace-hack.workspace = true workspace-hack.workspace = true
itertools.workspace = true
[target.'cfg(windows)'.dependencies] [target.'cfg(windows)'.dependencies]
async-std = { version = "1.12.0", features = ["unstable"] } async-std = { version = "1.12.0", features = ["unstable"] }

View file

@ -9,13 +9,20 @@ use fs::Fs;
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use gpui::{App, AsyncApp, Global, prelude::*}; use gpui::{App, AsyncApp, Global, prelude::*};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use itertools::Itertools;
use paths::home_dir; use paths::home_dir;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::watch_config_dir; use settings::watch_config_dir;
use strum::EnumIter;
pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions"; pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token"; pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
pub const COPILOT_CHAT_MODELS_URL: &str = "https://api.githubcopilot.com/models";
// Copilot's base model; defined by Microsoft in premium requests table
// This will be moved to the front of the Copilot model list, and will be used for
// 'fast' requests (e.g. title generation)
// https://docs.github.com/en/copilot/managing-copilot/monitoring-usage-and-entitlements/about-premium-requests
const DEFAULT_MODEL_ID: &str = "gpt-4.1";
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
@ -25,132 +32,104 @@ pub enum Role {
System, System,
} }
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Deserialize)]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] struct ModelSchema {
pub enum Model { #[serde(deserialize_with = "deserialize_models_skip_errors")]
#[default] data: Vec<Model>,
#[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")] }
Gpt4o,
#[serde(alias = "gpt-4", rename = "gpt-4")] fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
Gpt4, where
#[serde(alias = "gpt-4.1", rename = "gpt-4.1")] D: serde::Deserializer<'de>,
Gpt4_1, {
#[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")] let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
Gpt3_5Turbo, let models = raw_values
#[serde(alias = "o1", rename = "o1")] .into_iter()
O1, .filter_map(|value| match serde_json::from_value::<Model>(value) {
#[serde(alias = "o1-mini", rename = "o3-mini")] Ok(model) => Some(model),
O3Mini, Err(err) => {
#[serde(alias = "o3", rename = "o3")] log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
O3, None
#[serde(alias = "o4-mini", rename = "o4-mini")] }
O4Mini, })
#[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")] .collect();
Claude3_5Sonnet,
#[serde(alias = "claude-3-7-sonnet", rename = "claude-3.7-sonnet")] Ok(models)
Claude3_7Sonnet, }
#[serde(
alias = "claude-3.7-sonnet-thought", #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
rename = "claude-3.7-sonnet-thought" pub struct Model {
)] capabilities: ModelCapabilities,
Claude3_7SonnetThinking, id: String,
#[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")] name: String,
Gemini20Flash, policy: Option<ModelPolicy>,
#[serde(alias = "gemini-2.5-pro", rename = "gemini-2.5-pro")] vendor: ModelVendor,
Gemini25Pro, model_picker_enabled: bool,
}
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
struct ModelCapabilities {
family: String,
#[serde(default)]
limits: ModelLimits,
supports: ModelSupportedFeatures,
}
#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
struct ModelLimits {
#[serde(default)]
max_context_window_tokens: usize,
#[serde(default)]
max_output_tokens: usize,
#[serde(default)]
max_prompt_tokens: usize,
}
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
struct ModelPolicy {
state: String,
}
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
struct ModelSupportedFeatures {
#[serde(default)]
streaming: bool,
#[serde(default)]
tool_calls: bool,
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
pub enum ModelVendor {
// Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
#[serde(alias = "Azure OpenAI")]
OpenAI,
Google,
Anthropic,
} }
impl Model { impl Model {
pub fn default_fast() -> Self {
Self::Claude3_7Sonnet
}
pub fn uses_streaming(&self) -> bool { pub fn uses_streaming(&self) -> bool {
match self { self.capabilities.supports.streaming
Self::Gpt4o
| Self::Gpt4
| Self::Gpt4_1
| Self::Gpt3_5Turbo
| Self::O3
| Self::O4Mini
| Self::Claude3_5Sonnet
| Self::Claude3_7Sonnet
| Self::Claude3_7SonnetThinking => true,
Self::O3Mini | Self::O1 | Self::Gemini20Flash | Self::Gemini25Pro => false,
}
} }
pub fn from_id(id: &str) -> Result<Self> { pub fn id(&self) -> &str {
match id { self.id.as_str()
"gpt-4o" => Ok(Self::Gpt4o),
"gpt-4" => Ok(Self::Gpt4),
"gpt-4.1" => Ok(Self::Gpt4_1),
"gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
"o1" => Ok(Self::O1),
"o3-mini" => Ok(Self::O3Mini),
"o3" => Ok(Self::O3),
"o4-mini" => Ok(Self::O4Mini),
"claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
"claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
"claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
"gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
"gemini-2.5-pro" => Ok(Self::Gemini25Pro),
_ => Err(anyhow!("Invalid model id: {}", id)),
}
} }
pub fn id(&self) -> &'static str { pub fn display_name(&self) -> &str {
match self { self.name.as_str()
Self::Gpt3_5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
Self::Gpt4_1 => "gpt-4.1",
Self::Gpt4o => "gpt-4o",
Self::O3Mini => "o3-mini",
Self::O1 => "o1",
Self::O3 => "o3",
Self::O4Mini => "o4-mini",
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
Self::Claude3_7Sonnet => "claude-3-7-sonnet",
Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
Self::Gemini20Flash => "gemini-2.0-flash-001",
Self::Gemini25Pro => "gemini-2.5-pro",
}
}
pub fn display_name(&self) -> &'static str {
match self {
Self::Gpt3_5Turbo => "GPT-3.5",
Self::Gpt4 => "GPT-4",
Self::Gpt4_1 => "GPT-4.1",
Self::Gpt4o => "GPT-4o",
Self::O3Mini => "o3-mini",
Self::O1 => "o1",
Self::O3 => "o3",
Self::O4Mini => "o4-mini",
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
Self::Gemini20Flash => "Gemini 2.0 Flash",
Self::Gemini25Pro => "Gemini 2.5 Pro",
}
} }
pub fn max_token_count(&self) -> usize { pub fn max_token_count(&self) -> usize {
match self { self.capabilities.limits.max_prompt_tokens
Self::Gpt4o => 64_000, }
Self::Gpt4 => 32_768,
Self::Gpt4_1 => 128_000, pub fn supports_tools(&self) -> bool {
Self::Gpt3_5Turbo => 12_288, self.capabilities.supports.tool_calls
Self::O3Mini => 64_000, }
Self::O1 => 20_000,
Self::O3 => 128_000, pub fn vendor(&self) -> ModelVendor {
Self::O4Mini => 128_000, self.vendor
Self::Claude3_5Sonnet => 200_000,
Self::Claude3_7Sonnet => 90_000,
Self::Claude3_7SonnetThinking => 90_000,
Self::Gemini20Flash => 128_000,
Self::Gemini25Pro => 128_000,
}
} }
} }
@ -160,7 +139,7 @@ pub struct Request {
pub n: usize, pub n: usize,
pub stream: bool, pub stream: bool,
pub temperature: f32, pub temperature: f32,
pub model: Model, pub model: String,
pub messages: Vec<ChatMessage>, pub messages: Vec<ChatMessage>,
#[serde(default, skip_serializing_if = "Vec::is_empty")] #[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>, pub tools: Vec<Tool>,
@ -306,6 +285,7 @@ impl Global for GlobalCopilotChat {}
pub struct CopilotChat { pub struct CopilotChat {
oauth_token: Option<String>, oauth_token: Option<String>,
api_token: Option<ApiToken>, api_token: Option<ApiToken>,
models: Option<Vec<Model>>,
client: Arc<dyn HttpClient>, client: Arc<dyn HttpClient>,
} }
@ -342,31 +322,56 @@ impl CopilotChat {
let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect(); let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
let dir_path = copilot_chat_config_dir(); let dir_path = copilot_chat_config_dir();
cx.spawn(async move |cx| { cx.spawn({
let mut parent_watch_rx = watch_config_dir( let client = client.clone();
cx.background_executor(), async move |cx| {
fs.clone(), let mut parent_watch_rx = watch_config_dir(
dir_path.clone(), cx.background_executor(),
config_paths, fs.clone(),
); dir_path.clone(),
while let Some(contents) = parent_watch_rx.next().await { config_paths,
let oauth_token = extract_oauth_token(contents); );
cx.update(|cx| { while let Some(contents) = parent_watch_rx.next().await {
if let Some(this) = Self::global(cx).as_ref() { let oauth_token = extract_oauth_token(contents);
this.update(cx, |this, cx| { cx.update(|cx| {
this.oauth_token = oauth_token; if let Some(this) = Self::global(cx).as_ref() {
cx.notify(); this.update(cx, |this, cx| {
}); this.oauth_token = oauth_token.clone();
cx.notify();
});
}
})?;
if let Some(ref oauth_token) = oauth_token {
let api_token = request_api_token(oauth_token, client.clone()).await?;
cx.update(|cx| {
if let Some(this) = Self::global(cx).as_ref() {
this.update(cx, |this, cx| {
this.api_token = Some(api_token.clone());
cx.notify();
});
}
})?;
let models = get_models(api_token.api_key, client.clone()).await?;
cx.update(|cx| {
if let Some(this) = Self::global(cx).as_ref() {
this.update(cx, |this, cx| {
this.models = Some(models);
cx.notify();
});
}
})?;
} }
})?; }
anyhow::Ok(())
} }
anyhow::Ok(())
}) })
.detach_and_log_err(cx); .detach_and_log_err(cx);
Self { Self {
oauth_token: None, oauth_token: None,
api_token: None, api_token: None,
models: None,
client, client,
} }
} }
@ -375,6 +380,10 @@ impl CopilotChat {
self.oauth_token.is_some() self.oauth_token.is_some()
} }
pub fn models(&self) -> Option<&[Model]> {
self.models.as_deref()
}
pub async fn stream_completion( pub async fn stream_completion(
request: Request, request: Request,
mut cx: AsyncApp, mut cx: AsyncApp,
@ -409,6 +418,61 @@ impl CopilotChat {
} }
} }
async fn get_models(api_token: String, client: Arc<dyn HttpClient>) -> Result<Vec<Model>> {
let all_models = request_models(api_token, client).await?;
let mut models: Vec<Model> = all_models
.into_iter()
.filter(|model| {
// Ensure user has access to the model; Policy is present only for models that must be
// enabled in the GitHub dashboard
model.model_picker_enabled
&& model
.policy
.as_ref()
.is_none_or(|policy| policy.state == "enabled")
})
// The first model from the API response, in any given family, appear to be the non-tagged
// models, which are likely the best choice (e.g. gpt-4o rather than gpt-4o-2024-11-20)
.dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
.collect();
if let Some(default_model_position) =
models.iter().position(|model| model.id == DEFAULT_MODEL_ID)
{
let default_model = models.remove(default_model_position);
models.insert(0, default_model);
}
Ok(models)
}
async fn request_models(api_token: String, client: Arc<dyn HttpClient>) -> Result<Vec<Model>> {
let request_builder = HttpRequest::builder()
.method(Method::GET)
.uri(COPILOT_CHAT_MODELS_URL)
.header("Authorization", format!("Bearer {}", api_token))
.header("Content-Type", "application/json")
.header("Copilot-Integration-Id", "vscode-chat");
let request = request_builder.body(AsyncBody::empty())?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
Ok(models)
} else {
Err(anyhow!("Failed to request models: {}", response.status()))
}
}
async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> { async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
let request_builder = HttpRequest::builder() let request_builder = HttpRequest::builder()
.method(Method::GET) .method(Method::GET)
@ -527,3 +591,82 @@ async fn stream_completion(
Ok(futures::stream::once(async move { Ok(response) }).boxed()) Ok(futures::stream::once(async move { Ok(response) }).boxed())
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resilient_model_schema_deserialize() {
let json = r#"{
"data": [
{
"capabilities": {
"family": "gpt-4",
"limits": {
"max_context_window_tokens": 32768,
"max_output_tokens": 4096,
"max_prompt_tokens": 32768
},
"object": "model_capabilities",
"supports": { "streaming": true, "tool_calls": true },
"tokenizer": "cl100k_base",
"type": "chat"
},
"id": "gpt-4",
"model_picker_enabled": false,
"name": "GPT 4",
"object": "model",
"preview": false,
"vendor": "Azure OpenAI",
"version": "gpt-4-0613"
},
{
"some-unknown-field": 123
},
{
"capabilities": {
"family": "claude-3.7-sonnet",
"limits": {
"max_context_window_tokens": 200000,
"max_output_tokens": 16384,
"max_prompt_tokens": 90000,
"vision": {
"max_prompt_image_size": 3145728,
"max_prompt_images": 1,
"supported_media_types": ["image/jpeg", "image/png", "image/webp"]
}
},
"object": "model_capabilities",
"supports": {
"parallel_tool_calls": true,
"streaming": true,
"tool_calls": true,
"vision": true
},
"tokenizer": "o200k_base",
"type": "chat"
},
"id": "claude-3.7-sonnet",
"model_picker_enabled": true,
"name": "Claude 3.7 Sonnet",
"object": "model",
"policy": {
"state": "enabled",
"terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
},
"preview": false,
"vendor": "Anthropic",
"version": "claude-3.7-sonnet"
}
],
"object": "list"
}"#;
let schema: ModelSchema = serde_json::from_str(&json).unwrap();
assert_eq!(schema.data.len(), 2);
assert_eq!(schema.data[0].id, "gpt-4");
assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
}
}

View file

@ -15,13 +15,15 @@ path = "src/language_models.rs"
anthropic = { workspace = true, features = ["schemars"] } anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true anyhow.workspace = true
aws-config = { workspace = true, features = ["behavior-version-latest"] } aws-config = { workspace = true, features = ["behavior-version-latest"] }
aws-credential-types = { workspace = true, features = ["hardcoded-credentials"] } aws-credential-types = { workspace = true, features = [
"hardcoded-credentials",
] }
aws_http_client.workspace = true aws_http_client.workspace = true
bedrock.workspace = true bedrock.workspace = true
client.workspace = true client.workspace = true
collections.workspace = true collections.workspace = true
credentials_provider.workspace = true credentials_provider.workspace = true
copilot = { workspace = true, features = ["schemars"] } copilot.workspace = true
deepseek = { workspace = true, features = ["schemars"] } deepseek = { workspace = true, features = ["schemars"] }
editor.workspace = true editor.workspace = true
feature_flags.workspace = true feature_flags.workspace = true

View file

@ -5,8 +5,8 @@ use std::sync::Arc;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use collections::HashMap; use collections::HashMap;
use copilot::copilot_chat::{ use copilot::copilot_chat::{
ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest, ChatMessage, CopilotChat, Model as CopilotChatModel, ModelVendor,
ResponseEvent, Tool, ToolCall, Request as CopilotChatRequest, ResponseEvent, Tool, ToolCall,
}; };
use copilot::{Copilot, Status}; use copilot::{Copilot, Status};
use futures::future::BoxFuture; use futures::future::BoxFuture;
@ -20,12 +20,11 @@ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolUse, MessageContent, LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
RateLimiter, Role, StopReason, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
}; };
use settings::SettingsStore; use settings::SettingsStore;
use std::time::Duration; use std::time::Duration;
use strum::IntoEnumIterator;
use ui::prelude::*; use ui::prelude::*;
use super::anthropic::count_anthropic_tokens; use super::anthropic::count_anthropic_tokens;
@ -100,17 +99,26 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
IconName::Copilot IconName::Copilot
} }
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(CopilotChatModel::default())) let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
models
.first()
.map(|model| self.create_language_model(model.clone()))
} }
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(CopilotChatModel::default_fast())) // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
// model (e.g. 4o) and a sensible choice when considering premium requests
self.default_model(cx)
} }
fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
CopilotChatModel::iter() let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
.map(|model| self.create_language_model(model)) return Vec::new();
};
models
.iter()
.map(|model| self.create_language_model(model.clone()))
.collect() .collect()
} }
@ -187,13 +195,15 @@ impl LanguageModel for CopilotChatLanguageModel {
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
match self.model { self.model.supports_tools()
CopilotChatModel::Gpt4o }
| CopilotChatModel::Gpt4_1
| CopilotChatModel::O4Mini fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
| CopilotChatModel::Claude3_5Sonnet match self.model.vendor() {
| CopilotChatModel::Claude3_7Sonnet => true, ModelVendor::OpenAI | ModelVendor::Anthropic => {
_ => false, LanguageModelToolSchemaFormat::JsonSchema
}
ModelVendor::Google => LanguageModelToolSchemaFormat::JsonSchemaSubset,
} }
} }
@ -218,25 +228,13 @@ impl LanguageModel for CopilotChatLanguageModel {
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &App, cx: &App,
) -> BoxFuture<'static, Result<usize>> { ) -> BoxFuture<'static, Result<usize>> {
match self.model { match self.model.vendor() {
CopilotChatModel::Claude3_5Sonnet ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
| CopilotChatModel::Claude3_7Sonnet ModelVendor::Google => count_google_tokens(request, cx),
| CopilotChatModel::Claude3_7SonnetThinking => count_anthropic_tokens(request, cx), ModelVendor::OpenAI => {
CopilotChatModel::Gemini20Flash | CopilotChatModel::Gemini25Pro => { let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default();
count_google_tokens(request, cx) count_open_ai_tokens(request, model, cx)
} }
CopilotChatModel::Gpt4o => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx),
CopilotChatModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx),
CopilotChatModel::Gpt4_1 => {
count_open_ai_tokens(request, open_ai::Model::FourPointOne, cx)
}
CopilotChatModel::Gpt3_5Turbo => {
count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
}
CopilotChatModel::O1 => count_open_ai_tokens(request, open_ai::Model::O1, cx),
CopilotChatModel::O3Mini => count_open_ai_tokens(request, open_ai::Model::O3Mini, cx),
CopilotChatModel::O3 => count_open_ai_tokens(request, open_ai::Model::O3, cx),
CopilotChatModel::O4Mini => count_open_ai_tokens(request, open_ai::Model::O4Mini, cx),
} }
} }
@ -430,8 +428,6 @@ impl CopilotChatLanguageModel {
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
) -> Result<CopilotChatRequest> { ) -> Result<CopilotChatRequest> {
let model = self.model.clone();
let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new(); let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
for message in request.messages { for message in request.messages {
if let Some(last_message) = request_messages.last_mut() { if let Some(last_message) = request_messages.last_mut() {
@ -545,9 +541,9 @@ impl CopilotChatLanguageModel {
Ok(CopilotChatRequest { Ok(CopilotChatRequest {
intent: true, intent: true,
n: 1, n: 1,
stream: model.uses_streaming(), stream: self.model.uses_streaming(),
temperature: 0.1, temperature: 0.1,
model, model: self.model.id().to_string(),
messages, messages,
tools, tools,
tool_choice: request.tool_choice.map(|choice| match choice { tool_choice: request.tool_choice.map(|choice| match choice {