copilot: Add support for new models (#19968)
Closes #19963 This PR implements integration with the newly announced GitHub Copilot LLM models, including: - Claude 3.5 Sonnet - o1-mini - o1-preview Release Notes: - N/A --------- Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
This commit is contained in:
parent
070e5914c9
commit
67be6ec3b5
2 changed files with 79 additions and 26 deletions
|
@ -35,14 +35,30 @@ pub enum Model {
|
||||||
Gpt4,
|
Gpt4,
|
||||||
#[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
|
#[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
|
||||||
Gpt3_5Turbo,
|
Gpt3_5Turbo,
|
||||||
|
#[serde(alias = "o1-preview", rename = "o1-preview-2024-09-12")]
|
||||||
|
O1Preview,
|
||||||
|
#[serde(alias = "o1-mini", rename = "o1-mini-2024-09-12")]
|
||||||
|
O1Mini,
|
||||||
|
#[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
|
||||||
|
Claude3_5Sonnet,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
|
pub fn uses_streaming(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Gpt4o | Self::Gpt4 | Self::Gpt3_5Turbo | Self::Claude3_5Sonnet => true,
|
||||||
|
Self::O1Mini | Self::O1Preview => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn from_id(id: &str) -> Result<Self> {
|
pub fn from_id(id: &str) -> Result<Self> {
|
||||||
match id {
|
match id {
|
||||||
"gpt-4o" => Ok(Self::Gpt4o),
|
"gpt-4o" => Ok(Self::Gpt4o),
|
||||||
"gpt-4" => Ok(Self::Gpt4),
|
"gpt-4" => Ok(Self::Gpt4),
|
||||||
"gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
|
"gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
|
||||||
|
"o1-preview" => Ok(Self::O1Preview),
|
||||||
|
"o1-mini" => Ok(Self::O1Mini),
|
||||||
|
"claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
|
||||||
_ => Err(anyhow!("Invalid model id: {}", id)),
|
_ => Err(anyhow!("Invalid model id: {}", id)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -52,6 +68,9 @@ impl Model {
|
||||||
Self::Gpt3_5Turbo => "gpt-3.5-turbo",
|
Self::Gpt3_5Turbo => "gpt-3.5-turbo",
|
||||||
Self::Gpt4 => "gpt-4",
|
Self::Gpt4 => "gpt-4",
|
||||||
Self::Gpt4o => "gpt-4o",
|
Self::Gpt4o => "gpt-4o",
|
||||||
|
Self::O1Mini => "o1-mini",
|
||||||
|
Self::O1Preview => "o1-preview",
|
||||||
|
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,6 +79,9 @@ impl Model {
|
||||||
Self::Gpt3_5Turbo => "GPT-3.5",
|
Self::Gpt3_5Turbo => "GPT-3.5",
|
||||||
Self::Gpt4 => "GPT-4",
|
Self::Gpt4 => "GPT-4",
|
||||||
Self::Gpt4o => "GPT-4o",
|
Self::Gpt4o => "GPT-4o",
|
||||||
|
Self::O1Mini => "o1-mini",
|
||||||
|
Self::O1Preview => "o1-preview",
|
||||||
|
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,6 +90,9 @@ impl Model {
|
||||||
Self::Gpt4o => 128000,
|
Self::Gpt4o => 128000,
|
||||||
Self::Gpt4 => 8192,
|
Self::Gpt4 => 8192,
|
||||||
Self::Gpt3_5Turbo => 16385,
|
Self::Gpt3_5Turbo => 16385,
|
||||||
|
Self::O1Mini => 128000,
|
||||||
|
Self::O1Preview => 128000,
|
||||||
|
Self::Claude3_5Sonnet => 200_000,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -87,7 +112,7 @@ impl Request {
|
||||||
Self {
|
Self {
|
||||||
intent: true,
|
intent: true,
|
||||||
n: 1,
|
n: 1,
|
||||||
stream: true,
|
stream: model.uses_streaming(),
|
||||||
temperature: 0.1,
|
temperature: 0.1,
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
|
@ -113,7 +138,8 @@ pub struct ResponseEvent {
|
||||||
pub struct ResponseChoice {
|
pub struct ResponseChoice {
|
||||||
pub index: usize,
|
pub index: usize,
|
||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
pub delta: ResponseDelta,
|
pub delta: Option<ResponseDelta>,
|
||||||
|
pub message: Option<ResponseDelta>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
@ -333,9 +359,23 @@ async fn stream_completion(
|
||||||
if let Some(low_speed_timeout) = low_speed_timeout {
|
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||||
request_builder = request_builder.read_timeout(low_speed_timeout);
|
request_builder = request_builder.read_timeout(low_speed_timeout);
|
||||||
}
|
}
|
||||||
|
let is_streaming = request.stream;
|
||||||
|
|
||||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||||
let mut response = client.send(request).await?;
|
let mut response = client.send(request).await?;
|
||||||
if response.status().is_success() {
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let mut body = Vec::new();
|
||||||
|
response.body_mut().read_to_end(&mut body).await?;
|
||||||
|
let body_str = std::str::from_utf8(&body)?;
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Failed to connect to API: {} {}",
|
||||||
|
response.status(),
|
||||||
|
body_str
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_streaming {
|
||||||
let reader = BufReader::new(response.into_body());
|
let reader = BufReader::new(response.into_body());
|
||||||
Ok(reader
|
Ok(reader
|
||||||
.lines()
|
.lines()
|
||||||
|
@ -367,19 +407,9 @@ async fn stream_completion(
|
||||||
} else {
|
} else {
|
||||||
let mut body = Vec::new();
|
let mut body = Vec::new();
|
||||||
response.body_mut().read_to_end(&mut body).await?;
|
response.body_mut().read_to_end(&mut body).await?;
|
||||||
|
|
||||||
let body_str = std::str::from_utf8(&body)?;
|
let body_str = std::str::from_utf8(&body)?;
|
||||||
|
let response: ResponseEvent = serde_json::from_str(body_str)?;
|
||||||
|
|
||||||
match serde_json::from_str::<ResponseEvent>(body_str) {
|
Ok(futures::stream::once(async move { Ok(response) }).boxed())
|
||||||
Ok(_) => Err(anyhow!(
|
|
||||||
"Unexpected success response while expecting an error: {}",
|
|
||||||
body_str,
|
|
||||||
)),
|
|
||||||
Err(_) => Err(anyhow!(
|
|
||||||
"Failed to connect to API: {} {}",
|
|
||||||
response.status(),
|
|
||||||
body_str,
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,7 @@ use crate::{
|
||||||
};
|
};
|
||||||
use crate::{LanguageModelCompletionEvent, LanguageModelProviderState};
|
use crate::{LanguageModelCompletionEvent, LanguageModelProviderState};
|
||||||
|
|
||||||
|
use super::anthropic::count_anthropic_tokens;
|
||||||
use super::open_ai::count_open_ai_tokens;
|
use super::open_ai::count_open_ai_tokens;
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "copilot_chat";
|
const PROVIDER_ID: &str = "copilot_chat";
|
||||||
|
@ -179,14 +180,20 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
|
match self.model {
|
||||||
|
CopilotChatModel::Claude3_5Sonnet => count_anthropic_tokens(request, cx),
|
||||||
|
_ => {
|
||||||
let model = match self.model {
|
let model = match self.model {
|
||||||
CopilotChatModel::Gpt4o => open_ai::Model::FourOmni,
|
CopilotChatModel::Gpt4o => open_ai::Model::FourOmni,
|
||||||
CopilotChatModel::Gpt4 => open_ai::Model::Four,
|
CopilotChatModel::Gpt4 => open_ai::Model::Four,
|
||||||
CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo,
|
CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo,
|
||||||
|
CopilotChatModel::O1Preview | CopilotChatModel::O1Mini => open_ai::Model::Four,
|
||||||
|
CopilotChatModel::Claude3_5Sonnet => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
count_open_ai_tokens(request, model, cx)
|
count_open_ai_tokens(request, model, cx)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn stream_completion(
|
fn stream_completion(
|
||||||
&self,
|
&self,
|
||||||
|
@ -209,7 +216,8 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let request = self.to_copilot_chat_request(request);
|
let copilot_request = self.to_copilot_chat_request(request);
|
||||||
|
let is_streaming = copilot_request.stream;
|
||||||
let Ok(low_speed_timeout) = cx.update(|cx| {
|
let Ok(low_speed_timeout) = cx.update(|cx| {
|
||||||
AllLanguageModelSettings::get_global(cx)
|
AllLanguageModelSettings::get_global(cx)
|
||||||
.copilot_chat
|
.copilot_chat
|
||||||
|
@ -220,16 +228,31 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
|
|
||||||
let request_limiter = self.request_limiter.clone();
|
let request_limiter = self.request_limiter.clone();
|
||||||
let future = cx.spawn(|cx| async move {
|
let future = cx.spawn(|cx| async move {
|
||||||
let response = CopilotChat::stream_completion(request, low_speed_timeout, cx);
|
let response = CopilotChat::stream_completion(copilot_request, low_speed_timeout, cx);
|
||||||
request_limiter.stream(async move {
|
request_limiter.stream(async move {
|
||||||
let response = response.await?;
|
let response = response.await?;
|
||||||
let stream = response
|
let stream = response
|
||||||
.filter_map(|response| async move {
|
.filter_map(move |response| async move {
|
||||||
match response {
|
match response {
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
let choice = result.choices.first();
|
let choice = result.choices.first();
|
||||||
match choice {
|
match choice {
|
||||||
Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
|
Some(choice) if !is_streaming => {
|
||||||
|
match &choice.message {
|
||||||
|
Some(msg) => Some(Ok(msg.content.clone().unwrap_or_default())),
|
||||||
|
None => Some(Err(anyhow::anyhow!(
|
||||||
|
"The Copilot Chat API returned a response with no message content"
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Some(choice) => {
|
||||||
|
match &choice.delta {
|
||||||
|
Some(delta) => Some(Ok(delta.content.clone().unwrap_or_default())),
|
||||||
|
None => Some(Err(anyhow::anyhow!(
|
||||||
|
"The Copilot Chat API returned a response with no delta content"
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
},
|
||||||
None => Some(Err(anyhow::anyhow!(
|
None => Some(Err(anyhow::anyhow!(
|
||||||
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
|
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
|
||||||
))),
|
))),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue