Add Qwen2-7B to the list of zed.dev models (#15649)

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-08-01 22:26:07 +02:00 committed by GitHub
parent 60127f2a8d
commit 21816d1ff5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 112 additions and 2 deletions

View file

@ -1,5 +1,6 @@
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use strum::EnumIter;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "provider", rename_all = "lowercase")]
@ -7,6 +8,33 @@ pub enum CloudModel {
Anthropic(anthropic::Model),
OpenAi(open_ai::Model),
Google(google_ai::Model),
Zed(ZedModel),
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
pub enum ZedModel {
#[serde(rename = "qwen2-7b-instruct")]
Qwen2_7bInstruct,
}
impl ZedModel {
pub fn id(&self) -> &str {
match self {
ZedModel::Qwen2_7bInstruct => "qwen2-7b-instruct",
}
}
pub fn display_name(&self) -> &str {
match self {
ZedModel::Qwen2_7bInstruct => "Qwen2 7B Instruct",
}
}
pub fn max_token_count(&self) -> usize {
match self {
ZedModel::Qwen2_7bInstruct => 8192,
}
}
}
impl Default for CloudModel {
@ -21,6 +49,7 @@ impl CloudModel {
CloudModel::Anthropic(model) => model.id(),
CloudModel::OpenAi(model) => model.id(),
CloudModel::Google(model) => model.id(),
CloudModel::Zed(model) => model.id(),
}
}
@ -29,6 +58,7 @@ impl CloudModel {
CloudModel::Anthropic(model) => model.display_name(),
CloudModel::OpenAi(model) => model.display_name(),
CloudModel::Google(model) => model.display_name(),
CloudModel::Zed(model) => model.display_name(),
}
}
@ -37,6 +67,7 @@ impl CloudModel {
CloudModel::Anthropic(model) => model.max_token_count(),
CloudModel::OpenAi(model) => model.max_token_count(),
CloudModel::Google(model) => model.max_token_count(),
CloudModel::Zed(model) => model.max_token_count(),
}
}
}

View file

@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
};
use anyhow::{anyhow, Context as _, Result};
use client::{Client, UserStore};
@ -146,6 +146,9 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
models.insert(model.id().to_string(), CloudModel::Google(model));
}
}
for model in ZedModel::iter() {
models.insert(model.id().to_string(), CloudModel::Zed(model));
}
// Override with available models from settings
for model in &AllLanguageModelSettings::get_global(cx)
@ -263,6 +266,9 @@ impl LanguageModel for CloudLanguageModel {
}
.boxed()
}
CloudModel::Zed(_) => {
count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
}
}
}
@ -323,6 +329,24 @@ impl LanguageModel for CloudLanguageModel {
});
async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::Zed(model) => {
let client = self.client.clone();
let mut request = request.into_open_ai(model.id().into());
request.max_tokens = Some(4000);
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Zed as i32,
request,
})
.await?;
Ok(open_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
}
@ -382,6 +406,9 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
}
CloudModel::Zed(_) => {
future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
}
}
}
}

View file

@ -37,6 +37,7 @@ impl LanguageModelRequest {
stream: true,
stop: self.stop,
temperature: self.temperature,
max_tokens: None,
tools: Vec::new(),
tool_choice: None,
}