Add Qwen2-7B to the list of zed.dev models (#15649)
Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
60127f2a8d
commit
21816d1ff5
9 changed files with 112 additions and 2 deletions
|
@ -1,5 +1,6 @@
|
|||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::EnumIter;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "provider", rename_all = "lowercase")]
|
||||
|
@ -7,6 +8,33 @@ pub enum CloudModel {
|
|||
Anthropic(anthropic::Model),
|
||||
OpenAi(open_ai::Model),
|
||||
Google(google_ai::Model),
|
||||
Zed(ZedModel),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
|
||||
pub enum ZedModel {
|
||||
#[serde(rename = "qwen2-7b-instruct")]
|
||||
Qwen2_7bInstruct,
|
||||
}
|
||||
|
||||
impl ZedModel {
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
ZedModel::Qwen2_7bInstruct => "qwen2-7b-instruct",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
ZedModel::Qwen2_7bInstruct => "Qwen2 7B Instruct",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
ZedModel::Qwen2_7bInstruct => 8192,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CloudModel {
|
||||
|
@ -21,6 +49,7 @@ impl CloudModel {
|
|||
CloudModel::Anthropic(model) => model.id(),
|
||||
CloudModel::OpenAi(model) => model.id(),
|
||||
CloudModel::Google(model) => model.id(),
|
||||
CloudModel::Zed(model) => model.id(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -29,6 +58,7 @@ impl CloudModel {
|
|||
CloudModel::Anthropic(model) => model.display_name(),
|
||||
CloudModel::OpenAi(model) => model.display_name(),
|
||||
CloudModel::Google(model) => model.display_name(),
|
||||
CloudModel::Zed(model) => model.display_name(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -37,6 +67,7 @@ impl CloudModel {
|
|||
CloudModel::Anthropic(model) => model.max_token_count(),
|
||||
CloudModel::OpenAi(model) => model.max_token_count(),
|
||||
CloudModel::Google(model) => model.max_token_count(),
|
||||
CloudModel::Zed(model) => model.max_token_count(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
|
|||
use crate::{
|
||||
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use client::{Client, UserStore};
|
||||
|
@ -146,6 +146,9 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
models.insert(model.id().to_string(), CloudModel::Google(model));
|
||||
}
|
||||
}
|
||||
for model in ZedModel::iter() {
|
||||
models.insert(model.id().to_string(), CloudModel::Zed(model));
|
||||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
|
@ -263,6 +266,9 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::Zed(_) => {
|
||||
count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -323,6 +329,24 @@ impl LanguageModel for CloudLanguageModel {
|
|||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::Zed(model) => {
|
||||
let client = self.client.clone();
|
||||
let mut request = request.into_open_ai(model.id().into());
|
||||
request.max_tokens = Some(4000);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Zed as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
Ok(open_ai::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -382,6 +406,9 @@ impl LanguageModel for CloudLanguageModel {
|
|||
CloudModel::Google(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
||||
}
|
||||
CloudModel::Zed(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,6 +37,7 @@ impl LanguageModelRequest {
|
|||
stream: true,
|
||||
stop: self.stop,
|
||||
temperature: self.temperature,
|
||||
max_tokens: None,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue