copilot: Add support for new models (#19968)

Closes #19963

This PR implements integration with the newly announced GitHub Copilot
LLM models, including:
- Claude 3.5 Sonnet
- o1-mini
- o1-preview

Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
This commit is contained in:
Jonathan Toledo 2024-11-04 04:55:20 -05:00 committed by GitHub
parent 070e5914c9
commit 67be6ec3b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 79 additions and 26 deletions

View file

@ -35,14 +35,30 @@ pub enum Model {
Gpt4, Gpt4,
#[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")] #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
Gpt3_5Turbo, Gpt3_5Turbo,
#[serde(alias = "o1-preview", rename = "o1-preview-2024-09-12")]
O1Preview,
#[serde(alias = "o1-mini", rename = "o1-mini-2024-09-12")]
O1Mini,
#[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
Claude3_5Sonnet,
} }
impl Model { impl Model {
pub fn uses_streaming(&self) -> bool {
match self {
Self::Gpt4o | Self::Gpt4 | Self::Gpt3_5Turbo | Self::Claude3_5Sonnet => true,
Self::O1Mini | Self::O1Preview => false,
}
}
pub fn from_id(id: &str) -> Result<Self> { pub fn from_id(id: &str) -> Result<Self> {
match id { match id {
"gpt-4o" => Ok(Self::Gpt4o), "gpt-4o" => Ok(Self::Gpt4o),
"gpt-4" => Ok(Self::Gpt4), "gpt-4" => Ok(Self::Gpt4),
"gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo), "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
"o1-preview" => Ok(Self::O1Preview),
"o1-mini" => Ok(Self::O1Mini),
"claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
_ => Err(anyhow!("Invalid model id: {}", id)), _ => Err(anyhow!("Invalid model id: {}", id)),
} }
} }
@ -52,6 +68,9 @@ impl Model {
Self::Gpt3_5Turbo => "gpt-3.5-turbo", Self::Gpt3_5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4", Self::Gpt4 => "gpt-4",
Self::Gpt4o => "gpt-4o", Self::Gpt4o => "gpt-4o",
Self::O1Mini => "o1-mini",
Self::O1Preview => "o1-preview",
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
} }
} }
@ -60,6 +79,9 @@ impl Model {
Self::Gpt3_5Turbo => "GPT-3.5", Self::Gpt3_5Turbo => "GPT-3.5",
Self::Gpt4 => "GPT-4", Self::Gpt4 => "GPT-4",
Self::Gpt4o => "GPT-4o", Self::Gpt4o => "GPT-4o",
Self::O1Mini => "o1-mini",
Self::O1Preview => "o1-preview",
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
} }
} }
@ -68,6 +90,9 @@ impl Model {
Self::Gpt4o => 128000, Self::Gpt4o => 128000,
Self::Gpt4 => 8192, Self::Gpt4 => 8192,
Self::Gpt3_5Turbo => 16385, Self::Gpt3_5Turbo => 16385,
Self::O1Mini => 128000,
Self::O1Preview => 128000,
Self::Claude3_5Sonnet => 200_000,
} }
} }
} }
@ -87,7 +112,7 @@ impl Request {
Self { Self {
intent: true, intent: true,
n: 1, n: 1,
stream: true, stream: model.uses_streaming(),
temperature: 0.1, temperature: 0.1,
model, model,
messages, messages,
@ -113,7 +138,8 @@ pub struct ResponseEvent {
pub struct ResponseChoice { pub struct ResponseChoice {
pub index: usize, pub index: usize,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
pub delta: ResponseDelta, pub delta: Option<ResponseDelta>,
pub message: Option<ResponseDelta>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -333,9 +359,23 @@ async fn stream_completion(
if let Some(low_speed_timeout) = low_speed_timeout { if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.read_timeout(low_speed_timeout); request_builder = request_builder.read_timeout(low_speed_timeout);
} }
let is_streaming = request.stream;
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?; let mut response = client.send(request).await?;
if response.status().is_success() {
if !response.status().is_success() {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
return Err(anyhow!(
"Failed to connect to API: {} {}",
response.status(),
body_str
));
}
if is_streaming {
let reader = BufReader::new(response.into_body()); let reader = BufReader::new(response.into_body());
Ok(reader Ok(reader
.lines() .lines()
@ -367,19 +407,9 @@ async fn stream_completion(
} else { } else {
let mut body = Vec::new(); let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?; response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?; let body_str = std::str::from_utf8(&body)?;
let response: ResponseEvent = serde_json::from_str(body_str)?;
match serde_json::from_str::<ResponseEvent>(body_str) { Ok(futures::stream::once(async move { Ok(response) }).boxed())
Ok(_) => Err(anyhow!(
"Unexpected success response while expecting an error: {}",
body_str,
)),
Err(_) => Err(anyhow!(
"Failed to connect to API: {} {}",
response.status(),
body_str,
)),
}
} }
} }

View file

@ -30,6 +30,7 @@ use crate::{
}; };
use crate::{LanguageModelCompletionEvent, LanguageModelProviderState}; use crate::{LanguageModelCompletionEvent, LanguageModelProviderState};
use super::anthropic::count_anthropic_tokens;
use super::open_ai::count_open_ai_tokens; use super::open_ai::count_open_ai_tokens;
const PROVIDER_ID: &str = "copilot_chat"; const PROVIDER_ID: &str = "copilot_chat";
@ -179,14 +180,20 @@ impl LanguageModel for CopilotChatLanguageModel {
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AppContext, cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> { ) -> BoxFuture<'static, Result<usize>> {
match self.model {
CopilotChatModel::Claude3_5Sonnet => count_anthropic_tokens(request, cx),
_ => {
let model = match self.model { let model = match self.model {
CopilotChatModel::Gpt4o => open_ai::Model::FourOmni, CopilotChatModel::Gpt4o => open_ai::Model::FourOmni,
CopilotChatModel::Gpt4 => open_ai::Model::Four, CopilotChatModel::Gpt4 => open_ai::Model::Four,
CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo, CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo,
CopilotChatModel::O1Preview | CopilotChatModel::O1Mini => open_ai::Model::Four,
CopilotChatModel::Claude3_5Sonnet => unreachable!(),
}; };
count_open_ai_tokens(request, model, cx) count_open_ai_tokens(request, model, cx)
} }
}
}
fn stream_completion( fn stream_completion(
&self, &self,
@ -209,7 +216,8 @@ impl LanguageModel for CopilotChatLanguageModel {
} }
} }
let request = self.to_copilot_chat_request(request); let copilot_request = self.to_copilot_chat_request(request);
let is_streaming = copilot_request.stream;
let Ok(low_speed_timeout) = cx.update(|cx| { let Ok(low_speed_timeout) = cx.update(|cx| {
AllLanguageModelSettings::get_global(cx) AllLanguageModelSettings::get_global(cx)
.copilot_chat .copilot_chat
@ -220,16 +228,31 @@ impl LanguageModel for CopilotChatLanguageModel {
let request_limiter = self.request_limiter.clone(); let request_limiter = self.request_limiter.clone();
let future = cx.spawn(|cx| async move { let future = cx.spawn(|cx| async move {
let response = CopilotChat::stream_completion(request, low_speed_timeout, cx); let response = CopilotChat::stream_completion(copilot_request, low_speed_timeout, cx);
request_limiter.stream(async move { request_limiter.stream(async move {
let response = response.await?; let response = response.await?;
let stream = response let stream = response
.filter_map(|response| async move { .filter_map(move |response| async move {
match response { match response {
Ok(result) => { Ok(result) => {
let choice = result.choices.first(); let choice = result.choices.first();
match choice { match choice {
Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())), Some(choice) if !is_streaming => {
match &choice.message {
Some(msg) => Some(Ok(msg.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no message content"
))),
}
},
Some(choice) => {
match &choice.delta {
Some(delta) => Some(Ok(delta.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no delta content"
))),
}
},
None => Some(Err(anyhow::anyhow!( None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again." "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
))), ))),