language_models: Dynamically detect Copilot Chat models (#29027)
I noticed the discussion in #28881, and had thought of exactly the same a few days prior. This implementation should preserve existing functionality fairly well. I've added a dependency (serde_with) to allow the deserializer to skip models which cannot be deserialized, which could occur if a future provider, for instance, is added. Without this modification, such a change could break all models. If extra dependencies aren't desired, a manual implementation could be used instead. - Closes #29369 Release Notes: - Dynamically detect available Copilot Chat models, including all models with tool support --------- Co-authored-by: AidanV <aidanvanduyne@gmail.com> Co-authored-by: imumesh18 <umesh4257@gmail.com> Co-authored-by: Bennet Bo Fenner <bennet@zed.dev> Co-authored-by: Agus Zubiaga <hi@aguz.me>
This commit is contained in:
parent
634b275931
commit
f14e48d202
5 changed files with 320 additions and 182 deletions
|
@ -14,7 +14,6 @@ doctest = false
|
|||
|
||||
[features]
|
||||
default = []
|
||||
schemars = ["dep:schemars"]
|
||||
test-support = [
|
||||
"collections/test-support",
|
||||
"gpui/test-support",
|
||||
|
@ -43,16 +42,15 @@ node_runtime.workspace = true
|
|||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
project.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
strum.workspace = true
|
||||
task.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
itertools.workspace = true
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
async-std = { version = "1.12.0", features = ["unstable"] }
|
||||
|
|
|
@ -9,13 +9,20 @@ use fs::Fs;
|
|||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||
use gpui::{App, AsyncApp, Global, prelude::*};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use itertools::Itertools;
|
||||
use paths::home_dir;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::watch_config_dir;
|
||||
use strum::EnumIter;
|
||||
|
||||
pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
|
||||
pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
|
||||
pub const COPILOT_CHAT_MODELS_URL: &str = "https://api.githubcopilot.com/models";
|
||||
|
||||
// Copilot's base model; defined by Microsoft in premium requests table
|
||||
// This will be moved to the front of the Copilot model list, and will be used for
|
||||
// 'fast' requests (e.g. title generation)
|
||||
// https://docs.github.com/en/copilot/managing-copilot/monitoring-usage-and-entitlements/about-premium-requests
|
||||
const DEFAULT_MODEL_ID: &str = "gpt-4.1";
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
|
@ -25,132 +32,104 @@ pub enum Role {
|
|||
System,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum Model {
|
||||
#[default]
|
||||
#[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")]
|
||||
Gpt4o,
|
||||
#[serde(alias = "gpt-4", rename = "gpt-4")]
|
||||
Gpt4,
|
||||
#[serde(alias = "gpt-4.1", rename = "gpt-4.1")]
|
||||
Gpt4_1,
|
||||
#[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
|
||||
Gpt3_5Turbo,
|
||||
#[serde(alias = "o1", rename = "o1")]
|
||||
O1,
|
||||
#[serde(alias = "o1-mini", rename = "o3-mini")]
|
||||
O3Mini,
|
||||
#[serde(alias = "o3", rename = "o3")]
|
||||
O3,
|
||||
#[serde(alias = "o4-mini", rename = "o4-mini")]
|
||||
O4Mini,
|
||||
#[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
|
||||
Claude3_5Sonnet,
|
||||
#[serde(alias = "claude-3-7-sonnet", rename = "claude-3.7-sonnet")]
|
||||
Claude3_7Sonnet,
|
||||
#[serde(
|
||||
alias = "claude-3.7-sonnet-thought",
|
||||
rename = "claude-3.7-sonnet-thought"
|
||||
)]
|
||||
Claude3_7SonnetThinking,
|
||||
#[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
|
||||
Gemini20Flash,
|
||||
#[serde(alias = "gemini-2.5-pro", rename = "gemini-2.5-pro")]
|
||||
Gemini25Pro,
|
||||
#[derive(Deserialize)]
|
||||
struct ModelSchema {
|
||||
#[serde(deserialize_with = "deserialize_models_skip_errors")]
|
||||
data: Vec<Model>,
|
||||
}
|
||||
|
||||
fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
|
||||
let models = raw_values
|
||||
.into_iter()
|
||||
.filter_map(|value| match serde_json::from_value::<Model>(value) {
|
||||
Ok(model) => Some(model),
|
||||
Err(err) => {
|
||||
log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct Model {
|
||||
capabilities: ModelCapabilities,
|
||||
id: String,
|
||||
name: String,
|
||||
policy: Option<ModelPolicy>,
|
||||
vendor: ModelVendor,
|
||||
model_picker_enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
struct ModelCapabilities {
|
||||
family: String,
|
||||
#[serde(default)]
|
||||
limits: ModelLimits,
|
||||
supports: ModelSupportedFeatures,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
struct ModelLimits {
|
||||
#[serde(default)]
|
||||
max_context_window_tokens: usize,
|
||||
#[serde(default)]
|
||||
max_output_tokens: usize,
|
||||
#[serde(default)]
|
||||
max_prompt_tokens: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
struct ModelPolicy {
|
||||
state: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
struct ModelSupportedFeatures {
|
||||
#[serde(default)]
|
||||
streaming: bool,
|
||||
#[serde(default)]
|
||||
tool_calls: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub enum ModelVendor {
|
||||
// Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
|
||||
#[serde(alias = "Azure OpenAI")]
|
||||
OpenAI,
|
||||
Google,
|
||||
Anthropic,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn default_fast() -> Self {
|
||||
Self::Claude3_7Sonnet
|
||||
}
|
||||
|
||||
pub fn uses_streaming(&self) -> bool {
|
||||
match self {
|
||||
Self::Gpt4o
|
||||
| Self::Gpt4
|
||||
| Self::Gpt4_1
|
||||
| Self::Gpt3_5Turbo
|
||||
| Self::O3
|
||||
| Self::O4Mini
|
||||
| Self::Claude3_5Sonnet
|
||||
| Self::Claude3_7Sonnet
|
||||
| Self::Claude3_7SonnetThinking => true,
|
||||
Self::O3Mini | Self::O1 | Self::Gemini20Flash | Self::Gemini25Pro => false,
|
||||
}
|
||||
self.capabilities.supports.streaming
|
||||
}
|
||||
|
||||
pub fn from_id(id: &str) -> Result<Self> {
|
||||
match id {
|
||||
"gpt-4o" => Ok(Self::Gpt4o),
|
||||
"gpt-4" => Ok(Self::Gpt4),
|
||||
"gpt-4.1" => Ok(Self::Gpt4_1),
|
||||
"gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
|
||||
"o1" => Ok(Self::O1),
|
||||
"o3-mini" => Ok(Self::O3Mini),
|
||||
"o3" => Ok(Self::O3),
|
||||
"o4-mini" => Ok(Self::O4Mini),
|
||||
"claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
|
||||
"claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
|
||||
"claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
|
||||
"gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
|
||||
"gemini-2.5-pro" => Ok(Self::Gemini25Pro),
|
||||
_ => Err(anyhow!("Invalid model id: {}", id)),
|
||||
}
|
||||
pub fn id(&self) -> &str {
|
||||
self.id.as_str()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Gpt3_5Turbo => "gpt-3.5-turbo",
|
||||
Self::Gpt4 => "gpt-4",
|
||||
Self::Gpt4_1 => "gpt-4.1",
|
||||
Self::Gpt4o => "gpt-4o",
|
||||
Self::O3Mini => "o3-mini",
|
||||
Self::O1 => "o1",
|
||||
Self::O3 => "o3",
|
||||
Self::O4Mini => "o4-mini",
|
||||
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
|
||||
Self::Claude3_7Sonnet => "claude-3-7-sonnet",
|
||||
Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
|
||||
Self::Gemini20Flash => "gemini-2.0-flash-001",
|
||||
Self::Gemini25Pro => "gemini-2.5-pro",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Gpt3_5Turbo => "GPT-3.5",
|
||||
Self::Gpt4 => "GPT-4",
|
||||
Self::Gpt4_1 => "GPT-4.1",
|
||||
Self::Gpt4o => "GPT-4o",
|
||||
Self::O3Mini => "o3-mini",
|
||||
Self::O1 => "o1",
|
||||
Self::O3 => "o3",
|
||||
Self::O4Mini => "o4-mini",
|
||||
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
||||
Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
|
||||
Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
|
||||
Self::Gemini20Flash => "Gemini 2.0 Flash",
|
||||
Self::Gemini25Pro => "Gemini 2.5 Pro",
|
||||
}
|
||||
pub fn display_name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
Self::Gpt4o => 64_000,
|
||||
Self::Gpt4 => 32_768,
|
||||
Self::Gpt4_1 => 128_000,
|
||||
Self::Gpt3_5Turbo => 12_288,
|
||||
Self::O3Mini => 64_000,
|
||||
Self::O1 => 20_000,
|
||||
Self::O3 => 128_000,
|
||||
Self::O4Mini => 128_000,
|
||||
Self::Claude3_5Sonnet => 200_000,
|
||||
Self::Claude3_7Sonnet => 90_000,
|
||||
Self::Claude3_7SonnetThinking => 90_000,
|
||||
Self::Gemini20Flash => 128_000,
|
||||
Self::Gemini25Pro => 128_000,
|
||||
}
|
||||
self.capabilities.limits.max_prompt_tokens
|
||||
}
|
||||
|
||||
pub fn supports_tools(&self) -> bool {
|
||||
self.capabilities.supports.tool_calls
|
||||
}
|
||||
|
||||
pub fn vendor(&self) -> ModelVendor {
|
||||
self.vendor
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -160,7 +139,7 @@ pub struct Request {
|
|||
pub n: usize,
|
||||
pub stream: bool,
|
||||
pub temperature: f32,
|
||||
pub model: Model,
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tools: Vec<Tool>,
|
||||
|
@ -306,6 +285,7 @@ impl Global for GlobalCopilotChat {}
|
|||
pub struct CopilotChat {
|
||||
oauth_token: Option<String>,
|
||||
api_token: Option<ApiToken>,
|
||||
models: Option<Vec<Model>>,
|
||||
client: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
||||
|
@ -342,31 +322,56 @@ impl CopilotChat {
|
|||
let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
|
||||
let dir_path = copilot_chat_config_dir();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let mut parent_watch_rx = watch_config_dir(
|
||||
cx.background_executor(),
|
||||
fs.clone(),
|
||||
dir_path.clone(),
|
||||
config_paths,
|
||||
);
|
||||
while let Some(contents) = parent_watch_rx.next().await {
|
||||
let oauth_token = extract_oauth_token(contents);
|
||||
cx.update(|cx| {
|
||||
if let Some(this) = Self::global(cx).as_ref() {
|
||||
this.update(cx, |this, cx| {
|
||||
this.oauth_token = oauth_token;
|
||||
cx.notify();
|
||||
});
|
||||
cx.spawn({
|
||||
let client = client.clone();
|
||||
async move |cx| {
|
||||
let mut parent_watch_rx = watch_config_dir(
|
||||
cx.background_executor(),
|
||||
fs.clone(),
|
||||
dir_path.clone(),
|
||||
config_paths,
|
||||
);
|
||||
while let Some(contents) = parent_watch_rx.next().await {
|
||||
let oauth_token = extract_oauth_token(contents);
|
||||
cx.update(|cx| {
|
||||
if let Some(this) = Self::global(cx).as_ref() {
|
||||
this.update(cx, |this, cx| {
|
||||
this.oauth_token = oauth_token.clone();
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
})?;
|
||||
|
||||
if let Some(ref oauth_token) = oauth_token {
|
||||
let api_token = request_api_token(oauth_token, client.clone()).await?;
|
||||
cx.update(|cx| {
|
||||
if let Some(this) = Self::global(cx).as_ref() {
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_token = Some(api_token.clone());
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
})?;
|
||||
let models = get_models(api_token.api_key, client.clone()).await?;
|
||||
cx.update(|cx| {
|
||||
if let Some(this) = Self::global(cx).as_ref() {
|
||||
this.update(cx, |this, cx| {
|
||||
this.models = Some(models);
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
})?;
|
||||
}
|
||||
})?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
Self {
|
||||
oauth_token: None,
|
||||
api_token: None,
|
||||
models: None,
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
@ -375,6 +380,10 @@ impl CopilotChat {
|
|||
self.oauth_token.is_some()
|
||||
}
|
||||
|
||||
pub fn models(&self) -> Option<&[Model]> {
|
||||
self.models.as_deref()
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
request: Request,
|
||||
mut cx: AsyncApp,
|
||||
|
@ -409,6 +418,61 @@ impl CopilotChat {
|
|||
}
|
||||
}
|
||||
|
||||
async fn get_models(api_token: String, client: Arc<dyn HttpClient>) -> Result<Vec<Model>> {
|
||||
let all_models = request_models(api_token, client).await?;
|
||||
|
||||
let mut models: Vec<Model> = all_models
|
||||
.into_iter()
|
||||
.filter(|model| {
|
||||
// Ensure user has access to the model; Policy is present only for models that must be
|
||||
// enabled in the GitHub dashboard
|
||||
model.model_picker_enabled
|
||||
&& model
|
||||
.policy
|
||||
.as_ref()
|
||||
.is_none_or(|policy| policy.state == "enabled")
|
||||
})
|
||||
// The first model from the API response, in any given family, appear to be the non-tagged
|
||||
// models, which are likely the best choice (e.g. gpt-4o rather than gpt-4o-2024-11-20)
|
||||
.dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
|
||||
.collect();
|
||||
|
||||
if let Some(default_model_position) =
|
||||
models.iter().position(|model| model.id == DEFAULT_MODEL_ID)
|
||||
{
|
||||
let default_model = models.remove(default_model_position);
|
||||
models.insert(0, default_model);
|
||||
}
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
async fn request_models(api_token: String, client: Arc<dyn HttpClient>) -> Result<Vec<Model>> {
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::GET)
|
||||
.uri(COPILOT_CHAT_MODELS_URL)
|
||||
.header("Authorization", format!("Bearer {}", api_token))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Copilot-Integration-Id", "vscode-chat");
|
||||
|
||||
let request = request_builder.body(AsyncBody::empty())?;
|
||||
|
||||
let mut response = client.send(request).await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
|
||||
let body_str = std::str::from_utf8(&body)?;
|
||||
|
||||
let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
|
||||
|
||||
Ok(models)
|
||||
} else {
|
||||
Err(anyhow!("Failed to request models: {}", response.status()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::GET)
|
||||
|
@ -527,3 +591,82 @@ async fn stream_completion(
|
|||
Ok(futures::stream::once(async move { Ok(response) }).boxed())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_resilient_model_schema_deserialize() {
|
||||
let json = r#"{
|
||||
"data": [
|
||||
{
|
||||
"capabilities": {
|
||||
"family": "gpt-4",
|
||||
"limits": {
|
||||
"max_context_window_tokens": 32768,
|
||||
"max_output_tokens": 4096,
|
||||
"max_prompt_tokens": 32768
|
||||
},
|
||||
"object": "model_capabilities",
|
||||
"supports": { "streaming": true, "tool_calls": true },
|
||||
"tokenizer": "cl100k_base",
|
||||
"type": "chat"
|
||||
},
|
||||
"id": "gpt-4",
|
||||
"model_picker_enabled": false,
|
||||
"name": "GPT 4",
|
||||
"object": "model",
|
||||
"preview": false,
|
||||
"vendor": "Azure OpenAI",
|
||||
"version": "gpt-4-0613"
|
||||
},
|
||||
{
|
||||
"some-unknown-field": 123
|
||||
},
|
||||
{
|
||||
"capabilities": {
|
||||
"family": "claude-3.7-sonnet",
|
||||
"limits": {
|
||||
"max_context_window_tokens": 200000,
|
||||
"max_output_tokens": 16384,
|
||||
"max_prompt_tokens": 90000,
|
||||
"vision": {
|
||||
"max_prompt_image_size": 3145728,
|
||||
"max_prompt_images": 1,
|
||||
"supported_media_types": ["image/jpeg", "image/png", "image/webp"]
|
||||
}
|
||||
},
|
||||
"object": "model_capabilities",
|
||||
"supports": {
|
||||
"parallel_tool_calls": true,
|
||||
"streaming": true,
|
||||
"tool_calls": true,
|
||||
"vision": true
|
||||
},
|
||||
"tokenizer": "o200k_base",
|
||||
"type": "chat"
|
||||
},
|
||||
"id": "claude-3.7-sonnet",
|
||||
"model_picker_enabled": true,
|
||||
"name": "Claude 3.7 Sonnet",
|
||||
"object": "model",
|
||||
"policy": {
|
||||
"state": "enabled",
|
||||
"terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
|
||||
},
|
||||
"preview": false,
|
||||
"vendor": "Anthropic",
|
||||
"version": "claude-3.7-sonnet"
|
||||
}
|
||||
],
|
||||
"object": "list"
|
||||
}"#;
|
||||
|
||||
let schema: ModelSchema = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(schema.data.len(), 2);
|
||||
assert_eq!(schema.data[0].id, "gpt-4");
|
||||
assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue